diff --git a/.github/codeql/codeql-config.yml b/.github/codeql/codeql-config.yml index 2470d98..93e5a3d 100644 --- a/.github/codeql/codeql-config.yml +++ b/.github/codeql/codeql-config.yml @@ -7,7 +7,16 @@ queries: - uses: security-and-quality - uses: ./.github/codeql/queries +paths-ignore: + - "**/tests/**" + - "**/test/**" + - "**/*_test.rs" + - "**/*/tests.rs" query-filters: + - exclude: + tags: + - test + - exclude: id: - rust/unwrap-on-option diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml new file mode 100644 index 0000000..1b6e455 --- /dev/null +++ b/.github/workflows/build.yml @@ -0,0 +1,39 @@ +name: Build + +on: + push: + branches: [ "*" ] + pull_request: + branches: [ "*" ] + +env: + CARGO_TERM_COLOR: always + +jobs: + build: + name: Build + runs-on: ubuntu-latest + + permissions: + contents: read + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Install latest stable Rust toolchain + uses: dtolnay/rust-toolchain@stable + + - name: Cache cargo registry & build artifacts + uses: actions/cache@v4 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + target + key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} + restore-keys: | + ${{ runner.os }}-cargo- + + - name: Build Release + run: cargo build --release --verbose \ No newline at end of file diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 5cf034a..5e21a85 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -4,11 +4,15 @@ on: push: tags: - '[0-9]+.[0-9]+.[0-9]+' - - '[0-9]+.[0-9]+.[0-9]+-*' workflow_dispatch: + inputs: + tag: + description: 'Release tag (example: 3.3.15)' + required: true + type: string concurrency: - group: release-${{ github.ref }} + group: release-${{ github.ref_name }}-${{ github.event.inputs.tag || 'auto' }} cancel-in-progress: true permissions: @@ -16,201 +20,386 @@ permissions: env: CARGO_TERM_COLOR: always - RUST_BACKTRACE: "1" BINARY_NAME: telemt jobs: prepare: + name: Prepare runs-on: ubuntu-latest + outputs: - version: ${{ steps.meta.outputs.version }} - prerelease: ${{ steps.meta.outputs.prerelease }} - release_enabled: ${{ steps.meta.outputs.release_enabled }} + version: ${{ steps.vars.outputs.version }} + prerelease: ${{ steps.vars.outputs.prerelease }} + steps: - - id: meta + - name: Resolve version + id: vars + shell: bash run: | set -euo pipefail - if [[ "${GITHUB_REF}" == refs/tags/* ]]; then - VERSION="${GITHUB_REF#refs/tags/}" - RELEASE_ENABLED=true + if [ "${GITHUB_EVENT_NAME}" = "workflow_dispatch" ]; then + VERSION="${{ github.event.inputs.tag }}" else - VERSION="manual-${GITHUB_SHA::7}" - RELEASE_ENABLED=false + VERSION="${GITHUB_REF#refs/tags/}" fi - if [[ "$VERSION" == *"-alpha"* || "$VERSION" == *"-beta"* || "$VERSION" == *"-rc"* ]]; then + VERSION="${VERSION#refs/tags/}" + + if [ -z "${VERSION}" ]; then + echo "Release version is empty" >&2 + exit 1 + fi + + if [[ "${VERSION}" == *-* ]]; then PRERELEASE=true else PRERELEASE=false fi - echo "version=$VERSION" >> "$GITHUB_OUTPUT" - echo "prerelease=$PRERELEASE" >> "$GITHUB_OUTPUT" - echo "release_enabled=$RELEASE_ENABLED" >> "$GITHUB_OUTPUT" + echo "version=${VERSION}" >> "${GITHUB_OUTPUT}" + echo "prerelease=${PRERELEASE}" >> "${GITHUB_OUTPUT}" - checks: + # ========================== + # GNU / glibc + # ========================== + build-gnu: + name: GNU ${{ matrix.asset }} runs-on: ubuntu-latest + needs: prepare + container: - image: debian:trixie - steps: - - run: | - apt-get update - apt-get install -y build-essential clang llvm pkg-config curl git - - - uses: actions/checkout@v4 - - uses: dtolnay/rust-toolchain@stable - with: - components: rustfmt, clippy - - - uses: actions/cache@v4 - with: - path: | - /github/home/.cargo/registry - /github/home/.cargo/git - target - key: checks-${{ hashFiles('**/Cargo.lock') }} - - - run: cargo fetch --locked - - run: cargo fmt --all -- --check - - run: cargo clippy - - run: cargo test - - build-binaries: - needs: [prepare, checks] - runs-on: ubuntu-latest - container: - image: debian:trixie + image: rust:slim-bookworm strategy: fail-fast: false matrix: include: - - rust_target: x86_64-unknown-linux-gnu - zig_target: x86_64-unknown-linux-gnu.2.28 - asset_name: telemt-x86_64-linux-gnu - - rust_target: aarch64-unknown-linux-gnu - zig_target: aarch64-unknown-linux-gnu.2.28 - asset_name: telemt-aarch64-linux-gnu - - rust_target: x86_64-unknown-linux-musl - zig_target: x86_64-unknown-linux-musl - asset_name: telemt-x86_64-linux-musl - - rust_target: aarch64-unknown-linux-musl - zig_target: aarch64-unknown-linux-musl - asset_name: telemt-aarch64-linux-musl + - target: x86_64-unknown-linux-gnu + asset: telemt-x86_64-linux-gnu + cpu: baseline + + - target: x86_64-unknown-linux-gnu + asset: telemt-x86_64-v3-linux-gnu + cpu: v3 + + - target: aarch64-unknown-linux-gnu + asset: telemt-aarch64-linux-gnu + cpu: generic steps: - - run: | - apt-get update - apt-get install -y clang llvm pkg-config curl git python3 python3-pip file tar xz-utils - - uses: actions/checkout@v4 - - uses: dtolnay/rust-toolchain@stable + + - uses: dtolnay/rust-toolchain@v1 with: - targets: ${{ matrix.rust_target }} + toolchain: stable + targets: | + x86_64-unknown-linux-gnu + aarch64-unknown-linux-gnu + + - name: Install deps + run: | + apt-get update + apt-get install -y \ + build-essential \ + clang \ + lld \ + pkg-config \ + gcc-aarch64-linux-gnu \ + g++-aarch64-linux-gnu - uses: actions/cache@v4 with: path: | - /github/home/.cargo/registry - /github/home/.cargo/git + /usr/local/cargo/registry + /usr/local/cargo/git target - key: build-${{ matrix.zig_target }}-${{ hashFiles('**/Cargo.lock') }} + key: gnu-${{ matrix.asset }}-${{ hashFiles('**/Cargo.lock') }} + restore-keys: | + gnu-${{ matrix.asset }}- + gnu- - - run: | - python3 -m pip install --user --break-system-packages cargo-zigbuild - echo "/github/home/.local/bin" >> "$GITHUB_PATH" + - name: Build + shell: bash + run: | + set -euo pipefail - - run: cargo fetch --locked + if [ "${{ matrix.target }}" = "aarch64-unknown-linux-gnu" ]; then + export CC=aarch64-linux-gnu-gcc + export CXX=aarch64-linux-gnu-g++ + export RUSTFLAGS="-C linker=aarch64-linux-gnu-gcc -C lto=fat -C panic=abort" + else + export CC=clang + export CXX=clang++ - - run: | - cargo zigbuild --release --locked --target "${{ matrix.zig_target }}" + if [ "${{ matrix.cpu }}" = "v3" ]; then + CPU_FLAGS="-C target-cpu=x86-64-v3" + else + CPU_FLAGS="-C target-cpu=x86-64" + fi - - run: | - BIN="target/${{ matrix.rust_target }}/release/${BINARY_NAME}" - llvm-strip "$BIN" || true + export RUSTFLAGS="-C linker=clang -C link-arg=-fuse-ld=lld -C lto=fat -C panic=abort ${CPU_FLAGS}" + fi - - run: | - BIN="target/${{ matrix.rust_target }}/release/${BINARY_NAME}" - OUT="$RUNNER_TEMP/${{ matrix.asset_name }}" - mkdir -p "$OUT" - install -m755 "$BIN" "$OUT/${BINARY_NAME}" + cargo build --release --target ${{ matrix.target }} -j "$(nproc)" - tar -C "$RUNNER_TEMP" -czf "${{ matrix.asset_name }}.tar.gz" "${{ matrix.asset_name }}" - sha256sum "${{ matrix.asset_name }}.tar.gz" > "${{ matrix.asset_name }}.sha256" + - name: Package + shell: bash + run: | + set -euo pipefail + + mkdir -p dist + cp "target/${{ matrix.target }}/release/${{ env.BINARY_NAME }}" dist/telemt + + cd dist + tar -czf "${{ matrix.asset }}.tar.gz" \ + --owner=0 --group=0 --numeric-owner \ + telemt + + sha256sum "${{ matrix.asset }}.tar.gz" > "${{ matrix.asset }}.tar.gz.sha256" - uses: actions/upload-artifact@v4 with: - name: ${{ matrix.asset_name }} - path: | - ${{ matrix.asset_name }}.tar.gz - ${{ matrix.asset_name }}.sha256 + name: ${{ matrix.asset }} + path: dist/* - docker-image: - name: Docker ${{ matrix.platform }} - needs: [prepare, build-binaries] + # ========================== + # MUSL + # ========================== + build-musl: + name: MUSL ${{ matrix.asset }} runs-on: ubuntu-latest + needs: prepare + + container: + image: rust:slim-bookworm strategy: + fail-fast: false matrix: include: - - platform: linux/amd64 - artifact: telemt-x86_64-linux-gnu - - platform: linux/arm64 - artifact: telemt-aarch64-linux-gnu + - target: x86_64-unknown-linux-musl + asset: telemt-x86_64-linux-musl + cpu: baseline + + - target: x86_64-unknown-linux-musl + asset: telemt-x86_64-v3-linux-musl + cpu: v3 + + - target: aarch64-unknown-linux-musl + asset: telemt-aarch64-linux-musl + cpu: generic steps: - uses: actions/checkout@v4 - - uses: actions/download-artifact@v4 + - name: Install deps + run: | + apt-get update + apt-get install -y \ + musl-tools \ + pkg-config \ + curl + + - uses: actions/cache@v4 + if: matrix.target == 'aarch64-unknown-linux-musl' with: - name: ${{ matrix.artifact }} - path: dist + path: ~/.musl-aarch64 + key: musl-toolchain-aarch64-v1 - - run: | - mkdir docker-build - tar -xzf dist/*.tar.gz -C docker-build --strip-components=1 + - name: Install aarch64 musl toolchain + if: matrix.target == 'aarch64-unknown-linux-musl' + shell: bash + run: | + set -euo pipefail - - uses: docker/setup-buildx-action@v3 + TOOLCHAIN_DIR="$HOME/.musl-aarch64" + ARCHIVE="aarch64-linux-musl-cross.tgz" + URL="https://github.com/telemt/telemt/releases/download/toolchains/${ARCHIVE}" - - name: Login - if: ${{ needs.prepare.outputs.release_enabled == 'true' }} - uses: docker/login-action@v3 + if [ -x "${TOOLCHAIN_DIR}/bin/aarch64-linux-musl-gcc" ]; then + echo "MUSL toolchain cached" + else + curl -fL \ + --retry 5 \ + --retry-delay 3 \ + --connect-timeout 10 \ + --max-time 120 \ + -o "${ARCHIVE}" "${URL}" + + mkdir -p "${TOOLCHAIN_DIR}" + tar -xzf "${ARCHIVE}" --strip-components=1 -C "${TOOLCHAIN_DIR}" + fi + + echo "${TOOLCHAIN_DIR}/bin" >> "${GITHUB_PATH}" + + - name: Add rust target + run: rustup target add ${{ matrix.target }} + + - uses: actions/cache@v4 with: - registry: ghcr.io - username: ${{ github.actor }} - password: ${{ secrets.GITHUB_TOKEN }} + path: | + /usr/local/cargo/registry + /usr/local/cargo/git + target + key: musl-${{ matrix.asset }}-${{ hashFiles('**/Cargo.lock') }} + restore-keys: | + musl-${{ matrix.asset }}- + musl- - - uses: docker/build-push-action@v6 + - name: Build + shell: bash + run: | + set -euo pipefail + + if [ "${{ matrix.target }}" = "aarch64-unknown-linux-musl" ]; then + export CC=aarch64-linux-musl-gcc + export CC_aarch64_unknown_linux_musl=aarch64-linux-musl-gcc + export RUSTFLAGS="-C target-feature=+crt-static -C linker=aarch64-linux-musl-gcc -C lto=fat -C panic=abort" + else + export CC=musl-gcc + export CC_x86_64_unknown_linux_musl=musl-gcc + + if [ "${{ matrix.cpu }}" = "v3" ]; then + CPU_FLAGS="-C target-cpu=x86-64-v3" + else + CPU_FLAGS="-C target-cpu=x86-64" + fi + + export RUSTFLAGS="-C target-feature=+crt-static -C lto=fat -C panic=abort ${CPU_FLAGS}" + fi + + cargo build --release --target ${{ matrix.target }} -j "$(nproc)" + + - name: Package + shell: bash + run: | + set -euo pipefail + + mkdir -p dist + cp "target/${{ matrix.target }}/release/${{ env.BINARY_NAME }}" dist/telemt + + cd dist + tar -czf "${{ matrix.asset }}.tar.gz" \ + --owner=0 --group=0 --numeric-owner \ + telemt + + sha256sum "${{ matrix.asset }}.tar.gz" > "${{ matrix.asset }}.tar.gz.sha256" + + - uses: actions/upload-artifact@v4 with: - context: ./docker-build - platforms: ${{ matrix.platform }} - push: ${{ needs.prepare.outputs.release_enabled == 'true' }} - tags: ghcr.io/${{ github.repository }}:${{ needs.prepare.outputs.version }} - cache-from: type=gha,scope=telemt-${{ matrix.platform }} - cache-to: type=gha,mode=max,scope=telemt-${{ matrix.platform }} - provenance: false - sbom: false + name: ${{ matrix.asset }} + path: dist/* + # ========================== + # Release + # ========================== release: - if: ${{ needs.prepare.outputs.release_enabled == 'true' }} - needs: [prepare, build-binaries] + name: Release runs-on: ubuntu-latest + needs: [prepare, build-gnu, build-musl] + permissions: contents: write steps: - uses: actions/download-artifact@v4 with: - path: release-artifacts - pattern: telemt-* + path: artifacts - - run: | - mkdir upload - find release-artifacts -type f \( -name '*.tar.gz' -o -name '*.sha256' \) -exec cp {} upload/ \; + - name: Flatten artifacts + shell: bash + run: | + set -euo pipefail + mkdir -p dist + find artifacts -type f -exec cp {} dist/ \; - - uses: softprops/action-gh-release@v2 + - name: Create GitHub Release + uses: softprops/action-gh-release@v2 with: - files: upload/* + tag_name: ${{ needs.prepare.outputs.version }} + target_commitish: ${{ github.sha }} + files: dist/* generate_release_notes: true prerelease: ${{ needs.prepare.outputs.prerelease == 'true' }} + overwrite_files: true + + # ========================== + # Docker + # ========================== + docker: + name: Docker + runs-on: ubuntu-latest + needs: [prepare, release] + + permissions: + contents: read + packages: write + + steps: + - uses: actions/checkout@v4 + + - uses: docker/setup-qemu-action@v3 + + - uses: docker/setup-buildx-action@v3 + + - uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Probe release assets + shell: bash + env: + VERSION: ${{ needs.prepare.outputs.version }} + run: | + set -euo pipefail + + for asset in \ + telemt-x86_64-linux-musl.tar.gz \ + telemt-x86_64-linux-musl.tar.gz.sha256 \ + telemt-aarch64-linux-musl.tar.gz \ + telemt-aarch64-linux-musl.tar.gz.sha256 + do + curl -fsIL \ + --retry 10 \ + --retry-delay 3 \ + "https://github.com/${GITHUB_REPOSITORY}/releases/download/${VERSION}/${asset}" \ + > /dev/null + done + + - name: Compute image tags + id: meta + shell: bash + env: + VERSION: ${{ needs.prepare.outputs.version }} + run: | + set -euo pipefail + + IMAGE="$(echo "ghcr.io/${GITHUB_REPOSITORY}" | tr '[:upper:]' '[:lower:]')" + TAGS="${IMAGE}:${VERSION}" + + if [[ "${VERSION}" != *-* ]]; then + TAGS="${TAGS}"$'\n'"${IMAGE}:latest" + fi + + { + echo "tags<> "${GITHUB_OUTPUT}" + + - name: Build & Push + uses: docker/build-push-action@v6 + with: + context: . + push: true + pull: true + platforms: linux/amd64,linux/arm64 + tags: ${{ steps.meta.outputs.tags }} + build-args: | + TELEMT_REPOSITORY=${{ github.repository }} + TELEMT_VERSION=${{ needs.prepare.outputs.version }} + cache-from: type=gha + cache-to: type=gha,mode=max diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml deleted file mode 100644 index 799f2ce..0000000 --- a/.github/workflows/rust.yml +++ /dev/null @@ -1,66 +0,0 @@ -name: Rust - -on: - push: - branches: [ "*" ] - pull_request: - branches: [ "*" ] - -env: - CARGO_TERM_COLOR: always - -jobs: - build: - name: Build - runs-on: ubuntu-latest - - permissions: - contents: read - actions: write - checks: write - - steps: - - name: Checkout repository - uses: actions/checkout@v4 - - - name: Install latest stable Rust toolchain - uses: dtolnay/rust-toolchain@stable - with: - components: rustfmt, clippy - - - name: Cache cargo registry & build artifacts - uses: actions/cache@v4 - with: - path: | - ~/.cargo/registry - ~/.cargo/git - target - key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} - restore-keys: | - ${{ runner.os }}-cargo- - - - name: Build Release - run: cargo build --release --verbose - - - name: Run tests - run: cargo test --verbose - - - name: Stress quota-lock suites (PR only) - if: github.event_name == 'pull_request' - env: - RUST_TEST_THREADS: 16 - run: | - set -euo pipefail - for i in $(seq 1 12); do - echo "[quota-lock-stress] iteration ${i}/12" - cargo test quota_lock_ --bin telemt -- --nocapture --test-threads 16 - cargo test relay_quota_wake --bin telemt -- --nocapture --test-threads 16 - done - -# clippy dont fail on warnings because of active development of telemt -# and many warnings - - name: Run clippy - run: cargo clippy -- --cap-lints warn - - - name: Check for unused dependencies - run: cargo udeps || true diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..8e43dd7 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,139 @@ +name: Check + +on: + push: + branches: [ "*" ] + pull_request: + branches: [ "*" ] + +env: + CARGO_TERM_COLOR: always + +concurrency: + group: test-${{ github.ref }} + cancel-in-progress: true + +jobs: +# ========================== +# Formatting +# ========================== + fmt: + name: Fmt + runs-on: ubuntu-latest + + permissions: + contents: read + + steps: + - uses: actions/checkout@v4 + + - uses: dtolnay/rust-toolchain@stable + with: + components: rustfmt + + - run: cargo fmt -- --check + +# ========================== +# Tests +# ========================== + test: + name: Test + runs-on: ubuntu-latest + + permissions: + contents: read + actions: write + checks: write + + steps: + - uses: actions/checkout@v4 + + - uses: dtolnay/rust-toolchain@stable + + - name: Cache cargo + uses: actions/cache@v4 + with: + path: | + ~/.cargo/bin + ~/.cargo/registry + ~/.cargo/git + target + key: ${{ runner.os }}-cargo-nextest-${{ hashFiles('**/Cargo.lock') }} + restore-keys: | + ${{ runner.os }}-cargo-nextest- + ${{ runner.os }}-cargo- + + - name: Install cargo-nextest + run: cargo install --locked cargo-nextest || true + + - name: Run tests with nextest + run: cargo nextest run -j "$(nproc)" + +# ========================== +# Clippy +# ========================== + clippy: + name: Clippy + runs-on: ubuntu-latest + + permissions: + contents: read + checks: write + + steps: + - uses: actions/checkout@v4 + + - uses: dtolnay/rust-toolchain@stable + with: + components: clippy + + - name: Cache cargo + uses: actions/cache@v4 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + target + key: ${{ runner.os }}-cargo-clippy-${{ hashFiles('**/Cargo.lock') }} + restore-keys: | + ${{ runner.os }}-cargo-clippy- + ${{ runner.os }}-cargo- + + - name: Run clippy + run: cargo clippy -j "$(nproc)" -- --cap-lints warn + +# ========================== +# Udeps +# ========================== + udeps: + name: Udeps + runs-on: ubuntu-latest + + permissions: + contents: read + + steps: + - uses: actions/checkout@v4 + + - uses: dtolnay/rust-toolchain@stable + with: + components: rust-src + + - name: Cache cargo + uses: actions/cache@v4 + with: + path: | + ~/.cargo/bin + ~/.cargo/registry + ~/.cargo/git + target + key: ${{ runner.os }}-cargo-udeps-${{ hashFiles('**/Cargo.lock') }} + restore-keys: | + ${{ runner.os }}-cargo-udeps- + ${{ runner.os }}-cargo- + + - name: Install cargo-udeps + run: cargo install --locked cargo-udeps || true + + - name: Run udeps + run: cargo udeps -j "$(nproc)" || true diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index 84c5f77..14f9318 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -1,8 +1,8 @@ # Code of Conduct -## 1. Purpose +## Purpose -Telemt exists to solve technical problems. +**Telemt exists to solve technical problems.** Telemt is open to contributors who want to learn, improve and build meaningful systems together. @@ -18,27 +18,34 @@ Technology has consequences. Responsibility is inherent. --- -## 2. Principles +## Principles * **Technical over emotional** + Arguments are grounded in data, logs, reproducible cases, or clear reasoning. * **Clarity over noise** + Communication is structured, concise, and relevant. * **Openness with standards** + Participation is open. The work remains disciplined. * **Independence of judgment** + Claims are evaluated on technical merit, not affiliation or posture. * **Responsibility over capability** + Capability does not justify careless use. * **Cooperation over friction** + Progress depends on coordination, mutual support, and honest review. * **Good intent, rigorous method** + Assume good intent, but require rigor. > **Aussagen gelten nach ihrer Begründung.** @@ -47,7 +54,7 @@ Technology has consequences. Responsibility is inherent. --- -## 3. Expected Behavior +## Expected Behavior Participants are expected to: @@ -69,7 +76,7 @@ New contributors are welcome. They are expected to grow into these standards. Ex --- -## 4. Unacceptable Behavior +## Unacceptable Behavior The following is not allowed: @@ -89,7 +96,7 @@ Such discussions may be closed, removed, or redirected. --- -## 5. Security and Misuse +## Security and Misuse Telemt is intended for responsible use. @@ -109,15 +116,13 @@ Security is both technical and behavioral. Telemt is open to contributors of different backgrounds, experience levels, and working styles. -Standards are public, legible, and applied to the work itself. - -Questions are welcome. Careful disagreement is welcome. Honest correction is welcome. - -Gatekeeping by obscurity, status signaling, or hostility is not. +- Standards are public, legible, and applied to the work itself. +- Questions are welcome. Careful disagreement is welcome. Honest correction is welcome. +- Gatekeeping by obscurity, status signaling, or hostility is not. --- -## 7. Scope +## Scope This Code of Conduct applies to all official spaces: @@ -127,16 +132,19 @@ This Code of Conduct applies to all official spaces: --- -## 8. Maintainer Stewardship +## Maintainer Stewardship Maintainers are responsible for final decisions in matters of conduct, scope, and direction. -This responsibility is stewardship: preserving continuity, protecting signal, maintaining standards, and keeping Telemt workable for others. +This responsibility is stewardship: +- preserving continuity, +- protecting signal, +- maintaining standards, +- keeping Telemt workable for others. Judgment should be exercised with restraint, consistency, and institutional responsibility. - -Not every decision requires extended debate. -Not every intervention requires public explanation. +- Not every decision requires extended debate. +- Not every intervention requires public explanation. All decisions are expected to serve the durability, clarity, and integrity of Telemt. @@ -146,7 +154,7 @@ All decisions are expected to serve the durability, clarity, and integrity of Te --- -## 9. Enforcement +## Enforcement Maintainers may act to preserve the integrity of Telemt, including by: @@ -156,44 +164,40 @@ Maintainers may act to preserve the integrity of Telemt, including by: * Restricting or banning participants Actions are taken to maintain function, continuity, and signal quality. - -Where possible, correction is preferred to exclusion. - -Where necessary, exclusion is preferred to decay. +- Where possible, correction is preferred to exclusion. +- Where necessary, exclusion is preferred to decay. --- -## 10. Final +## Final Telemt is built on discipline, structure, and shared intent. +- Signal over noise. +- Facts over opinion. +- Systems over rhetoric. -Signal over noise. -Facts over opinion. -Systems over rhetoric. +- Work is collective. +- Outcomes are shared. +- Responsibility is distributed. -Work is collective. -Outcomes are shared. -Responsibility is distributed. - -Precision is learned. -Rigor is expected. -Help is part of the work. +- Precision is learned. +- Rigor is expected. +- Help is part of the work. > **Ordnung ist Voraussetzung der Freiheit.** -If you contribute — contribute with care. -If you speak — speak with substance. -If you engage — engage constructively. +- If you contribute — contribute with care. +- If you speak — speak with substance. +- If you engage — engage constructively. --- -## 11. After All +## After All Systems outlive intentions. - -What is built will be used. -What is released will propagate. -What is maintained will define the future state. +- What is built will be used. +- What is released will propagate. +- What is maintained will define the future state. There is no neutral infrastructure, only infrastructure shaped well or poorly. @@ -201,8 +205,8 @@ There is no neutral infrastructure, only infrastructure shaped well or poorly. > Every system carries responsibility. -Stability requires discipline. -Freedom requires structure. -Trust requires honesty. +- Stability requires discipline. +- Freedom requires structure. +- Trust requires honesty. -In the end, the system reflects its contributors. +In the end: the system reflects its contributors. diff --git a/Cargo.lock b/Cargo.lock index ef6a7c3..13c2977 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1454,9 +1454,9 @@ dependencies = [ [[package]] name = "iri-string" -version = "0.7.10" +version = "0.7.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c91338f0783edbd6195decb37bae672fd3b165faffb89bf7b9e6942f8b1a731a" +checksum = "d8e7418f59cc01c88316161279a7f665217ae316b388e58a0d10e29f54f1e5eb" dependencies = [ "memchr", "serde", @@ -1486,7 +1486,7 @@ dependencies = [ "cesu8", "cfg-if", "combine", - "jni-sys", + "jni-sys 0.3.1", "log", "thiserror 1.0.69", "walkdir", @@ -1495,9 +1495,31 @@ dependencies = [ [[package]] name = "jni-sys" -version = "0.3.0" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130" +checksum = "41a652e1f9b6e0275df1f15b32661cf0d4b78d4d87ddec5e0c3c20f097433258" +dependencies = [ + "jni-sys 0.4.1", +] + +[[package]] +name = "jni-sys" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c6377a88cb3910bee9b0fa88d4f42e1d2da8e79915598f65fb0c7ee14c878af2" +dependencies = [ + "jni-sys-macros", +] + +[[package]] +name = "jni-sys-macros" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38c0b942f458fe50cdac086d2f946512305e5631e720728f2a61aabcd47a6264" +dependencies = [ + "quote", + "syn", +] [[package]] name = "jobserver" @@ -1659,9 +1681,9 @@ dependencies = [ [[package]] name = "moka" -version = "0.12.14" +version = "0.12.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85f8024e1c8e71c778968af91d43700ce1d11b219d127d79fb2934153b82b42b" +checksum = "957228ad12042ee839f93c8f257b62b4c0ab5eaae1d4fa60de53b27c9d7c5046" dependencies = [ "crossbeam-channel", "crossbeam-epoch", @@ -2771,7 +2793,7 @@ checksum = "7b2093cf4c8eb1e67749a6762251bc9cd836b6fc171623bd0a9d324d37af2417" [[package]] name = "telemt" -version = "3.3.29" +version = "3.3.32" dependencies = [ "aes", "anyhow", diff --git a/Cargo.toml b/Cargo.toml index a5f1fec..f1b6af2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,8 +1,11 @@ [package] name = "telemt" -version = "3.3.29" +version = "3.3.32" edition = "2024" +[features] +redteam_offline_expected_fail = [] + [dependencies] # C libc = "0.2" @@ -93,4 +96,6 @@ name = "crypto_bench" harness = false [profile.release] -lto = "thin" +lto = "fat" +codegen-units = 1 + diff --git a/Dockerfile b/Dockerfile index 15a4900..d138ce9 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,44 +1,98 @@ -# ========================== -# Stage 1: Build -# ========================== -FROM rust:1.88-slim-bookworm AS builder +# syntax=docker/dockerfile:1 -RUN apt-get update && apt-get install -y --no-install-recommends \ - pkg-config \ - && rm -rf /var/lib/apt/lists/* - -WORKDIR /build - -COPY Cargo.toml Cargo.lock* ./ -RUN mkdir src && echo 'fn main() {}' > src/main.rs && \ - cargo build --release 2>/dev/null || true && \ - rm -rf src - -COPY . . -RUN cargo build --release && strip target/release/telemt +ARG TELEMT_REPOSITORY=telemt/telemt +ARG TELEMT_VERSION=latest # ========================== -# Stage 2: Runtime +# Minimal Image # ========================== -FROM debian:bookworm-slim +FROM debian:12-slim AS minimal -RUN apt-get update && apt-get install -y --no-install-recommends \ - ca-certificates \ - && rm -rf /var/lib/apt/lists/* +ARG TARGETARCH +ARG TELEMT_REPOSITORY +ARG TELEMT_VERSION -RUN useradd -r -s /usr/sbin/nologin telemt +RUN set -eux; \ + apt-get update; \ + apt-get install -y --no-install-recommends \ + binutils \ + ca-certificates \ + curl \ + tar; \ + rm -rf /var/lib/apt/lists/* + +RUN set -eux; \ + case "${TARGETARCH}" in \ + amd64) ASSET="telemt-x86_64-linux-musl.tar.gz" ;; \ + arm64) ASSET="telemt-aarch64-linux-musl.tar.gz" ;; \ + *) echo "Unsupported TARGETARCH: ${TARGETARCH}" >&2; exit 1 ;; \ + esac; \ + VERSION="${TELEMT_VERSION#refs/tags/}"; \ + if [ -z "${VERSION}" ] || [ "${VERSION}" = "latest" ]; then \ + BASE_URL="https://github.com/${TELEMT_REPOSITORY}/releases/latest/download"; \ + else \ + BASE_URL="https://github.com/${TELEMT_REPOSITORY}/releases/download/${VERSION}"; \ + fi; \ + curl -fL \ + --retry 5 \ + --retry-delay 3 \ + --connect-timeout 10 \ + --max-time 120 \ + -o "/tmp/${ASSET}" \ + "${BASE_URL}/${ASSET}"; \ + curl -fL \ + --retry 5 \ + --retry-delay 3 \ + --connect-timeout 10 \ + --max-time 120 \ + -o "/tmp/${ASSET}.sha256" \ + "${BASE_URL}/${ASSET}.sha256"; \ + cd /tmp; \ + sha256sum -c "${ASSET}.sha256"; \ + tar -xzf "${ASSET}" -C /tmp; \ + test -f /tmp/telemt; \ + install -m 0755 /tmp/telemt /telemt; \ + strip --strip-unneeded /telemt || true; \ + rm -f "/tmp/${ASSET}" "/tmp/${ASSET}.sha256" /tmp/telemt + +# ========================== +# Debug Image +# ========================== +FROM debian:12-slim AS debug + +RUN set -eux; \ + apt-get update; \ + apt-get install -y --no-install-recommends \ + ca-certificates \ + tzdata \ + curl \ + iproute2 \ + busybox; \ + rm -rf /var/lib/apt/lists/* WORKDIR /app -COPY --from=builder /build/target/release/telemt /app/telemt +COPY --from=minimal /telemt /app/telemt COPY config.toml /app/config.toml -RUN chown -R telemt:telemt /app -USER telemt - -EXPOSE 443 -EXPOSE 9090 -EXPOSE 9091 +EXPOSE 443 9090 9091 + +ENTRYPOINT ["/app/telemt"] +CMD ["config.toml"] + +# ========================== +# Production Distroless on MUSL +# ========================== +FROM gcr.io/distroless/static-debian12 AS prod + +WORKDIR /app + +COPY --from=minimal /telemt /app/telemt +COPY config.toml /app/config.toml + +USER nonroot:nonroot + +EXPOSE 443 9090 9091 ENTRYPOINT ["/app/telemt"] CMD ["config.toml"] diff --git a/docs/CONFIG_PARAMS.en.md b/docs/CONFIG_PARAMS.en.md index 3eee3a7..eda2435 100644 --- a/docs/CONFIG_PARAMS.en.md +++ b/docs/CONFIG_PARAMS.en.md @@ -20,7 +20,7 @@ This document lists all configuration keys accepted by `config.toml`. | Parameter | Type | Default | Constraints / validation | Description | |---|---|---|---|---| | data_path | `String \| null` | `null` | — | Optional runtime data directory path. | -| prefer_ipv6 | `bool` | `false` | — | Prefer IPv6 where applicable in runtime logic. | +| prefer_ipv6 | `bool` | `false` | Deprecated. Use `network.prefer`. | Deprecated legacy IPv6 preference flag migrated to `network.prefer`. | | fast_mode | `bool` | `true` | — | Enables fast-path optimizations for traffic processing. | | use_middle_proxy | `bool` | `true` | none | Enables ME transport mode; if `false`, runtime falls back to direct DC routing. | | proxy_secret_path | `String \| null` | `"proxy-secret"` | Path may be `null`. | Path to Telegram infrastructure proxy-secret file used by ME handshake logic. | @@ -44,11 +44,14 @@ This document lists all configuration keys accepted by `config.toml`. | me_writer_cmd_channel_capacity | `usize` | `4096` | Must be `> 0`. | Capacity of per-writer command channel. | | me_route_channel_capacity | `usize` | `768` | Must be `> 0`. | Capacity of per-connection ME response route channel. | | me_c2me_channel_capacity | `usize` | `1024` | Must be `> 0`. | Capacity of per-client command queue (client reader -> ME sender). | +| me_c2me_send_timeout_ms | `u64` | `4000` | `0..=60000`. | Maximum wait for enqueueing client->ME commands when the per-client queue is full (`0` keeps legacy unbounded wait). | | me_reader_route_data_wait_ms | `u64` | `2` | `0..=20`. | Bounded wait for routing ME DATA to per-connection queue (`0` = no wait). | | me_d2c_flush_batch_max_frames | `usize` | `32` | `1..=512`. | Max ME->client frames coalesced before flush. | | me_d2c_flush_batch_max_bytes | `usize` | `131072` | `4096..=2_097_152`. | Max ME->client payload bytes coalesced before flush. | | me_d2c_flush_batch_max_delay_us | `u64` | `500` | `0..=5000`. | Max microsecond wait for coalescing more ME->client frames (`0` disables timed coalescing). | | me_d2c_ack_flush_immediate | `bool` | `true` | — | Flushes client writer immediately after quick-ack write. | +| me_quota_soft_overshoot_bytes | `u64` | `65536` | `0..=16_777_216`. | Extra per-route quota allowance (bytes) tolerated before writer-side quota enforcement drops route data. | +| me_d2c_frame_buf_shrink_threshold_bytes | `usize` | `262144` | `4096..=16_777_216`. | Threshold for shrinking oversized ME->client frame-aggregation buffers after flush. | | direct_relay_copy_buf_c2s_bytes | `usize` | `65536` | `4096..=1_048_576`. | Copy buffer size for client->DC direction in direct relay. | | direct_relay_copy_buf_s2c_bytes | `usize` | `262144` | `8192..=2_097_152`. | Copy buffer size for DC->client direction in direct relay. | | crypto_pending_buffer | `usize` | `262144` | — | Max pending ciphertext buffer per client writer (bytes). | @@ -105,6 +108,8 @@ This document lists all configuration keys accepted by `config.toml`. | me_warn_rate_limit_ms | `u64` | `5000` | Must be `> 0`. | Cooldown for repetitive ME warning logs (ms). | | me_route_no_writer_mode | `"async_recovery_failfast" \| "inline_recovery_legacy" \| "hybrid_async_persistent"` | `"hybrid_async_persistent"` | — | Route behavior when no writer is immediately available. | | me_route_no_writer_wait_ms | `u64` | `250` | `10..=5000`. | Max wait in async-recovery failfast mode (ms). | +| me_route_hybrid_max_wait_ms | `u64` | `3000` | `50..=60000`. | Maximum cumulative wait in hybrid no-writer mode before failfast fallback (ms). | +| me_route_blocking_send_timeout_ms | `u64` | `250` | `0..=5000`. | Maximum wait for blocking route-channel send fallback (`0` keeps legacy unbounded wait). | | me_route_inline_recovery_attempts | `u32` | `3` | Must be `> 0`. | Inline recovery attempts in legacy mode. | | me_route_inline_recovery_wait_ms | `u64` | `3000` | `10..=30000`. | Max inline recovery wait in legacy mode (ms). | | fast_mode_min_tls_record | `usize` | `0` | — | Minimum TLS record size when fast-mode coalescing is enabled (`0` disables). | @@ -124,6 +129,7 @@ This document lists all configuration keys accepted by `config.toml`. | me_secret_atomic_snapshot | `bool` | `true` | — | Keeps selector and secret bytes from the same snapshot atomically. | | proxy_secret_len_max | `usize` | `256` | Must be within `[32, 4096]`. | Upper length limit for accepted proxy-secret bytes. | | me_pool_drain_ttl_secs | `u64` | `90` | none | Time window where stale writers remain fallback-eligible after map change. | +| me_instadrain | `bool` | `false` | — | Forces draining stale writers to be removed on the next cleanup tick, bypassing TTL/deadline waiting. | | me_pool_drain_threshold | `u64` | `128` | — | Max draining stale writers before batch force-close (`0` disables threshold cleanup). | | me_pool_drain_soft_evict_enabled | `bool` | `true` | — | Enables gradual soft-eviction of stale writers during drain/reinit instead of immediate hard close. | | me_pool_drain_soft_evict_grace_secs | `u64` | `30` | `0..=3600`. | Grace period before stale writers become soft-evict candidates. | @@ -198,10 +204,14 @@ This document lists all configuration keys accepted by `config.toml`. | listen_tcp | `bool \| null` | `null` (auto) | — | Explicit TCP listener enable/disable override. | | proxy_protocol | `bool` | `false` | — | Enables HAProxy PROXY protocol parsing on incoming client connections. | | proxy_protocol_header_timeout_ms | `u64` | `500` | Must be `> 0`. | Timeout for PROXY protocol header read/parse (ms). | +| proxy_protocol_trusted_cidrs | `IpNetwork[]` | `[]` | — | When non-empty, only connections from these proxy source CIDRs are allowed to provide PROXY protocol headers. If empty, PROXY headers are rejected by default (security hardening). | | metrics_port | `u16 \| null` | `null` | — | Metrics endpoint port (enables metrics listener). | | metrics_listen | `String \| null` | `null` | — | Full metrics bind address (`IP:PORT`), overrides `metrics_port`. | | metrics_whitelist | `IpNetwork[]` | `["127.0.0.1/32", "::1/128"]` | — | CIDR whitelist for metrics endpoint access. | | max_connections | `u32` | `10000` | — | Max concurrent client connections (`0` = unlimited). | +| accept_permit_timeout_ms | `u64` | `250` | `0..=60000`. | Maximum wait for acquiring a connection-slot permit before the accepted connection is dropped (`0` keeps legacy unbounded wait). | + +Note: When `server.proxy_protocol` is enabled, incoming PROXY protocol headers are parsed from the first bytes of the connection and the client source address is replaced with `src_addr` from the header. For security, the peer source IP (the direct connection address) is verified against `server.proxy_protocol_trusted_cidrs`; if this list is empty, PROXY headers are rejected and the connection is considered untrusted. ## [server.api] @@ -226,7 +236,7 @@ This document lists all configuration keys accepted by `config.toml`. |---|---|---|---|---| | ip | `IpAddr` | — | — | Listener bind IP. | | announce | `String \| null` | — | — | Public IP/domain announced in proxy links (priority over `announce_ip`). | -| announce_ip | `IpAddr \| null` | — | — | Deprecated legacy announce IP (migrated to `announce` if needed). | +| announce_ip | `IpAddr \| null` | — | Deprecated. Use `announce`. | Deprecated legacy announce IP (migrated to `announce` if needed). | | proxy_protocol | `bool \| null` | `null` | — | Per-listener override for PROXY protocol enable flag. | | reuse_allow | `bool` | `false` | — | Enables `SO_REUSEPORT` for multi-instance bind sharing. | @@ -235,6 +245,10 @@ This document lists all configuration keys accepted by `config.toml`. | Parameter | Type | Default | Constraints / validation | Description | |---|---|---|---|---| | client_handshake | `u64` | `30` | — | Client handshake timeout. | +| relay_idle_policy_v2_enabled | `bool` | `true` | — | Enables soft/hard middle-relay client idle policy. | +| relay_client_idle_soft_secs | `u64` | `120` | Must be `> 0`; must be `<= relay_client_idle_hard_secs`. | Soft idle threshold for middle-relay client uplink inactivity (seconds). | +| relay_client_idle_hard_secs | `u64` | `360` | Must be `> 0`; must be `>= relay_client_idle_soft_secs`. | Hard idle threshold for middle-relay client uplink inactivity (seconds). | +| relay_idle_grace_after_downstream_activity_secs | `u64` | `30` | Must be `<= relay_client_idle_hard_secs`. | Extra hard-idle grace after recent downstream activity (seconds). | | tg_connect | `u64` | `10` | — | Upstream Telegram connect timeout. | | client_keepalive | `u64` | `15` | — | Client keepalive timeout. | | client_ack | `u64` | `90` | — | Client ACK timeout. | @@ -247,6 +261,9 @@ This document lists all configuration keys accepted by `config.toml`. |---|---|---|---|---| | tls_domain | `String` | `"petrovich.ru"` | — | Primary TLS domain used in fake TLS handshake profile. | | tls_domains | `String[]` | `[]` | — | Additional TLS domains for generating multiple links. | +| unknown_sni_action | `"drop" \| "mask"` | `"drop"` | — | Action for TLS ClientHello with unknown/non-configured SNI. | +| tls_fetch_scope | `String` | `""` | Value is trimmed during load; empty keeps default upstream routing behavior. | Upstream scope tag used for TLS-front metadata fetches. | +| tls_fetch | `Table` | built-in defaults | See `[censorship.tls_fetch]` section below. | TLS-front metadata fetch strategy settings. | | mask | `bool` | `true` | — | Enables masking/fronting relay mode. | | mask_host | `String \| null` | `null` | — | Upstream mask host for TLS fronting relay. | | mask_port | `u16` | `443` | — | Upstream mask port for TLS fronting relay. | @@ -266,10 +283,24 @@ This document lists all configuration keys accepted by `config.toml`. | mask_shape_bucket_cap_bytes | `usize` | `4096` | Must be `>= mask_shape_bucket_floor_bytes`. | Maximum bucket size used by shape-channel hardening; traffic above cap is not padded further. | | mask_shape_above_cap_blur | `bool` | `false` | Requires `mask_shape_hardening = true`; requires `mask_shape_above_cap_blur_max_bytes > 0`. | Adds bounded randomized tail bytes even when forwarded size already exceeds cap. | | mask_shape_above_cap_blur_max_bytes | `usize` | `512` | Must be `<= 1048576`; must be `> 0` when `mask_shape_above_cap_blur = true`. | Maximum randomized extra bytes appended above cap. | +| mask_relay_max_bytes | `usize` | `5242880` | Must be `> 0`; must be `<= 67108864`. | Maximum relayed bytes per direction on unauthenticated masking fallback path. | +| mask_classifier_prefetch_timeout_ms | `u64` | `5` | Must be within `[5, 50]`. | Timeout budget (ms) for extending fragmented initial classifier window on masking fallback. | | mask_timing_normalization_enabled | `bool` | `false` | Requires `mask_timing_normalization_floor_ms > 0`; requires `ceiling >= floor`. | Enables timing envelope normalization on masking outcomes. | | mask_timing_normalization_floor_ms | `u64` | `0` | Must be `> 0` when timing normalization is enabled; must be `<= ceiling`. | Lower bound (ms) for masking outcome normalization target. | | mask_timing_normalization_ceiling_ms | `u64` | `0` | Must be `>= floor`; must be `<= 60000`. | Upper bound (ms) for masking outcome normalization target. | +## [censorship.tls_fetch] + +| Parameter | Type | Default | Constraints / validation | Description | +|---|---|---|---|---| +| profiles | `("modern_chrome_like" \| "modern_firefox_like" \| "compat_tls12" \| "legacy_minimal")[]` | `["modern_chrome_like", "modern_firefox_like", "compat_tls12", "legacy_minimal"]` | Empty list falls back to defaults; values are deduplicated preserving order. | Ordered ClientHello profile fallback chain for TLS-front metadata fetch. | +| strict_route | `bool` | `true` | — | Fails closed on upstream-route connect errors instead of falling back to direct TCP when route is configured. | +| attempt_timeout_ms | `u64` | `5000` | Must be `> 0`. | Timeout budget per one TLS-fetch profile attempt (ms). | +| total_budget_ms | `u64` | `15000` | Must be `> 0`. | Total wall-clock budget across all TLS-fetch attempts (ms). | +| grease_enabled | `bool` | `false` | — | Enables GREASE-style random values in selected ClientHello extensions for fetch traffic. | +| deterministic | `bool` | `false` | — | Enables deterministic ClientHello randomness for debugging/tests. | +| profile_cache_ttl_secs | `u64` | `600` | `0` disables cache. | TTL for winner-profile cache entries used by TLS fetch path. | + ### Shape-channel hardening notes (`[censorship]`) These parameters are designed to reduce one specific fingerprint source during masking: the exact number of bytes sent from proxy to `mask_host` for invalid or probing traffic. diff --git a/docs/FAQ.en.md b/docs/FAQ.en.md index 4af1c34..3d84348 100644 --- a/docs/FAQ.en.md +++ b/docs/FAQ.en.md @@ -3,7 +3,7 @@ 1. Go to @MTProxybot bot. 2. Enter the command `/newproxy` 3. Send the server IP and port. For example: 1.2.3.4:443 -4. Open the config `nano /etc/telemt.toml`. +4. Open the config `nano /etc/telemt/telemt.toml`. 5. Copy and send the user secret from the [access.users] section to the bot. 6. Copy the tag received from the bot. For example 1234567890abcdef1234567890abcdef. > [!WARNING] @@ -33,6 +33,9 @@ hello = "ad_tag" hello2 = "ad_tag2" ``` +## Why is middle proxy (ME) needed +https://github.com/telemt/telemt/discussions/167 + ## How many people can use 1 link By default, 1 link can be used by any number of people. @@ -60,9 +63,12 @@ user3 = "00000000000000000000000000000003" curl -s http://127.0.0.1:9091/v1/users | jq ``` +## "Unknown TLS SNI" Error +You probably updated tls_domain, but users are still connecting via old links with the previous domain. + ## How to view metrics -1. Open the config `nano /etc/telemt.toml` +1. Open the config `nano /etc/telemt/telemt.toml` 2. Add the following parameters ```toml [server] diff --git a/docs/FAQ.ru.md b/docs/FAQ.ru.md index ae38cab..15c5d8c 100644 --- a/docs/FAQ.ru.md +++ b/docs/FAQ.ru.md @@ -3,7 +3,7 @@ 1. Зайти в бота @MTProxybot. 2. Ввести команду `/newproxy` 3. Отправить IP и порт сервера. Например: 1.2.3.4:443 -4. Открыть конфиг `nano /etc/telemt.toml`. +4. Открыть конфиг `nano /etc/telemt/telemt.toml`. 5. Скопировать и отправить боту секрет пользователя из раздела [access.users]. 6. Скопировать полученный tag у бота. Например 1234567890abcdef1234567890abcdef. > [!WARNING] @@ -33,6 +33,10 @@ hello = "ad_tag" hello2 = "ad_tag2" ``` +## Зачем нужен middle proxy (ME) +https://github.com/telemt/telemt/discussions/167 + + ## Сколько человек может пользоваться 1 ссылкой По умолчанию 1 ссылкой может пользоваться сколько угодно человек. @@ -60,9 +64,12 @@ user3 = "00000000000000000000000000000003" curl -s http://127.0.0.1:9091/v1/users | jq ``` +## Ошибка "Unknown TLS SNI" +Возможно, вы обновили tls_domain, но пользователи всё ещё пытаются подключаться по старым ссылкам с прежним доменом. + ## Как посмотреть метрики -1. Открыть конфиг `nano /etc/telemt.toml` +1. Открыть конфиг `nano /etc/telemt/telemt.toml` 2. Добавить следующие параметры ```toml [server] diff --git a/docs/QUICK_START_GUIDE.en.md b/docs/QUICK_START_GUIDE.en.md index ffb387f..f6df4c4 100644 --- a/docs/QUICK_START_GUIDE.en.md +++ b/docs/QUICK_START_GUIDE.en.md @@ -27,12 +27,12 @@ chmod +x /bin/telemt **0. Check port and generate secrets** -The port you have selected for use should be MISSING from the list, when: +The port you have selected for use should not be in the list: ```bash netstat -lnp ``` -Generate 16 bytes/32 characters HEX with OpenSSL or another way: +Generate 16 bytes/32 characters in HEX format with OpenSSL or another way: ```bash openssl rand -hex 16 ``` @@ -50,7 +50,7 @@ Save the obtained result somewhere. You will need it later! **1. Place your config to /etc/telemt/telemt.toml** -Create config directory: +Create the config directory: ```bash mkdir /etc/telemt ``` @@ -59,7 +59,7 @@ Open nano ```bash nano /etc/telemt/telemt.toml ``` -paste your config +Insert your configuration: ```toml # === General Settings === @@ -93,8 +93,9 @@ hello = "00000000000000000000000000000000" then Ctrl+S -> Ctrl+X to save > [!WARNING] -> Replace the value of the hello parameter with the value you obtained in step 0. -> Replace the value of the tls_domain parameter with another website. +> Replace the value of the hello parameter with the value you obtained in step 0. +> Additionally, change the value of the tls_domain parameter to a different website. +> Changing the tls_domain parameter will break all links that use the old domain! --- @@ -105,14 +106,14 @@ useradd -d /opt/telemt -m -r -U telemt chown -R telemt:telemt /etc/telemt ``` -**3. Create service on /etc/systemd/system/telemt.service** +**3. Create service in /etc/systemd/system/telemt.service** Open nano ```bash nano /etc/systemd/system/telemt.service ``` -paste this Systemd Module +Insert this Systemd module: ```bash [Unit] Description=Telemt @@ -147,13 +148,16 @@ systemctl daemon-reload **6.** For automatic startup at system boot, enter `systemctl enable telemt` -**7.** To get the link(s), enter +**7.** To get the link(s), enter: ```bash curl -s http://127.0.0.1:9091/v1/users | jq ``` > Any number of people can use one link. +> [!WARNING] +> Only the command from step 7 can provide a working link. Do not try to create it yourself or copy it from anywhere if you are not sure what you are doing! + --- # Telemt via Docker Compose diff --git a/docs/QUICK_START_GUIDE.ru.md b/docs/QUICK_START_GUIDE.ru.md index c90e0de..3925953 100644 --- a/docs/QUICK_START_GUIDE.ru.md +++ b/docs/QUICK_START_GUIDE.ru.md @@ -95,6 +95,7 @@ hello = "00000000000000000000000000000000" > [!WARNING] > Замените значение параметра hello на значение, которое вы получили в пункте 0. > Так же замените значение параметра tls_domain на другой сайт. +> Изменение параметра tls_domain сделает нерабочими все ссылки, использующие старый домен! --- diff --git a/docs/VPS_DOUBLE_HOP.en.md b/docs/VPS_DOUBLE_HOP.en.md new file mode 100644 index 0000000..6b6abe5 --- /dev/null +++ b/docs/VPS_DOUBLE_HOP.en.md @@ -0,0 +1,287 @@ + + +## Concept +- **Server A** (__conditionally Russian Federation_):\ + Entry point, receives Telegram proxy user traffic via **HAProxy** (port `443`)\ + and sends it to the tunnel to Server **B**.\ + Internal IP in the tunnel — `10.10.10.2`\ + Port for HAProxy clients — `443\tcp` +- **Server B** (_conditionally Netherlands_):\ + Exit point, runs **telemt** and accepts client connections through Server **A**.\ + The server must have unrestricted access to Telegram servers.\ + Internal IP in the tunnel — `10.10.10.1`\ + AmneziaWG port — `8443\udp`\ + Port for telemt clients — `443\tcp` + +--- + +## Step 1. Setting up the AmneziaWG tunnel (A <-> B) +[AmneziaWG](https://github.com/amnezia-vpn/amneziawg-linux-kernel-module) must be installed on all servers.\ +All following commands are given for **Ubuntu 24.04**.\ +For RHEL-based distributions, installation instructions are available at the link above. + +### Installing AmneziaWG (Servers A and B) +The following steps must be performed on each server: + +#### 1. Adding the AmneziaWG repository and installing required packages: +```bash +sudo apt install -y software-properties-common python3-launchpadlib gnupg2 linux-headers-$(uname -r) && \ +sudo add-apt-repository ppa:amnezia/ppa && \ +sudo apt-get install -y amneziawg +``` + +#### 2. Generating a unique key pair: +```bash +cd /etc/amnezia/amneziawg && \ +awg genkey | tee private.key | awg pubkey > public.key +``` + +As a result, you will get two files in the `/etc/amnezia/amneziawg` folder:\ +`private.key` - private, and\ +`public.key` - public server keys + +#### 3. Configuring network interfaces: +Obfuscation parameters `S1`, `S2`, `H1`, `H2`, `H3`, `H4` must be strictly identical on both servers.\ +Parameters `Jc`, `Jmin` and `Jmax` can differ.\ +Parameters `I1-I5` ([Custom Protocol Signature](https://docs.amnezia.org/documentation/amnezia-wg/)) must be specified on the client side (Server **A**). + +Recommendations for choosing values: + +```text +Jc — 1 ≤ Jc ≤ 128; from 4 to 12 inclusive +Jmin — Jmax > Jmin < 1280*; recommended 8 +Jmax — Jmin < Jmax ≤ 1280*; recommended 80 +S1 — S1 ≤ 1132* (1280* - 148 = 1132); S1 + 56 ≠ S2; +recommended range from 15 to 150 inclusive +S2 — S2 ≤ 1188* (1280* - 92 = 1188); +recommended range from 15 to 150 inclusive +H1/H2/H3/H4 — must be unique and differ from each other; +recommended range from 5 to 2147483647 inclusive + +* It is assumed that the Internet connection has an MTU of 1280. +``` + +> [!IMPORTANT] +> It is recommended to use your own, unique values.\ +> You can use the [generator](https://htmlpreview.github.io/?https://gist.githubusercontent.com/avbor/955782b5c37b06240b243aa375baeac5/raw/13f5517ca473b47c412b9a99407066de973732bd/awg-gen.html) to select parameters. + +#### Server B Configuration (Netherlands): + +Create the interface configuration file (`awg0`) +```bash +nano /etc/amnezia/amneziawg/awg0.conf +``` + +File content +```ini +[Interface] +Address = 10.10.10.1/24 +ListenPort = 8443 +PrivateKey = +SaveConfig = true +Jc = 4 +Jmin = 8 +Jmax = 80 +S1 = 29 +S2 = 15 +S3 = 18 +S4 = 0 +H1 = 2087563914 +H2 = 188817757 +H3 = 101784570 +H4 = 432174303 + +[Peer] +PublicKey = +AllowedIPs = 10.10.10.2/32 +``` +`ListenPort` - the port on which the server will wait for connections, you can choose any free one.\ +`` - the content of the `private.key` file from Server **B**.\ +`` - the content of the `public.key` file from Server **A**. + +Open the port on the firewall (if enabled): +```bash +sudo ufw allow from to any port 8443 proto udp +``` + +`` - the external IP address of Server **A**. + +#### Server A Configuration (Russian Federation): +Create the interface configuration file (awg0) + +```bash +nano /etc/amnezia/amneziawg/awg0.conf +``` + +File content +```ini +[Interface] +Address = 10.10.10.2/24 +PrivateKey = +Jc = 4 +Jmin = 8 +Jmax = 80 +S1 = 29 +S2 = 15 +S3 = 18 +S4 = 0 +H1 = 2087563914 +H2 = 188817757 +H3 = 101784570 +H4 = 432174303 +I1 = +I2 = +I3 = +I4 = +I5 = + +[Peer] +PublicKey = +Endpoint = :8443 +AllowedIPs = 10.10.10.1/32 +PersistentKeepalive = 25 +``` + +`` - the content of the `private.key` file from Server **A**.\ +`` - the content of the `public.key` file from Server **B**.\ +`` - the public IP address of Server **B**. + +Enable the tunnel on both servers: +```bash +sudo systemctl enable --now awg-quick@awg0 +``` + +Make sure Server B is accessible from Server A through the tunnel. +```bash +ping 10.10.10.1 +PING 10.10.10.1 (10.10.10.1) 56(84) bytes of data. +64 bytes from 10.10.10.1: icmp_seq=1 ttl=64 time=35.1 ms +64 bytes from 10.10.10.1: icmp_seq=2 ttl=64 time=35.0 ms +64 bytes from 10.10.10.1: icmp_seq=3 ttl=64 time=35.1 ms +^C +``` +--- + +## Step 2. Installing telemt on Server B (conditionally Netherlands) +Installation and configuration are described [here](https://github.com/telemt/telemt/blob/main/docs/QUICK_START_GUIDE.ru.md) or [here](https://gitlab.com/An0nX/telemt-docker#-quick-start-docker-compose).\ +It is assumed that telemt expects connections on port `443\tcp`. + +In the telemt config, you must enable the `Proxy` protocol and restrict connections to it only through the tunnel. +```toml +[server] +port = 443 +listen_addr_ipv4 = "10.10.10.1" +proxy_protocol = true +``` + +Also, for correct link generation, specify the FQDN or IP address and port of Server `A` +```toml +[general.links] +show = "*" +public_host = "" +public_port = 443 +``` + +Open the port on the firewall (if enabled): +```bash +sudo ufw allow from 10.10.10.2 to any port 443 proto tcp +``` + +--- + +## Step 3. Configuring HAProxy on Server A (Russian Federation) +Since the version in the standard Ubuntu repository is relatively old, it makes sense to use the official Docker image.\ +[Instructions](https://docs.docker.com/engine/install/ubuntu/) for installing Docker on Ubuntu. + +> [!WARNING] +> By default, regular users do not have rights to use ports < 1024. +> Attempts to run HAProxy on port 443 can lead to errors: +> ``` +> [ALERT] (8) : Binding [/usr/local/etc/haproxy/haproxy.cfg:17] for frontend tcp_in_443: +> protocol tcpv4: cannot bind socket (Permission denied) for [0.0.0.0:443]. +> ``` +> There are two simple ways to bypass this restriction, choose one: +> 1. At the OS level, change the net.ipv4.ip_unprivileged_port_start setting to allow users to use all ports: +> ``` +> echo "net.ipv4.ip_unprivileged_port_start = 0" | sudo tee -a /etc/sysctl.conf && sudo sysctl -p +> ``` +> or +> +> 2. Run HAProxy as root: +> Uncomment the `user: "root"` parameter in docker-compose.yaml. + +#### Create a folder for HAProxy: +```bash +mkdir -p /opt/docker-compose/haproxy && cd $_ +``` + +#### Create the docker-compose.yaml file +`nano docker-compose.yaml` + +File content +```yaml +services: + haproxy: + image: haproxy:latest + container_name: haproxy + restart: unless-stopped + # user: "root" + network_mode: "host" + volumes: + - ./haproxy.cfg:/usr/local/etc/haproxy/haproxy.cfg:ro + logging: + driver: "json-file" + options: + max-size: "1m" + max-file: "1" +``` + +#### Create the haproxy.cfg config file +Accept connections on port 443\tcp and send them through the tunnel to Server `B` 10.10.10.1:443 + +`nano haproxy.cfg` + +File content + +```haproxy +global + log stdout format raw local0 + maxconn 10000 + +defaults + log global + mode tcp + option tcplog + option clitcpka + option srvtcpka + timeout connect 5s + timeout client 2h + timeout server 2h + timeout check 5s + +frontend tcp_in_443 + bind *:443 + maxconn 8000 + option tcp-smart-accept + default_backend telemt_nodes + +backend telemt_nodes + option tcp-smart-connect + server server_a 10.10.10.1:443 check inter 5s rise 2 fall 3 send-proxy-v2 + + +``` +> [!WARNING] +> **The file must end with an empty line, otherwise HAProxy will not start!** + +#### Allow port 443\tcp in the firewall (if enabled) +```bash +sudo ufw allow 443/tcp +``` + +#### Start the HAProxy container +```bash +docker compose up -d +``` + +If everything is configured correctly, you can now try connecting Telegram clients using links from the telemt log\api. diff --git a/docs/VPS_DOUBLE_HOP.ru.md b/docs/VPS_DOUBLE_HOP.ru.md new file mode 100644 index 0000000..037dfcb --- /dev/null +++ b/docs/VPS_DOUBLE_HOP.ru.md @@ -0,0 +1,291 @@ + + +## Концепция +- **Сервер A** (_РФ_):\ + Точка входа, принимает трафик пользователей Telegram-прокси через **HAProxy** (порт `443`)\ + и отправляет в туннель на Сервер **B**.\ + Внутренний IP в туннеле — `10.10.10.2`\ + Порт для клиентов HAProxy — `443\tcp` +- **Сервер B** (_условно Нидерланды_):\ + Точка выхода, на нем работает **telemt** и принимает подключения клиентов через Сервер **A**.\ + На сервере должен быть неограниченный доступ до серверов Telegram.\ + Внутренний IP в туннеле — `10.10.10.1`\ + Порт AmneziaWG — `8443\udp`\ + Порт для клиентов telemt — `443\tcp` + +--- + +## Шаг 1. Настройка туннеля AmneziaWG (A <-> B) + +На всех серверах необходимо установить [amneziawg](https://github.com/amnezia-vpn/amneziawg-linux-kernel-module).\ +Далее все команды даны для **Ununtu 24.04**.\ +Для RHEL-based дистрибутивов инструкция по установке есть по ссылке выше. + +### Установка AmneziaWG (Сервера A и B) +На каждом из серверов необходимо выполнить следующие шаги: + +#### 1. Добавление репозитория AmneziaWG и установка необходимых пакетов: +```bash +sudo apt install -y software-properties-common python3-launchpadlib gnupg2 linux-headers-$(uname -r) && \ +sudo add-apt-repository ppa:amnezia/ppa && \ +sudo apt-get install -y amneziawg +``` + +#### 2. Генерация уникальной пары ключей: +```bash +cd /etc/amnezia/amneziawg && \ +awg genkey | tee private.key | awg pubkey > public.key +``` +В результате вы получите в папке `/etc/amnezia/amneziawg` два файла:\ +`private.key` - приватный и\ +`public.key` - публичный ключи сервера + +#### 3. Настройка сетевых интерфейсов: + +Параметры обфускации `S1`, `S2`, `H1`, `H2`, `H3`, `H4` должны быть строго идентичными на обоих серверах.\ +Параметры `Jc`, `Jmin` и `Jmax` могут отличатся.\ +Параметры `I1-I5` ([Custom Protocol Signature](https://docs.amnezia.org/documentation/amnezia-wg/)) нужно указывать на стороне _клиента_ (Сервер **А**). + +Рекомендации по выбору значений: +```text +Jc — 1 ≤ Jc ≤ 128; от 4 до 12 включительно +Jmin — Jmax > Jmin < 1280*; рекомендовано 8 +Jmax — Jmin < Jmax ≤ 1280*; рекомендовано 80 +S1 — S1 ≤ 1132* (1280* - 148 = 1132); S1 + 56 ≠ S2; +рекомендованный диапазон от 15 до 150 включительно +S2 — S2 ≤ 1188* (1280* - 92 = 1188); +рекомендованный диапазон от 15 до 150 включительно +H1/H2/H3/H4 — должны быть уникальны и отличаться друг от друга; +рекомендованный диапазон от 5 до 2147483647 включительно + +* Предполагается, что подключение к Интернету имеет MTU 1280. +``` +> [!IMPORTANT] +> Рекомендуется использовать собственные, уникальные значения.\ +> Для выбора параметров можете воспользоваться [генератором](https://htmlpreview.github.io/?https://gist.githubusercontent.com/avbor/955782b5c37b06240b243aa375baeac5/raw/13f5517ca473b47c412b9a99407066de973732bd/awg-gen.html). + +#### Конфигурация Сервера B (_Нидерланды_): + +Создаем файл конфигурации интерфейса (`awg0`) +```bash +nano /etc/amnezia/amneziawg/awg0.conf +``` + +Содержимое файла +```ini +[Interface] +Address = 10.10.10.1/24 +ListenPort = 8443 +PrivateKey = +SaveConfig = true +Jc = 4 +Jmin = 8 +Jmax = 80 +S1 = 29 +S2 = 15 +S3 = 18 +S4 = 0 +H1 = 2087563914 +H2 = 188817757 +H3 = 101784570 +H4 = 432174303 + +[Peer] +PublicKey = +AllowedIPs = 10.10.10.2/32 +``` + +`ListenPort` - порт, на котором сервер будет ждать подключения, можете выбрать любой свободный.\ +`` - содержимое файла `private.key` с сервера **B**.\ +`` - содержимое файла `public.key` с сервера **A**. + +Открываем порт на фаерволе (если включен): +```bash +sudo ufw allow from to any port 8443 proto udp +``` + +`` - внешний IP адрес Сервера **A**. + +#### Конфигурация Сервера A (_РФ_): + +Создаем файл конфигурации интерфейса (`awg0`) +```bash +nano /etc/amnezia/amneziawg/awg0.conf +``` + +Содержимое файла +```ini +[Interface] +Address = 10.10.10.2/24 +PrivateKey = +Jc = 4 +Jmin = 8 +Jmax = 80 +S1 = 29 +S2 = 15 +S3 = 18 +S4 = 0 +H1 = 2087563914 +H2 = 188817757 +H3 = 101784570 +H4 = 432174303 +I1 = +I2 = +I3 = +I4 = +I5 = + +[Peer] +PublicKey = +Endpoint = :8443 +AllowedIPs = 10.10.10.1/32 +PersistentKeepalive = 25 +``` + +`` - содержимое файла `private.key` с сервера **A**.\ +`` - содержимое файла `public.key` с сервера **B**.\ +`` - публичный IP адресс сервера **B**. + +#### Включаем туннель на обоих серверах: +```bash +sudo systemctl enable --now awg-quick@awg0 +``` + +Убедитесь, что с Сервера `A` доступен Сервер `B` через туннель. +```bash +ping 10.10.10.1 +PING 10.10.10.1 (10.10.10.1) 56(84) bytes of data. +64 bytes from 10.10.10.1: icmp_seq=1 ttl=64 time=35.1 ms +64 bytes from 10.10.10.1: icmp_seq=2 ttl=64 time=35.0 ms +64 bytes from 10.10.10.1: icmp_seq=3 ttl=64 time=35.1 ms +^C + +``` + +--- + +## Шаг 2. Установка telemt на Сервере B (_условно Нидерланды_) + +Установка и настройка описаны [здесь](https://github.com/telemt/telemt/blob/main/docs/QUICK_START_GUIDE.ru.md) или [здесь](https://gitlab.com/An0nX/telemt-docker#-quick-start-docker-compose).\ +Подразумевается что telemt ожидает подключения на порту `443\tcp`. + +В конфиге telemt необходимо включить протокол `Proxy` и ограничить подключения к нему только через туннель. + +```toml +[server] +port = 443 +listen_addr_ipv4 = "10.10.10.1" +proxy_protocol = true +``` + +А также, для правильной генерации ссылок, указать FQDN или IP адрес и порт Сервера `A` + +```toml +[general.links] +show = "*" +public_host = "" +public_port = 443 +``` + +Открываем порт на фаерволе (если включен): +```bash +sudo ufw allow from 10.10.10.2 to any port 443 proto tcp +``` + +--- + +### Шаг 3. Настройка HAProxy на Сервере A (_РФ_) + +Т.к. в стандартном репозитории Ubuntu версия относительно старая, имеет смысл воспользоваться официальным образом Docker.\ +[Инструкция](https://docs.docker.com/engine/install/ubuntu/) по установке Docker на Ubuntu. + +> [!WARNING] +> По умолчанию у обычных пользователей нет прав на использование портов < 1024.\ +> Попытки запустить HAProxy на 443 порту могут приводить к ошибкам: +> ``` +> [ALERT] (8) : Binding [/usr/local/etc/haproxy/haproxy.cfg:17] for frontend tcp_in_443: +> protocol tcpv4: cannot bind socket (Permission denied) for [0.0.0.0:443]. +> ``` +> Есть два простых способа обойти это ограничение, выберите что-то одно: +> 1. На уровне ОС изменить настройку net.ipv4.ip_unprivileged_port_start, разрешив пользователям использовать все порты: +> ``` +> echo "net.ipv4.ip_unprivileged_port_start = 0" | sudo tee -a /etc/sysctl.conf && sudo sysctl -p +> ``` +> или +> +> 2. Запустить HAProxy под root:\ +> Раскомментируйте в docker-compose.yaml параметр `user: "root"`. + +#### Создаем папку для HAProxy: +```bash +mkdir -p /opt/docker-compose/haproxy && cd $_ +``` +#### Создаем файл docker-compose.yaml + +`nano docker-compose.yaml` + +Содержимое файла +```yaml +services: + haproxy: + image: haproxy:latest + container_name: haproxy + restart: unless-stopped + # user: "root" + network_mode: "host" + volumes: + - ./haproxy.cfg:/usr/local/etc/haproxy/haproxy.cfg:ro + logging: + driver: "json-file" + options: + max-size: "1m" + max-file: "1" +``` +#### Создаем файл конфига haproxy.cfg +Принимаем подключения на порту 443\tcp и отправляем их через туннель на Сервер `B` 10.10.10.1:443 + +`nano haproxy.cfg` + +Содержимое файла +```haproxy +global + log stdout format raw local0 + maxconn 10000 + +defaults + log global + mode tcp + option tcplog + option clitcpka + option srvtcpka + timeout connect 5s + timeout client 2h + timeout server 2h + timeout check 5s + +frontend tcp_in_443 + bind *:443 + maxconn 8000 + option tcp-smart-accept + default_backend telemt_nodes + +backend telemt_nodes + option tcp-smart-connect + server server_a 10.10.10.1:443 check inter 5s rise 2 fall 3 send-proxy-v2 + + +``` +>[!WARNING] +>**Файл должен заканчиваться пустой строкой, иначе HAProxy не запустится!** + +#### Разрешаем порт 443\tcp в фаерволе (если включен) +```bash +sudo ufw allow 443/tcp +``` + +#### Запускаем контейнер HAProxy +```bash +docker compose up -d +``` + +Если все настроено верно, то теперь можно пробовать подключить клиентов Telegram с использованием ссылок из лога\api telemt. diff --git a/src/api/mod.rs b/src/api/mod.rs index c1e3557..c0eab87 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -42,6 +42,7 @@ use events::ApiEventStore; use http_utils::{error_response, read_json, read_optional_json, success_response}; use model::{ ApiFailure, CreateUserRequest, HealthData, PatchUserRequest, RotateSecretRequest, SummaryData, + UserActiveIps, }; use runtime_edge::{ EdgeConnectionsCacheEntry, build_runtime_connections_summary_data, @@ -362,6 +363,18 @@ async fn handle( ); Ok(success_response(StatusCode::OK, data, revision)) } + ("GET", "/v1/stats/users/active-ips") => { + let revision = current_revision(&shared.config_path).await?; + let usernames: Vec<_> = cfg.access.users.keys().cloned().collect(); + let active_ips_map = shared.ip_tracker.get_active_ips_for_users(&usernames).await; + let mut data: Vec = active_ips_map + .into_iter() + .filter(|(_, ips)| !ips.is_empty()) + .map(|(username, active_ips)| UserActiveIps { username, active_ips }) + .collect(); + data.sort_by(|a, b| a.username.cmp(&b.username)); + Ok(success_response(StatusCode::OK, data, revision)) + } ("GET", "/v1/stats/users") | ("GET", "/v1/users") => { let revision = current_revision(&shared.config_path).await?; let (detected_ip_v4, detected_ip_v6) = shared.detected_link_ips(); diff --git a/src/api/model.rs b/src/api/model.rs index 94b50f6..164042f 100644 --- a/src/api/model.rs +++ b/src/api/model.rs @@ -174,6 +174,24 @@ pub(super) struct ZeroMiddleProxyData { pub(super) route_drop_queue_full_total: u64, pub(super) route_drop_queue_full_base_total: u64, pub(super) route_drop_queue_full_high_total: u64, + pub(super) d2c_batches_total: u64, + pub(super) d2c_batch_frames_total: u64, + pub(super) d2c_batch_bytes_total: u64, + pub(super) d2c_flush_reason_queue_drain_total: u64, + pub(super) d2c_flush_reason_batch_frames_total: u64, + pub(super) d2c_flush_reason_batch_bytes_total: u64, + pub(super) d2c_flush_reason_max_delay_total: u64, + pub(super) d2c_flush_reason_ack_immediate_total: u64, + pub(super) d2c_flush_reason_close_total: u64, + pub(super) d2c_data_frames_total: u64, + pub(super) d2c_ack_frames_total: u64, + pub(super) d2c_payload_bytes_total: u64, + pub(super) d2c_write_mode_coalesced_total: u64, + pub(super) d2c_write_mode_split_total: u64, + pub(super) d2c_quota_reject_pre_write_total: u64, + pub(super) d2c_quota_reject_post_write_total: u64, + pub(super) d2c_frame_buf_shrink_total: u64, + pub(super) d2c_frame_buf_shrink_bytes_total: u64, pub(super) socks_kdf_strict_reject_total: u64, pub(super) socks_kdf_compat_fallback_total: u64, pub(super) endpoint_quarantine_total: u64, @@ -424,6 +442,12 @@ pub(super) struct UserInfo { pub(super) links: UserLinks, } +#[derive(Serialize)] +pub(super) struct UserActiveIps { + pub(super) username: String, + pub(super) active_ips: Vec, +} + #[derive(Serialize)] pub(super) struct CreateUserResponse { pub(super) user: UserInfo, diff --git a/src/api/runtime_stats.rs b/src/api/runtime_stats.rs index 94f27a9..b66d1a5 100644 --- a/src/api/runtime_stats.rs +++ b/src/api/runtime_stats.rs @@ -68,6 +68,25 @@ pub(super) fn build_zero_all_data(stats: &Stats, configured_users: usize) -> Zer route_drop_queue_full_total: stats.get_me_route_drop_queue_full(), route_drop_queue_full_base_total: stats.get_me_route_drop_queue_full_base(), route_drop_queue_full_high_total: stats.get_me_route_drop_queue_full_high(), + d2c_batches_total: stats.get_me_d2c_batches_total(), + d2c_batch_frames_total: stats.get_me_d2c_batch_frames_total(), + d2c_batch_bytes_total: stats.get_me_d2c_batch_bytes_total(), + d2c_flush_reason_queue_drain_total: stats.get_me_d2c_flush_reason_queue_drain_total(), + d2c_flush_reason_batch_frames_total: stats.get_me_d2c_flush_reason_batch_frames_total(), + d2c_flush_reason_batch_bytes_total: stats.get_me_d2c_flush_reason_batch_bytes_total(), + d2c_flush_reason_max_delay_total: stats.get_me_d2c_flush_reason_max_delay_total(), + d2c_flush_reason_ack_immediate_total: stats + .get_me_d2c_flush_reason_ack_immediate_total(), + d2c_flush_reason_close_total: stats.get_me_d2c_flush_reason_close_total(), + d2c_data_frames_total: stats.get_me_d2c_data_frames_total(), + d2c_ack_frames_total: stats.get_me_d2c_ack_frames_total(), + d2c_payload_bytes_total: stats.get_me_d2c_payload_bytes_total(), + d2c_write_mode_coalesced_total: stats.get_me_d2c_write_mode_coalesced_total(), + d2c_write_mode_split_total: stats.get_me_d2c_write_mode_split_total(), + d2c_quota_reject_pre_write_total: stats.get_me_d2c_quota_reject_pre_write_total(), + d2c_quota_reject_post_write_total: stats.get_me_d2c_quota_reject_post_write_total(), + d2c_frame_buf_shrink_total: stats.get_me_d2c_frame_buf_shrink_total(), + d2c_frame_buf_shrink_bytes_total: stats.get_me_d2c_frame_buf_shrink_bytes_total(), socks_kdf_strict_reject_total: stats.get_me_socks_kdf_strict_reject(), socks_kdf_compat_fallback_total: stats.get_me_socks_kdf_compat_fallback(), endpoint_quarantine_total: stats.get_me_endpoint_quarantine_total(), diff --git a/src/api/runtime_zero.rs b/src/api/runtime_zero.rs index a6eb163..0ed84a8 100644 --- a/src/api/runtime_zero.rs +++ b/src/api/runtime_zero.rs @@ -35,11 +35,14 @@ pub(super) struct RuntimeGatesData { pub(super) conditional_cast_enabled: bool, pub(super) me_runtime_ready: bool, pub(super) me2dc_fallback_enabled: bool, + pub(super) me2dc_fast_enabled: bool, pub(super) use_middle_proxy: bool, pub(super) route_mode: &'static str, pub(super) reroute_active: bool, #[serde(skip_serializing_if = "Option::is_none")] pub(super) reroute_to_direct_at_epoch_secs: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub(super) reroute_reason: Option<&'static str>, pub(super) startup_status: &'static str, pub(super) startup_stage: String, pub(super) startup_progress_pct: f64, @@ -86,6 +89,7 @@ pub(super) struct EffectiveMiddleProxyLimits { pub(super) writer_pick_mode: &'static str, pub(super) writer_pick_sample_size: u8, pub(super) me2dc_fallback: bool, + pub(super) me2dc_fast: bool, } #[derive(Serialize)] @@ -169,6 +173,8 @@ pub(super) async fn build_runtime_gates_data( let startup_summary = build_runtime_startup_summary(shared).await; let route_state = shared.route_runtime.snapshot(); let route_mode = route_state.mode.as_str(); + let fast_fallback_enabled = + cfg.general.use_middle_proxy && cfg.general.me2dc_fallback && cfg.general.me2dc_fast; let reroute_active = cfg.general.use_middle_proxy && cfg.general.me2dc_fallback && matches!(route_state.mode, RelayRouteMode::Direct); @@ -177,6 +183,15 @@ pub(super) async fn build_runtime_gates_data( } else { None }; + let reroute_reason = if reroute_active { + if fast_fallback_enabled { + Some("fast_not_ready_fallback") + } else { + Some("strict_grace_fallback") + } + } else { + None + }; let me_runtime_ready = if !cfg.general.use_middle_proxy { true } else { @@ -194,10 +209,12 @@ pub(super) async fn build_runtime_gates_data( conditional_cast_enabled: cfg.general.use_middle_proxy, me_runtime_ready, me2dc_fallback_enabled: cfg.general.me2dc_fallback, + me2dc_fast_enabled: fast_fallback_enabled, use_middle_proxy: cfg.general.use_middle_proxy, route_mode, reroute_active, reroute_to_direct_at_epoch_secs, + reroute_reason, startup_status: startup_summary.status, startup_stage: startup_summary.stage, startup_progress_pct: startup_summary.progress_pct, @@ -263,6 +280,7 @@ pub(super) fn build_limits_effective_data(cfg: &ProxyConfig) -> EffectiveLimitsD writer_pick_mode: me_writer_pick_mode_label(cfg.general.me_writer_pick_mode), writer_pick_sample_size: cfg.general.me_writer_pick_sample_size, me2dc_fallback: cfg.general.me2dc_fallback, + me2dc_fast: cfg.general.me2dc_fast, }, user_ip_policy: EffectiveUserIpPolicyLimits { global_each: cfg.access.user_max_unique_ips_global_each, diff --git a/src/config/defaults.rs b/src/config/defaults.rs index 650d70d..608e1b8 100644 --- a/src/config/defaults.rs +++ b/src/config/defaults.rs @@ -29,6 +29,8 @@ const DEFAULT_ME_D2C_FLUSH_BATCH_MAX_FRAMES: usize = 32; const DEFAULT_ME_D2C_FLUSH_BATCH_MAX_BYTES: usize = 128 * 1024; const DEFAULT_ME_D2C_FLUSH_BATCH_MAX_DELAY_US: u64 = 500; const DEFAULT_ME_D2C_ACK_FLUSH_IMMEDIATE: bool = true; +const DEFAULT_ME_QUOTA_SOFT_OVERSHOOT_BYTES: u64 = 64 * 1024; +const DEFAULT_ME_D2C_FRAME_BUF_SHRINK_THRESHOLD_BYTES: usize = 256 * 1024; const DEFAULT_DIRECT_RELAY_COPY_BUF_C2S_BYTES: usize = 64 * 1024; const DEFAULT_DIRECT_RELAY_COPY_BUF_S2C_BYTES: usize = 256 * 1024; const DEFAULT_ME_WRITER_PICK_SAMPLE_SIZE: u8 = 3; @@ -69,6 +71,22 @@ pub(crate) fn default_tls_fetch_scope() -> String { String::new() } +pub(crate) fn default_tls_fetch_attempt_timeout_ms() -> u64 { + 5_000 +} + +pub(crate) fn default_tls_fetch_total_budget_ms() -> u64 { + 15_000 +} + +pub(crate) fn default_tls_fetch_strict_route() -> bool { + true +} + +pub(crate) fn default_tls_fetch_profile_cache_ttl_secs() -> u64 { + 600 +} + pub(crate) fn default_mask_port() -> u16 { 443 } @@ -183,6 +201,10 @@ pub(crate) fn default_proxy_protocol_header_timeout_ms() -> u64 { 500 } +pub(crate) fn default_proxy_protocol_trusted_cidrs() -> Vec { + vec!["0.0.0.0/0".parse().unwrap(), "::/0".parse().unwrap()] +} + pub(crate) fn default_server_max_connections() -> u32 { 10_000 } @@ -251,6 +273,10 @@ pub(crate) fn default_me2dc_fallback() -> bool { true } +pub(crate) fn default_me2dc_fast() -> bool { + false +} + pub(crate) fn default_keepalive_interval() -> u64 { 8 } @@ -387,6 +413,14 @@ pub(crate) fn default_me_d2c_ack_flush_immediate() -> bool { DEFAULT_ME_D2C_ACK_FLUSH_IMMEDIATE } +pub(crate) fn default_me_quota_soft_overshoot_bytes() -> u64 { + DEFAULT_ME_QUOTA_SOFT_OVERSHOOT_BYTES +} + +pub(crate) fn default_me_d2c_frame_buf_shrink_threshold_bytes() -> usize { + DEFAULT_ME_D2C_FRAME_BUF_SHRINK_THRESHOLD_BYTES +} + pub(crate) fn default_direct_relay_copy_buf_c2s_bytes() -> usize { DEFAULT_DIRECT_RELAY_COPY_BUF_C2S_BYTES } @@ -543,6 +577,20 @@ pub(crate) fn default_mask_shape_above_cap_blur_max_bytes() -> usize { 512 } +#[cfg(not(test))] +pub(crate) fn default_mask_relay_max_bytes() -> usize { + 5 * 1024 * 1024 +} + +#[cfg(test)] +pub(crate) fn default_mask_relay_max_bytes() -> usize { + 32 * 1024 +} + +pub(crate) fn default_mask_classifier_prefetch_timeout_ms() -> u64 { + 5 +} + pub(crate) fn default_mask_timing_normalization_enabled() -> bool { false } diff --git a/src/config/hot_reload.rs b/src/config/hot_reload.rs index 39c31a1..9bd2927 100644 --- a/src/config/hot_reload.rs +++ b/src/config/hot_reload.rs @@ -106,6 +106,8 @@ pub struct HotFields { pub me_d2c_flush_batch_max_bytes: usize, pub me_d2c_flush_batch_max_delay_us: u64, pub me_d2c_ack_flush_immediate: bool, + pub me_quota_soft_overshoot_bytes: u64, + pub me_d2c_frame_buf_shrink_threshold_bytes: usize, pub direct_relay_copy_buf_c2s_bytes: usize, pub direct_relay_copy_buf_s2c_bytes: usize, pub me_health_interval_ms_unhealthy: u64, @@ -225,6 +227,10 @@ impl HotFields { me_d2c_flush_batch_max_bytes: cfg.general.me_d2c_flush_batch_max_bytes, me_d2c_flush_batch_max_delay_us: cfg.general.me_d2c_flush_batch_max_delay_us, me_d2c_ack_flush_immediate: cfg.general.me_d2c_ack_flush_immediate, + me_quota_soft_overshoot_bytes: cfg.general.me_quota_soft_overshoot_bytes, + me_d2c_frame_buf_shrink_threshold_bytes: cfg + .general + .me_d2c_frame_buf_shrink_threshold_bytes, direct_relay_copy_buf_c2s_bytes: cfg.general.direct_relay_copy_buf_c2s_bytes, direct_relay_copy_buf_s2c_bytes: cfg.general.direct_relay_copy_buf_s2c_bytes, me_health_interval_ms_unhealthy: cfg.general.me_health_interval_ms_unhealthy, @@ -511,6 +517,9 @@ fn overlay_hot_fields(old: &ProxyConfig, new: &ProxyConfig) -> ProxyConfig { cfg.general.me_d2c_flush_batch_max_bytes = new.general.me_d2c_flush_batch_max_bytes; cfg.general.me_d2c_flush_batch_max_delay_us = new.general.me_d2c_flush_batch_max_delay_us; cfg.general.me_d2c_ack_flush_immediate = new.general.me_d2c_ack_flush_immediate; + cfg.general.me_quota_soft_overshoot_bytes = new.general.me_quota_soft_overshoot_bytes; + cfg.general.me_d2c_frame_buf_shrink_threshold_bytes = + new.general.me_d2c_frame_buf_shrink_threshold_bytes; cfg.general.direct_relay_copy_buf_c2s_bytes = new.general.direct_relay_copy_buf_c2s_bytes; cfg.general.direct_relay_copy_buf_s2c_bytes = new.general.direct_relay_copy_buf_s2c_bytes; cfg.general.me_health_interval_ms_unhealthy = new.general.me_health_interval_ms_unhealthy; @@ -593,6 +602,9 @@ fn warn_non_hot_changes(old: &ProxyConfig, new: &ProxyConfig, non_hot_changed: b || old.censorship.mask_shape_above_cap_blur != new.censorship.mask_shape_above_cap_blur || old.censorship.mask_shape_above_cap_blur_max_bytes != new.censorship.mask_shape_above_cap_blur_max_bytes + || old.censorship.mask_relay_max_bytes != new.censorship.mask_relay_max_bytes + || old.censorship.mask_classifier_prefetch_timeout_ms + != new.censorship.mask_classifier_prefetch_timeout_ms || old.censorship.mask_timing_normalization_enabled != new.censorship.mask_timing_normalization_enabled || old.censorship.mask_timing_normalization_floor_ms @@ -639,6 +651,9 @@ fn warn_non_hot_changes(old: &ProxyConfig, new: &ProxyConfig, non_hot_changed: b } if old.general.me_route_no_writer_mode != new.general.me_route_no_writer_mode || old.general.me_route_no_writer_wait_ms != new.general.me_route_no_writer_wait_ms + || old.general.me_route_hybrid_max_wait_ms != new.general.me_route_hybrid_max_wait_ms + || old.general.me_route_blocking_send_timeout_ms + != new.general.me_route_blocking_send_timeout_ms || old.general.me_route_inline_recovery_attempts != new.general.me_route_inline_recovery_attempts || old.general.me_route_inline_recovery_wait_ms @@ -657,9 +672,11 @@ fn warn_non_hot_changes(old: &ProxyConfig, new: &ProxyConfig, non_hot_changed: b warned = true; warn!("config reload: general.me_init_retry_attempts changed; restart required"); } - if old.general.me2dc_fallback != new.general.me2dc_fallback { + if old.general.me2dc_fallback != new.general.me2dc_fallback + || old.general.me2dc_fast != new.general.me2dc_fast + { warned = true; - warn!("config reload: general.me2dc_fallback changed; restart required"); + warn!("config reload: general.me2dc_fallback/me2dc_fast changed; restart required"); } if old.general.proxy_config_v4_cache_path != new.general.proxy_config_v4_cache_path || old.general.proxy_config_v6_cache_path != new.general.proxy_config_v6_cache_path @@ -1030,15 +1047,20 @@ fn log_changes( || old_hot.me_d2c_flush_batch_max_bytes != new_hot.me_d2c_flush_batch_max_bytes || old_hot.me_d2c_flush_batch_max_delay_us != new_hot.me_d2c_flush_batch_max_delay_us || old_hot.me_d2c_ack_flush_immediate != new_hot.me_d2c_ack_flush_immediate + || old_hot.me_quota_soft_overshoot_bytes != new_hot.me_quota_soft_overshoot_bytes + || old_hot.me_d2c_frame_buf_shrink_threshold_bytes + != new_hot.me_d2c_frame_buf_shrink_threshold_bytes || old_hot.direct_relay_copy_buf_c2s_bytes != new_hot.direct_relay_copy_buf_c2s_bytes || old_hot.direct_relay_copy_buf_s2c_bytes != new_hot.direct_relay_copy_buf_s2c_bytes { info!( - "config reload: relay_tuning: me_d2c_frames={} me_d2c_bytes={} me_d2c_delay_us={} me_ack_flush_immediate={} direct_buf_c2s={} direct_buf_s2c={}", + "config reload: relay_tuning: me_d2c_frames={} me_d2c_bytes={} me_d2c_delay_us={} me_ack_flush_immediate={} me_quota_soft_overshoot_bytes={} me_d2c_frame_buf_shrink_threshold_bytes={} direct_buf_c2s={} direct_buf_s2c={}", new_hot.me_d2c_flush_batch_max_frames, new_hot.me_d2c_flush_batch_max_bytes, new_hot.me_d2c_flush_batch_max_delay_us, new_hot.me_d2c_ack_flush_immediate, + new_hot.me_quota_soft_overshoot_bytes, + new_hot.me_d2c_frame_buf_shrink_threshold_bytes, new_hot.direct_relay_copy_buf_c2s_bytes, new_hot.direct_relay_copy_buf_s2c_bytes, ); diff --git a/src/config/load.rs b/src/config/load.rs index 2382878..7892e2c 100644 --- a/src/config/load.rs +++ b/src/config/load.rs @@ -1,6 +1,6 @@ #![allow(deprecated)] -use std::collections::{BTreeSet, HashMap}; +use std::collections::{BTreeSet, HashMap, HashSet}; use std::hash::{DefaultHasher, Hash, Hasher}; use std::net::{IpAddr, SocketAddr}; use std::path::{Path, PathBuf}; @@ -430,6 +430,24 @@ impl ProxyConfig { )); } + if config.censorship.mask_relay_max_bytes == 0 { + return Err(ProxyError::Config( + "censorship.mask_relay_max_bytes must be > 0".to_string(), + )); + } + + if config.censorship.mask_relay_max_bytes > 67_108_864 { + return Err(ProxyError::Config( + "censorship.mask_relay_max_bytes must be <= 67108864".to_string(), + )); + } + + if !(5..=50).contains(&config.censorship.mask_classifier_prefetch_timeout_ms) { + return Err(ProxyError::Config( + "censorship.mask_classifier_prefetch_timeout_ms must be within [5, 50]".to_string(), + )); + } + if config.censorship.mask_timing_normalization_ceiling_ms < config.censorship.mask_timing_normalization_floor_ms { @@ -533,6 +551,21 @@ impl ProxyConfig { )); } + if config.general.me_quota_soft_overshoot_bytes > 16 * 1024 * 1024 { + return Err(ProxyError::Config( + "general.me_quota_soft_overshoot_bytes must be within [0, 16777216]".to_string(), + )); + } + + if !(4096..=16 * 1024 * 1024) + .contains(&config.general.me_d2c_frame_buf_shrink_threshold_bytes) + { + return Err(ProxyError::Config( + "general.me_d2c_frame_buf_shrink_threshold_bytes must be within [4096, 16777216]" + .to_string(), + )); + } + if !(4096..=1024 * 1024).contains(&config.general.direct_relay_copy_buf_c2s_bytes) { return Err(ProxyError::Config( "general.direct_relay_copy_buf_c2s_bytes must be within [4096, 1048576]" @@ -944,6 +977,28 @@ impl ProxyConfig { // Normalize optional TLS fetch scope: whitespace-only values disable scoped routing. config.censorship.tls_fetch_scope = config.censorship.tls_fetch_scope.trim().to_string(); + if config.censorship.tls_fetch.profiles.is_empty() { + config.censorship.tls_fetch.profiles = TlsFetchConfig::default().profiles; + } else { + let mut seen = HashSet::new(); + config + .censorship + .tls_fetch + .profiles + .retain(|profile| seen.insert(*profile)); + } + + if config.censorship.tls_fetch.attempt_timeout_ms == 0 { + return Err(ProxyError::Config( + "censorship.tls_fetch.attempt_timeout_ms must be > 0".to_string(), + )); + } + if config.censorship.tls_fetch.total_budget_ms == 0 { + return Err(ProxyError::Config( + "censorship.tls_fetch.total_budget_ms must be > 0".to_string(), + )); + } + // Merge primary + extra TLS domains, deduplicate (primary always first). if !config.censorship.tls_domains.is_empty() { let mut all = Vec::with_capacity(1 + config.censorship.tls_domains.len()); @@ -1121,6 +1176,10 @@ mod load_security_tests; #[path = "tests/load_mask_shape_security_tests.rs"] mod load_mask_shape_security_tests; +#[cfg(test)] +#[path = "tests/load_mask_classifier_prefetch_timeout_security_tests.rs"] +mod load_mask_classifier_prefetch_timeout_security_tests; + #[cfg(test)] mod tests { use super::*; @@ -1158,6 +1217,7 @@ mod tests { default_me_init_retry_attempts() ); assert_eq!(cfg.general.me2dc_fallback, default_me2dc_fallback()); + assert_eq!(cfg.general.me2dc_fast, default_me2dc_fast()); assert_eq!( cfg.general.proxy_config_v4_cache_path, default_proxy_config_v4_cache_path() @@ -1226,6 +1286,11 @@ mod tests { assert_eq!(cfg.general.update_every, default_update_every()); assert_eq!(cfg.server.listen_addr_ipv4, default_listen_addr_ipv4()); assert_eq!(cfg.server.listen_addr_ipv6, default_listen_addr_ipv6_opt()); + assert_eq!( + cfg.server.proxy_protocol_trusted_cidrs, + default_proxy_protocol_trusted_cidrs() + ); + assert_eq!(cfg.censorship.unknown_sni_action, UnknownSniAction::Drop); assert_eq!(cfg.server.api.listen, default_api_listen()); assert_eq!(cfg.server.api.whitelist, default_api_whitelist()); assert_eq!( @@ -1292,6 +1357,7 @@ mod tests { default_me_init_retry_attempts() ); assert_eq!(general.me2dc_fallback, default_me2dc_fallback()); + assert_eq!(general.me2dc_fast, default_me2dc_fast()); assert_eq!( general.proxy_config_v4_cache_path, default_proxy_config_v4_cache_path() @@ -1358,6 +1424,14 @@ mod tests { let server = ServerConfig::default(); assert_eq!(server.listen_addr_ipv6, Some(default_listen_addr_ipv6())); + assert_eq!( + server.proxy_protocol_trusted_cidrs, + default_proxy_protocol_trusted_cidrs() + ); + assert_eq!( + AntiCensorshipConfig::default().unknown_sni_action, + UnknownSniAction::Drop + ); assert_eq!(server.api.listen, default_api_listen()); assert_eq!(server.api.whitelist, default_api_whitelist()); assert_eq!( @@ -1393,6 +1467,75 @@ mod tests { assert_eq!(access.users, default_access_users()); } + #[test] + fn proxy_protocol_trusted_cidrs_missing_uses_trust_all_but_explicit_empty_stays_empty() { + let cfg_missing: ProxyConfig = toml::from_str( + r#" + [server] + [general] + [network] + [access] + "#, + ) + .unwrap(); + assert_eq!( + cfg_missing.server.proxy_protocol_trusted_cidrs, + default_proxy_protocol_trusted_cidrs() + ); + + let cfg_explicit_empty: ProxyConfig = toml::from_str( + r#" + [server] + proxy_protocol_trusted_cidrs = [] + + [general] + [network] + [access] + "#, + ) + .unwrap(); + assert!( + cfg_explicit_empty + .server + .proxy_protocol_trusted_cidrs + .is_empty() + ); + } + + #[test] + fn unknown_sni_action_parses_and_defaults_to_drop() { + let cfg_default: ProxyConfig = toml::from_str( + r#" + [server] + [general] + [network] + [access] + [censorship] + "#, + ) + .unwrap(); + assert_eq!( + cfg_default.censorship.unknown_sni_action, + UnknownSniAction::Drop + ); + + let cfg_mask: ProxyConfig = toml::from_str( + r#" + [server] + [general] + [network] + [access] + [censorship] + unknown_sni_action = "mask" + "#, + ) + .unwrap(); + assert_eq!( + cfg_mask.censorship.unknown_sni_action, + UnknownSniAction::Mask + ); + } + #[test] fn dc_overrides_allow_string_and_array() { let toml = r#" @@ -2340,6 +2483,94 @@ mod tests { let _ = std::fs::remove_file(path); } + #[test] + fn tls_fetch_defaults_are_applied() { + let toml = r#" + [censorship] + tls_domain = "example.com" + + [access.users] + user = "00000000000000000000000000000000" + "#; + let dir = std::env::temp_dir(); + let path = dir.join("telemt_tls_fetch_defaults_test.toml"); + std::fs::write(&path, toml).unwrap(); + let cfg = ProxyConfig::load(&path).unwrap(); + assert_eq!( + cfg.censorship.tls_fetch.profiles, + TlsFetchConfig::default().profiles + ); + assert!(cfg.censorship.tls_fetch.strict_route); + assert_eq!(cfg.censorship.tls_fetch.attempt_timeout_ms, 5_000); + assert_eq!(cfg.censorship.tls_fetch.total_budget_ms, 15_000); + assert_eq!(cfg.censorship.tls_fetch.profile_cache_ttl_secs, 600); + let _ = std::fs::remove_file(path); + } + + #[test] + fn tls_fetch_profiles_are_deduplicated_preserving_order() { + let toml = r#" + [censorship] + tls_domain = "example.com" + [censorship.tls_fetch] + profiles = ["compat_tls12", "modern_chrome_like", "compat_tls12", "legacy_minimal"] + + [access.users] + user = "00000000000000000000000000000000" + "#; + let dir = std::env::temp_dir(); + let path = dir.join("telemt_tls_fetch_profiles_dedup_test.toml"); + std::fs::write(&path, toml).unwrap(); + let cfg = ProxyConfig::load(&path).unwrap(); + assert_eq!( + cfg.censorship.tls_fetch.profiles, + vec![ + TlsFetchProfile::CompatTls12, + TlsFetchProfile::ModernChromeLike, + TlsFetchProfile::LegacyMinimal + ] + ); + let _ = std::fs::remove_file(path); + } + + #[test] + fn tls_fetch_attempt_timeout_zero_is_rejected() { + let toml = r#" + [censorship] + tls_domain = "example.com" + [censorship.tls_fetch] + attempt_timeout_ms = 0 + + [access.users] + user = "00000000000000000000000000000000" + "#; + let dir = std::env::temp_dir(); + let path = dir.join("telemt_tls_fetch_attempt_timeout_zero_test.toml"); + std::fs::write(&path, toml).unwrap(); + let err = ProxyConfig::load(&path).unwrap_err().to_string(); + assert!(err.contains("censorship.tls_fetch.attempt_timeout_ms must be > 0")); + let _ = std::fs::remove_file(path); + } + + #[test] + fn tls_fetch_total_budget_zero_is_rejected() { + let toml = r#" + [censorship] + tls_domain = "example.com" + [censorship.tls_fetch] + total_budget_ms = 0 + + [access.users] + user = "00000000000000000000000000000000" + "#; + let dir = std::env::temp_dir(); + let path = dir.join("telemt_tls_fetch_total_budget_zero_test.toml"); + std::fs::write(&path, toml).unwrap(); + let err = ProxyConfig::load(&path).unwrap_err().to_string(); + assert!(err.contains("censorship.tls_fetch.total_budget_ms must be > 0")); + let _ = std::fs::remove_file(path); + } + #[test] fn invalid_ad_tag_is_disabled_during_load() { let toml = r#" diff --git a/src/config/tests/load_mask_classifier_prefetch_timeout_security_tests.rs b/src/config/tests/load_mask_classifier_prefetch_timeout_security_tests.rs new file mode 100644 index 0000000..0b3d543 --- /dev/null +++ b/src/config/tests/load_mask_classifier_prefetch_timeout_security_tests.rs @@ -0,0 +1,76 @@ +use super::*; +use std::fs; +use std::path::PathBuf; +use std::time::{SystemTime, UNIX_EPOCH}; + +fn write_temp_config(contents: &str) -> PathBuf { + let nonce = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("system time must be after unix epoch") + .as_nanos(); + let path = std::env::temp_dir().join(format!( + "telemt-load-mask-prefetch-timeout-security-{nonce}.toml" + )); + fs::write(&path, contents).expect("temp config write must succeed"); + path +} + +fn remove_temp_config(path: &PathBuf) { + let _ = fs::remove_file(path); +} + +#[test] +fn load_rejects_mask_classifier_prefetch_timeout_below_min_bound() { + let path = write_temp_config( + r#" +[censorship] +mask_classifier_prefetch_timeout_ms = 4 +"#, + ); + + let err = ProxyConfig::load(&path) + .expect_err("prefetch timeout below minimum security bound must be rejected"); + let msg = err.to_string(); + assert!( + msg.contains("censorship.mask_classifier_prefetch_timeout_ms must be within [5, 50]"), + "error must explain timeout bound invariant, got: {msg}" + ); + + remove_temp_config(&path); +} + +#[test] +fn load_rejects_mask_classifier_prefetch_timeout_above_max_bound() { + let path = write_temp_config( + r#" +[censorship] +mask_classifier_prefetch_timeout_ms = 51 +"#, + ); + + let err = ProxyConfig::load(&path) + .expect_err("prefetch timeout above max security bound must be rejected"); + let msg = err.to_string(); + assert!( + msg.contains("censorship.mask_classifier_prefetch_timeout_ms must be within [5, 50]"), + "error must explain timeout bound invariant, got: {msg}" + ); + + remove_temp_config(&path); +} + +#[test] +fn load_accepts_mask_classifier_prefetch_timeout_within_bounds() { + let path = write_temp_config( + r#" +[censorship] +mask_classifier_prefetch_timeout_ms = 20 +"#, + ); + + let cfg = + ProxyConfig::load(&path).expect("prefetch timeout within security bounds must be accepted"); + assert_eq!(cfg.censorship.mask_classifier_prefetch_timeout_ms, 20); + + remove_temp_config(&path); +} diff --git a/src/config/tests/load_mask_shape_security_tests.rs b/src/config/tests/load_mask_shape_security_tests.rs index 8986a49..bccd36f 100644 --- a/src/config/tests/load_mask_shape_security_tests.rs +++ b/src/config/tests/load_mask_shape_security_tests.rs @@ -236,3 +236,57 @@ mask_shape_above_cap_blur_max_bytes = 8 remove_temp_config(&path); } + +#[test] +fn load_rejects_zero_mask_relay_max_bytes() { + let path = write_temp_config( + r#" +[censorship] +mask_relay_max_bytes = 0 +"#, + ); + + let err = ProxyConfig::load(&path).expect_err("mask_relay_max_bytes must be > 0"); + let msg = err.to_string(); + assert!( + msg.contains("censorship.mask_relay_max_bytes must be > 0"), + "error must explain non-zero relay cap invariant, got: {msg}" + ); + + remove_temp_config(&path); +} + +#[test] +fn load_rejects_mask_relay_max_bytes_above_upper_bound() { + let path = write_temp_config( + r#" +[censorship] +mask_relay_max_bytes = 67108865 +"#, + ); + + let err = + ProxyConfig::load(&path).expect_err("mask_relay_max_bytes above hard cap must be rejected"); + let msg = err.to_string(); + assert!( + msg.contains("censorship.mask_relay_max_bytes must be <= 67108864"), + "error must explain relay cap upper bound invariant, got: {msg}" + ); + + remove_temp_config(&path); +} + +#[test] +fn load_accepts_valid_mask_relay_max_bytes() { + let path = write_temp_config( + r#" +[censorship] +mask_relay_max_bytes = 8388608 +"#, + ); + + let cfg = ProxyConfig::load(&path).expect("valid mask_relay_max_bytes must be accepted"); + assert_eq!(cfg.censorship.mask_relay_max_bytes, 8_388_608); + + remove_temp_config(&path); +} diff --git a/src/config/types.rs b/src/config/types.rs index 1c5423e..cb14747 100644 --- a/src/config/types.rs +++ b/src/config/types.rs @@ -429,6 +429,11 @@ pub struct GeneralConfig { #[serde(default = "default_me2dc_fallback")] pub me2dc_fallback: bool, + /// Fast ME->Direct fallback mode for new sessions. + /// Active only when both `use_middle_proxy=true` and `me2dc_fallback=true`. + #[serde(default = "default_me2dc_fast")] + pub me2dc_fast: bool, + /// Enable ME keepalive padding frames. #[serde(default = "default_true")] pub me_keepalive_enabled: bool, @@ -468,7 +473,7 @@ pub struct GeneralConfig { pub me_c2me_send_timeout_ms: u64, /// Bounded wait in milliseconds for routing ME DATA to per-connection queue. - /// `0` keeps legacy no-wait behavior. + /// `0` keeps non-blocking routing; values >0 enable bounded wait for compatibility. #[serde(default = "default_me_reader_route_data_wait_ms")] pub me_reader_route_data_wait_ms: u64, @@ -489,6 +494,14 @@ pub struct GeneralConfig { #[serde(default = "default_me_d2c_ack_flush_immediate")] pub me_d2c_ack_flush_immediate: bool, + /// Additional bytes above strict per-user quota allowed in hot-path soft mode. + #[serde(default = "default_me_quota_soft_overshoot_bytes")] + pub me_quota_soft_overshoot_bytes: u64, + + /// Shrink threshold for reusable ME->Client frame assembly buffer. + #[serde(default = "default_me_d2c_frame_buf_shrink_threshold_bytes")] + pub me_d2c_frame_buf_shrink_threshold_bytes: usize, + /// Copy buffer size for client->DC direction in direct relay. #[serde(default = "default_direct_relay_copy_buf_c2s_bytes")] pub direct_relay_copy_buf_c2s_bytes: usize, @@ -931,6 +944,7 @@ impl Default for GeneralConfig { middle_proxy_warm_standby: default_middle_proxy_warm_standby(), me_init_retry_attempts: default_me_init_retry_attempts(), me2dc_fallback: default_me2dc_fallback(), + me2dc_fast: default_me2dc_fast(), me_keepalive_enabled: default_true(), me_keepalive_interval_secs: default_keepalive_interval(), me_keepalive_jitter_secs: default_keepalive_jitter(), @@ -945,6 +959,9 @@ impl Default for GeneralConfig { me_d2c_flush_batch_max_bytes: default_me_d2c_flush_batch_max_bytes(), me_d2c_flush_batch_max_delay_us: default_me_d2c_flush_batch_max_delay_us(), me_d2c_ack_flush_immediate: default_me_d2c_ack_flush_immediate(), + me_quota_soft_overshoot_bytes: default_me_quota_soft_overshoot_bytes(), + me_d2c_frame_buf_shrink_threshold_bytes: + default_me_d2c_frame_buf_shrink_threshold_bytes(), direct_relay_copy_buf_c2s_bytes: default_direct_relay_copy_buf_c2s_bytes(), direct_relay_copy_buf_s2c_bytes: default_direct_relay_copy_buf_s2c_bytes(), me_warmup_stagger_enabled: default_true(), @@ -1229,9 +1246,10 @@ pub struct ServerConfig { /// Trusted source CIDRs allowed to send incoming PROXY protocol headers. /// - /// When non-empty, connections from addresses outside this allowlist are - /// rejected before `src_addr` is applied. - #[serde(default)] + /// If this field is omitted in config, it defaults to trust-all CIDRs + /// (`0.0.0.0/0` and `::/0`). If it is explicitly set to an empty list, + /// all PROXY protocol headers are rejected. + #[serde(default = "default_proxy_protocol_trusted_cidrs")] pub proxy_protocol_trusted_cidrs: Vec, /// Port for the Prometheus-compatible metrics endpoint. @@ -1276,7 +1294,7 @@ impl Default for ServerConfig { listen_tcp: None, proxy_protocol: false, proxy_protocol_header_timeout_ms: default_proxy_protocol_header_timeout_ms(), - proxy_protocol_trusted_cidrs: Vec::new(), + proxy_protocol_trusted_cidrs: default_proxy_protocol_trusted_cidrs(), metrics_port: None, metrics_listen: None, metrics_whitelist: default_metrics_whitelist(), @@ -1347,6 +1365,90 @@ impl Default for TimeoutsConfig { } } +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)] +#[serde(rename_all = "lowercase")] +pub enum UnknownSniAction { + #[default] + Drop, + Mask, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum TlsFetchProfile { + ModernChromeLike, + ModernFirefoxLike, + CompatTls12, + LegacyMinimal, +} + +impl TlsFetchProfile { + pub fn as_str(self) -> &'static str { + match self { + TlsFetchProfile::ModernChromeLike => "modern_chrome_like", + TlsFetchProfile::ModernFirefoxLike => "modern_firefox_like", + TlsFetchProfile::CompatTls12 => "compat_tls12", + TlsFetchProfile::LegacyMinimal => "legacy_minimal", + } + } +} + +fn default_tls_fetch_profiles() -> Vec { + vec![ + TlsFetchProfile::ModernChromeLike, + TlsFetchProfile::ModernFirefoxLike, + TlsFetchProfile::CompatTls12, + TlsFetchProfile::LegacyMinimal, + ] +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TlsFetchConfig { + /// Ordered list of ClientHello profiles used for adaptive fallback. + #[serde(default = "default_tls_fetch_profiles")] + pub profiles: Vec, + + /// When true and upstream route is configured, TLS fetch fails closed on + /// upstream connect errors and does not fallback to direct TCP. + #[serde(default = "default_tls_fetch_strict_route")] + pub strict_route: bool, + + /// Timeout per one profile attempt in milliseconds. + #[serde(default = "default_tls_fetch_attempt_timeout_ms")] + pub attempt_timeout_ms: u64, + + /// Total wall-clock budget in milliseconds across all profile attempts. + #[serde(default = "default_tls_fetch_total_budget_ms")] + pub total_budget_ms: u64, + + /// Adds GREASE-style values into selected ClientHello extensions. + #[serde(default)] + pub grease_enabled: bool, + + /// Produces deterministic ClientHello randomness for debugging/tests. + #[serde(default)] + pub deterministic: bool, + + /// TTL for winner-profile cache entries in seconds. + /// Set to 0 to disable profile cache. + #[serde(default = "default_tls_fetch_profile_cache_ttl_secs")] + pub profile_cache_ttl_secs: u64, +} + +impl Default for TlsFetchConfig { + fn default() -> Self { + Self { + profiles: default_tls_fetch_profiles(), + strict_route: default_tls_fetch_strict_route(), + attempt_timeout_ms: default_tls_fetch_attempt_timeout_ms(), + total_budget_ms: default_tls_fetch_total_budget_ms(), + grease_enabled: false, + deterministic: false, + profile_cache_ttl_secs: default_tls_fetch_profile_cache_ttl_secs(), + } + } +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct AntiCensorshipConfig { #[serde(default = "default_tls_domain")] @@ -1356,11 +1458,19 @@ pub struct AntiCensorshipConfig { #[serde(default)] pub tls_domains: Vec, + /// Policy for TLS ClientHello with unknown (non-configured) SNI. + #[serde(default)] + pub unknown_sni_action: UnknownSniAction, + /// Upstream scope used for TLS front metadata fetches. /// Empty value keeps default upstream routing behavior. #[serde(default = "default_tls_fetch_scope")] pub tls_fetch_scope: String, + /// Fetch strategy for TLS front metadata bootstrap and periodic refresh. + #[serde(default)] + pub tls_fetch: TlsFetchConfig, + #[serde(default = "default_true")] pub mask: bool, @@ -1440,6 +1550,14 @@ pub struct AntiCensorshipConfig { #[serde(default = "default_mask_shape_above_cap_blur_max_bytes")] pub mask_shape_above_cap_blur_max_bytes: usize, + /// Maximum bytes relayed per direction on unauthenticated masking fallback paths. + #[serde(default = "default_mask_relay_max_bytes")] + pub mask_relay_max_bytes: usize, + + /// Prefetch timeout (ms) for extending fragmented masking classifier window. + #[serde(default = "default_mask_classifier_prefetch_timeout_ms")] + pub mask_classifier_prefetch_timeout_ms: u64, + /// Enable outcome-time normalization envelope for masking fallback. #[serde(default = "default_mask_timing_normalization_enabled")] pub mask_timing_normalization_enabled: bool, @@ -1458,7 +1576,9 @@ impl Default for AntiCensorshipConfig { Self { tls_domain: default_tls_domain(), tls_domains: Vec::new(), + unknown_sni_action: UnknownSniAction::Drop, tls_fetch_scope: default_tls_fetch_scope(), + tls_fetch: TlsFetchConfig::default(), mask: default_true(), mask_host: None, mask_port: default_mask_port(), @@ -1478,6 +1598,8 @@ impl Default for AntiCensorshipConfig { mask_shape_bucket_cap_bytes: default_mask_shape_bucket_cap_bytes(), mask_shape_above_cap_blur: default_mask_shape_above_cap_blur(), mask_shape_above_cap_blur_max_bytes: default_mask_shape_above_cap_blur_max_bytes(), + mask_relay_max_bytes: default_mask_relay_max_bytes(), + mask_classifier_prefetch_timeout_ms: default_mask_classifier_prefetch_timeout_ms(), mask_timing_normalization_enabled: default_mask_timing_normalization_enabled(), mask_timing_normalization_floor_ms: default_mask_timing_normalization_floor_ms(), mask_timing_normalization_ceiling_ms: default_mask_timing_normalization_ceiling_ms(), diff --git a/src/error.rs b/src/error.rs index d9aeb22..49c8c81 100644 --- a/src/error.rs +++ b/src/error.rs @@ -216,6 +216,9 @@ pub enum ProxyError { #[error("Invalid proxy protocol header")] InvalidProxyProtocol, + #[error("Unknown TLS SNI")] + UnknownTlsSni, + #[error("Proxy error: {0}")] Proxy(String), diff --git a/src/maestro/admission.rs b/src/maestro/admission.rs index 69a9c9f..82484ad 100644 --- a/src/maestro/admission.rs +++ b/src/maestro/admission.rs @@ -21,10 +21,29 @@ pub(crate) async fn configure_admission_gate( if config.general.use_middle_proxy { if let Some(pool) = me_pool.as_ref() { let initial_ready = pool.admission_ready_conditional_cast().await; - admission_tx.send_replace(initial_ready); - let _ = route_runtime.set_mode(RelayRouteMode::Middle); + let mut fallback_enabled = config.general.me2dc_fallback; + let mut fast_fallback_enabled = fallback_enabled && config.general.me2dc_fast; + let (initial_gate_open, initial_route_mode, initial_fallback_reason) = if initial_ready + { + (true, RelayRouteMode::Middle, None) + } else if fast_fallback_enabled { + ( + true, + RelayRouteMode::Direct, + Some("fast_not_ready_fallback"), + ) + } else { + (false, RelayRouteMode::Middle, None) + }; + admission_tx.send_replace(initial_gate_open); + let _ = route_runtime.set_mode(initial_route_mode); if initial_ready { info!("Conditional-admission gate: open / ME pool READY"); + } else if let Some(reason) = initial_fallback_reason { + warn!( + fallback_reason = reason, + "Conditional-admission gate opened in ME fast fallback mode" + ); } else { warn!("Conditional-admission gate: closed / ME pool is NOT ready)"); } @@ -34,10 +53,9 @@ pub(crate) async fn configure_admission_gate( let route_runtime_gate = route_runtime.clone(); let mut config_rx_gate = config_rx.clone(); let mut admission_poll_ms = config.general.me_admission_poll_ms.max(1); - let mut fallback_enabled = config.general.me2dc_fallback; tokio::spawn(async move { - let mut gate_open = initial_ready; - let mut route_mode = RelayRouteMode::Middle; + let mut gate_open = initial_gate_open; + let mut route_mode = initial_route_mode; let mut ready_observed = initial_ready; let mut not_ready_since = if initial_ready { None @@ -53,16 +71,23 @@ pub(crate) async fn configure_admission_gate( let cfg = config_rx_gate.borrow_and_update().clone(); admission_poll_ms = cfg.general.me_admission_poll_ms.max(1); fallback_enabled = cfg.general.me2dc_fallback; + fast_fallback_enabled = cfg.general.me2dc_fallback && cfg.general.me2dc_fast; continue; } _ = tokio::time::sleep(Duration::from_millis(admission_poll_ms)) => {} } let ready = pool_for_gate.admission_ready_conditional_cast().await; let now = Instant::now(); - let (next_gate_open, next_route_mode, next_fallback_active) = if ready { + let (next_gate_open, next_route_mode, next_fallback_reason) = if ready { ready_observed = true; not_ready_since = None; - (true, RelayRouteMode::Middle, false) + (true, RelayRouteMode::Middle, None) + } else if fast_fallback_enabled { + ( + true, + RelayRouteMode::Direct, + Some("fast_not_ready_fallback"), + ) } else { let not_ready_started_at = *not_ready_since.get_or_insert(now); let not_ready_for = now.saturating_duration_since(not_ready_started_at); @@ -72,11 +97,12 @@ pub(crate) async fn configure_admission_gate( STARTUP_FALLBACK_AFTER }; if fallback_enabled && not_ready_for > fallback_after { - (true, RelayRouteMode::Direct, true) + (true, RelayRouteMode::Direct, Some("strict_grace_fallback")) } else { - (false, RelayRouteMode::Middle, false) + (false, RelayRouteMode::Middle, None) } }; + let next_fallback_active = next_fallback_reason.is_some(); if next_route_mode != route_mode { route_mode = next_route_mode; @@ -88,17 +114,28 @@ pub(crate) async fn configure_admission_gate( "Middle-End routing restored for new sessions" ); } else { - let fallback_after = if ready_observed { - RUNTIME_FALLBACK_AFTER + let fallback_reason = next_fallback_reason.unwrap_or("unknown"); + if fallback_reason == "strict_grace_fallback" { + let fallback_after = if ready_observed { + RUNTIME_FALLBACK_AFTER + } else { + STARTUP_FALLBACK_AFTER + }; + warn!( + target_mode = route_mode.as_str(), + cutover_generation = snapshot.generation, + grace_secs = fallback_after.as_secs(), + fallback_reason, + "ME pool stayed not-ready beyond grace; routing new sessions via Direct-DC" + ); } else { - STARTUP_FALLBACK_AFTER - }; - warn!( - target_mode = route_mode.as_str(), - cutover_generation = snapshot.generation, - grace_secs = fallback_after.as_secs(), - "ME pool stayed not-ready beyond grace; routing new sessions via Direct-DC" - ); + warn!( + target_mode = route_mode.as_str(), + cutover_generation = snapshot.generation, + fallback_reason, + "ME pool not-ready; routing new sessions via Direct-DC (fast mode)" + ); + } } } } @@ -108,7 +145,10 @@ pub(crate) async fn configure_admission_gate( admission_tx_gate.send_replace(gate_open); if gate_open { if next_fallback_active { - warn!("Conditional-admission gate opened in ME fallback mode"); + warn!( + fallback_reason = next_fallback_reason.unwrap_or("unknown"), + "Conditional-admission gate opened in ME fallback mode" + ); } else { info!("Conditional-admission gate opened / ME pool READY"); } diff --git a/src/maestro/helpers.rs b/src/maestro/helpers.rs index c376d0b..e3c3feb 100644 --- a/src/maestro/helpers.rs +++ b/src/maestro/helpers.rs @@ -9,8 +9,10 @@ use tracing::{debug, error, info, warn}; use crate::cli; use crate::config::ProxyConfig; use crate::logging::LogDestination; +use crate::transport::UpstreamManager; use crate::transport::middle_proxy::{ - ProxyConfigData, fetch_proxy_config_with_raw, load_proxy_config_cache, save_proxy_config_cache, + ProxyConfigData, fetch_proxy_config_with_raw_via_upstream, load_proxy_config_cache, + save_proxy_config_cache, }; pub(crate) fn resolve_runtime_config_path( @@ -370,9 +372,10 @@ pub(crate) async fn load_startup_proxy_config_snapshot( cache_path: Option<&str>, me2dc_fallback: bool, label: &'static str, + upstream: Option>, ) -> Option { loop { - match fetch_proxy_config_with_raw(url).await { + match fetch_proxy_config_with_raw_via_upstream(url, upstream.clone()).await { Ok((cfg, raw)) => { if !cfg.map.is_empty() { if let Some(path) = cache_path diff --git a/src/maestro/me_startup.rs b/src/maestro/me_startup.rs index 022f8ae..4e49e9e 100644 --- a/src/maestro/me_startup.rs +++ b/src/maestro/me_startup.rs @@ -63,9 +63,10 @@ pub(crate) async fn initialize_me_pool( let proxy_secret_path = config.general.proxy_secret_path.as_deref(); let pool_size = config.general.middle_proxy_pool_size.max(1); let proxy_secret = loop { - match crate::transport::middle_proxy::fetch_proxy_secret( + match crate::transport::middle_proxy::fetch_proxy_secret_with_upstream( proxy_secret_path, config.general.proxy_secret_len_max, + Some(upstream_manager.clone()), ) .await { @@ -129,6 +130,7 @@ pub(crate) async fn initialize_me_pool( config.general.proxy_config_v4_cache_path.as_deref(), me2dc_fallback, "getProxyConfig", + Some(upstream_manager.clone()), ) .await; if cfg_v4.is_some() { @@ -160,6 +162,7 @@ pub(crate) async fn initialize_me_pool( config.general.proxy_config_v6_cache_path.as_deref(), me2dc_fallback, "getProxyConfigV6", + Some(upstream_manager.clone()), ) .await; if cfg_v6.is_some() { @@ -274,6 +277,8 @@ pub(crate) async fn initialize_me_pool( config.general.me_warn_rate_limit_ms, config.general.me_route_no_writer_mode, config.general.me_route_no_writer_wait_ms, + config.general.me_route_hybrid_max_wait_ms, + config.general.me_route_blocking_send_timeout_ms, config.general.me_route_inline_recovery_attempts, config.general.me_route_inline_recovery_wait_ms, ); diff --git a/src/maestro/mod.rs b/src/maestro/mod.rs index ecd7f6d..921b8bd 100644 --- a/src/maestro/mod.rs +++ b/src/maestro/mod.rs @@ -168,15 +168,13 @@ async fn run_inner( ); std::process::exit(1); } - } else { - if let Err(e) = std::fs::create_dir_all(data_path) { - eprintln!( - "[telemt] Can't create data_path {}: {}", - data_path.display(), - e - ); - std::process::exit(1); - } + } else if let Err(e) = std::fs::create_dir_all(data_path) { + eprintln!( + "[telemt] Can't create data_path {}: {}", + data_path.display(), + e + ); + std::process::exit(1); } if let Err(e) = std::env::set_current_dir(data_path) { diff --git a/src/maestro/tls_bootstrap.rs b/src/maestro/tls_bootstrap.rs index 342a2f9..7cf3039 100644 --- a/src/maestro/tls_bootstrap.rs +++ b/src/maestro/tls_bootstrap.rs @@ -7,6 +7,7 @@ use tracing::warn; use crate::config::ProxyConfig; use crate::startup::{COMPONENT_TLS_FRONT_BOOTSTRAP, StartupTracker}; use crate::tls_front::TlsFrontCache; +use crate::tls_front::fetcher::TlsFetchStrategy; use crate::transport::UpstreamManager; pub(crate) async fn bootstrap_tls_front( @@ -40,7 +41,17 @@ pub(crate) async fn bootstrap_tls_front( let mask_unix_sock = config.censorship.mask_unix_sock.clone(); let tls_fetch_scope = (!config.censorship.tls_fetch_scope.is_empty()) .then(|| config.censorship.tls_fetch_scope.clone()); - let fetch_timeout = Duration::from_secs(5); + let tls_fetch = config.censorship.tls_fetch.clone(); + let fetch_strategy = TlsFetchStrategy { + profiles: tls_fetch.profiles, + strict_route: tls_fetch.strict_route, + attempt_timeout: Duration::from_millis(tls_fetch.attempt_timeout_ms.max(1)), + total_budget: Duration::from_millis(tls_fetch.total_budget_ms.max(1)), + grease_enabled: tls_fetch.grease_enabled, + deterministic: tls_fetch.deterministic, + profile_cache_ttl: Duration::from_secs(tls_fetch.profile_cache_ttl_secs), + }; + let fetch_timeout = fetch_strategy.total_budget; let cache_initial = cache.clone(); let domains_initial = tls_domains.to_vec(); @@ -48,6 +59,7 @@ pub(crate) async fn bootstrap_tls_front( let unix_sock_initial = mask_unix_sock.clone(); let scope_initial = tls_fetch_scope.clone(); let upstream_initial = upstream_manager.clone(); + let strategy_initial = fetch_strategy.clone(); tokio::spawn(async move { let mut join = tokio::task::JoinSet::new(); for domain in domains_initial { @@ -56,12 +68,13 @@ pub(crate) async fn bootstrap_tls_front( let unix_sock_domain = unix_sock_initial.clone(); let scope_domain = scope_initial.clone(); let upstream_domain = upstream_initial.clone(); + let strategy_domain = strategy_initial.clone(); join.spawn(async move { - match crate::tls_front::fetcher::fetch_real_tls( + match crate::tls_front::fetcher::fetch_real_tls_with_strategy( &host_domain, port, &domain, - fetch_timeout, + &strategy_domain, Some(upstream_domain), scope_domain.as_deref(), proxy_protocol, @@ -107,6 +120,7 @@ pub(crate) async fn bootstrap_tls_front( let unix_sock_refresh = mask_unix_sock.clone(); let scope_refresh = tls_fetch_scope.clone(); let upstream_refresh = upstream_manager.clone(); + let strategy_refresh = fetch_strategy.clone(); tokio::spawn(async move { loop { let base_secs = rand::rng().random_range(4 * 3600..=6 * 3600); @@ -120,12 +134,13 @@ pub(crate) async fn bootstrap_tls_front( let unix_sock_domain = unix_sock_refresh.clone(); let scope_domain = scope_refresh.clone(); let upstream_domain = upstream_refresh.clone(); + let strategy_domain = strategy_refresh.clone(); join.spawn(async move { - match crate::tls_front::fetcher::fetch_real_tls( + match crate::tls_front::fetcher::fetch_real_tls_with_strategy( &host_domain, port, &domain, - fetch_timeout, + &strategy_domain, Some(upstream_domain), scope_domain.as_deref(), proxy_protocol, diff --git a/src/main.rs b/src/main.rs index 26a10f9..0d29981 100644 --- a/src/main.rs +++ b/src/main.rs @@ -11,12 +11,12 @@ mod ip_tracker; mod logging; mod service; #[cfg(test)] -#[path = "tests/ip_tracker_hotpath_adversarial_tests.rs"] -mod ip_tracker_hotpath_adversarial_tests; -#[cfg(test)] #[path = "tests/ip_tracker_encapsulation_adversarial_tests.rs"] mod ip_tracker_encapsulation_adversarial_tests; #[cfg(test)] +#[path = "tests/ip_tracker_hotpath_adversarial_tests.rs"] +mod ip_tracker_hotpath_adversarial_tests; +#[cfg(test)] #[path = "tests/ip_tracker_regression_tests.rs"] mod ip_tracker_regression_tests; mod maestro; @@ -32,6 +32,9 @@ mod transport; mod util; fn main() -> std::result::Result<(), Box> { + // Install rustls crypto provider early + let _ = rustls::crypto::ring::default_provider().install_default(); + let args: Vec = std::env::args().skip(1).collect(); let cmd = cli::parse_command(&args); @@ -40,20 +43,18 @@ fn main() -> std::result::Result<(), Box> { std::process::exit(exit_code); } - // On Unix, handle daemonization before starting tokio runtime #[cfg(unix)] { let daemon_opts = cmd.daemon_opts; - // Daemonize if requested (must happen before tokio runtime starts) + // Daemonize BEFORE runtime if daemon_opts.should_daemonize() { match daemon::daemonize(daemon_opts.working_dir.as_deref()) { Ok(daemon::DaemonizeResult::Parent) => { - // Parent process exits successfully std::process::exit(0); } Ok(daemon::DaemonizeResult::Child) => { - // Continue as daemon child + // continue } Err(e) => { eprintln!("[telemt] Daemonization failed: {}", e); @@ -62,7 +63,6 @@ fn main() -> std::result::Result<(), Box> { } } - // Now start tokio runtime and run the server tokio::runtime::Builder::new_multi_thread() .enable_all() .build()? diff --git a/src/metrics.rs b/src/metrics.rs index 2560294..2c87ed6 100644 --- a/src/metrics.rs +++ b/src/metrics.rs @@ -935,6 +935,459 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp } ); + let _ = writeln!( + out, + "# HELP telemt_me_d2c_batches_total Total DC->Client flush batches" + ); + let _ = writeln!(out, "# TYPE telemt_me_d2c_batches_total counter"); + let _ = writeln!( + out, + "telemt_me_d2c_batches_total {}", + if me_allows_normal { + stats.get_me_d2c_batches_total() + } else { + 0 + } + ); + + let _ = writeln!( + out, + "# HELP telemt_me_d2c_batch_frames_total Total DC->Client frames flushed in batches" + ); + let _ = writeln!(out, "# TYPE telemt_me_d2c_batch_frames_total counter"); + let _ = writeln!( + out, + "telemt_me_d2c_batch_frames_total {}", + if me_allows_normal { + stats.get_me_d2c_batch_frames_total() + } else { + 0 + } + ); + + let _ = writeln!( + out, + "# HELP telemt_me_d2c_batch_bytes_total Total DC->Client bytes flushed in batches" + ); + let _ = writeln!(out, "# TYPE telemt_me_d2c_batch_bytes_total counter"); + let _ = writeln!( + out, + "telemt_me_d2c_batch_bytes_total {}", + if me_allows_normal { + stats.get_me_d2c_batch_bytes_total() + } else { + 0 + } + ); + + let _ = writeln!( + out, + "# HELP telemt_me_d2c_flush_reason_total DC->Client flush reasons" + ); + let _ = writeln!(out, "# TYPE telemt_me_d2c_flush_reason_total counter"); + let _ = writeln!( + out, + "telemt_me_d2c_flush_reason_total{{reason=\"queue_drain\"}} {}", + if me_allows_normal { + stats.get_me_d2c_flush_reason_queue_drain_total() + } else { + 0 + } + ); + let _ = writeln!( + out, + "telemt_me_d2c_flush_reason_total{{reason=\"batch_frames\"}} {}", + if me_allows_normal { + stats.get_me_d2c_flush_reason_batch_frames_total() + } else { + 0 + } + ); + let _ = writeln!( + out, + "telemt_me_d2c_flush_reason_total{{reason=\"batch_bytes\"}} {}", + if me_allows_normal { + stats.get_me_d2c_flush_reason_batch_bytes_total() + } else { + 0 + } + ); + let _ = writeln!( + out, + "telemt_me_d2c_flush_reason_total{{reason=\"max_delay\"}} {}", + if me_allows_normal { + stats.get_me_d2c_flush_reason_max_delay_total() + } else { + 0 + } + ); + let _ = writeln!( + out, + "telemt_me_d2c_flush_reason_total{{reason=\"ack_immediate\"}} {}", + if me_allows_normal { + stats.get_me_d2c_flush_reason_ack_immediate_total() + } else { + 0 + } + ); + let _ = writeln!( + out, + "telemt_me_d2c_flush_reason_total{{reason=\"close\"}} {}", + if me_allows_normal { + stats.get_me_d2c_flush_reason_close_total() + } else { + 0 + } + ); + + let _ = writeln!( + out, + "# HELP telemt_me_d2c_data_frames_total DC->Client data frames" + ); + let _ = writeln!(out, "# TYPE telemt_me_d2c_data_frames_total counter"); + let _ = writeln!( + out, + "telemt_me_d2c_data_frames_total {}", + if me_allows_normal { + stats.get_me_d2c_data_frames_total() + } else { + 0 + } + ); + + let _ = writeln!( + out, + "# HELP telemt_me_d2c_ack_frames_total DC->Client quick-ack frames" + ); + let _ = writeln!(out, "# TYPE telemt_me_d2c_ack_frames_total counter"); + let _ = writeln!( + out, + "telemt_me_d2c_ack_frames_total {}", + if me_allows_normal { + stats.get_me_d2c_ack_frames_total() + } else { + 0 + } + ); + + let _ = writeln!( + out, + "# HELP telemt_me_d2c_payload_bytes_total DC->Client payload bytes before transport framing" + ); + let _ = writeln!(out, "# TYPE telemt_me_d2c_payload_bytes_total counter"); + let _ = writeln!( + out, + "telemt_me_d2c_payload_bytes_total {}", + if me_allows_normal { + stats.get_me_d2c_payload_bytes_total() + } else { + 0 + } + ); + + let _ = writeln!( + out, + "# HELP telemt_me_d2c_write_mode_total DC->Client writer mode selection" + ); + let _ = writeln!(out, "# TYPE telemt_me_d2c_write_mode_total counter"); + let _ = writeln!( + out, + "telemt_me_d2c_write_mode_total{{mode=\"coalesced\"}} {}", + if me_allows_normal { + stats.get_me_d2c_write_mode_coalesced_total() + } else { + 0 + } + ); + let _ = writeln!( + out, + "telemt_me_d2c_write_mode_total{{mode=\"split\"}} {}", + if me_allows_normal { + stats.get_me_d2c_write_mode_split_total() + } else { + 0 + } + ); + + let _ = writeln!( + out, + "# HELP telemt_me_d2c_quota_reject_total DC->Client quota rejects" + ); + let _ = writeln!(out, "# TYPE telemt_me_d2c_quota_reject_total counter"); + let _ = writeln!( + out, + "telemt_me_d2c_quota_reject_total{{stage=\"pre_write\"}} {}", + if me_allows_normal { + stats.get_me_d2c_quota_reject_pre_write_total() + } else { + 0 + } + ); + let _ = writeln!( + out, + "telemt_me_d2c_quota_reject_total{{stage=\"post_write\"}} {}", + if me_allows_normal { + stats.get_me_d2c_quota_reject_post_write_total() + } else { + 0 + } + ); + + let _ = writeln!( + out, + "# HELP telemt_me_d2c_frame_buf_shrink_total DC->Client reusable frame buffer shrink events" + ); + let _ = writeln!(out, "# TYPE telemt_me_d2c_frame_buf_shrink_total counter"); + let _ = writeln!( + out, + "telemt_me_d2c_frame_buf_shrink_total {}", + if me_allows_normal { + stats.get_me_d2c_frame_buf_shrink_total() + } else { + 0 + } + ); + + let _ = writeln!( + out, + "# HELP telemt_me_d2c_frame_buf_shrink_bytes_total DC->Client reusable frame buffer bytes released" + ); + let _ = writeln!( + out, + "# TYPE telemt_me_d2c_frame_buf_shrink_bytes_total counter" + ); + let _ = writeln!( + out, + "telemt_me_d2c_frame_buf_shrink_bytes_total {}", + if me_allows_normal { + stats.get_me_d2c_frame_buf_shrink_bytes_total() + } else { + 0 + } + ); + + let _ = writeln!( + out, + "# HELP telemt_me_d2c_batch_frames_bucket_total DC->Client batch frame count buckets" + ); + let _ = writeln!( + out, + "# TYPE telemt_me_d2c_batch_frames_bucket_total counter" + ); + let _ = writeln!( + out, + "telemt_me_d2c_batch_frames_bucket_total{{bucket=\"1\"}} {}", + if me_allows_debug { + stats.get_me_d2c_batch_frames_bucket_1() + } else { + 0 + } + ); + let _ = writeln!( + out, + "telemt_me_d2c_batch_frames_bucket_total{{bucket=\"2_4\"}} {}", + if me_allows_debug { + stats.get_me_d2c_batch_frames_bucket_2_4() + } else { + 0 + } + ); + let _ = writeln!( + out, + "telemt_me_d2c_batch_frames_bucket_total{{bucket=\"5_8\"}} {}", + if me_allows_debug { + stats.get_me_d2c_batch_frames_bucket_5_8() + } else { + 0 + } + ); + let _ = writeln!( + out, + "telemt_me_d2c_batch_frames_bucket_total{{bucket=\"9_16\"}} {}", + if me_allows_debug { + stats.get_me_d2c_batch_frames_bucket_9_16() + } else { + 0 + } + ); + let _ = writeln!( + out, + "telemt_me_d2c_batch_frames_bucket_total{{bucket=\"17_32\"}} {}", + if me_allows_debug { + stats.get_me_d2c_batch_frames_bucket_17_32() + } else { + 0 + } + ); + let _ = writeln!( + out, + "telemt_me_d2c_batch_frames_bucket_total{{bucket=\"gt_32\"}} {}", + if me_allows_debug { + stats.get_me_d2c_batch_frames_bucket_gt_32() + } else { + 0 + } + ); + + let _ = writeln!( + out, + "# HELP telemt_me_d2c_batch_bytes_bucket_total DC->Client batch byte size buckets" + ); + let _ = writeln!(out, "# TYPE telemt_me_d2c_batch_bytes_bucket_total counter"); + let _ = writeln!( + out, + "telemt_me_d2c_batch_bytes_bucket_total{{bucket=\"0_1k\"}} {}", + if me_allows_debug { + stats.get_me_d2c_batch_bytes_bucket_0_1k() + } else { + 0 + } + ); + let _ = writeln!( + out, + "telemt_me_d2c_batch_bytes_bucket_total{{bucket=\"1k_4k\"}} {}", + if me_allows_debug { + stats.get_me_d2c_batch_bytes_bucket_1k_4k() + } else { + 0 + } + ); + let _ = writeln!( + out, + "telemt_me_d2c_batch_bytes_bucket_total{{bucket=\"4k_16k\"}} {}", + if me_allows_debug { + stats.get_me_d2c_batch_bytes_bucket_4k_16k() + } else { + 0 + } + ); + let _ = writeln!( + out, + "telemt_me_d2c_batch_bytes_bucket_total{{bucket=\"16k_64k\"}} {}", + if me_allows_debug { + stats.get_me_d2c_batch_bytes_bucket_16k_64k() + } else { + 0 + } + ); + let _ = writeln!( + out, + "telemt_me_d2c_batch_bytes_bucket_total{{bucket=\"64k_128k\"}} {}", + if me_allows_debug { + stats.get_me_d2c_batch_bytes_bucket_64k_128k() + } else { + 0 + } + ); + let _ = writeln!( + out, + "telemt_me_d2c_batch_bytes_bucket_total{{bucket=\"gt_128k\"}} {}", + if me_allows_debug { + stats.get_me_d2c_batch_bytes_bucket_gt_128k() + } else { + 0 + } + ); + + let _ = writeln!( + out, + "# HELP telemt_me_d2c_flush_duration_us_bucket_total DC->Client flush duration buckets" + ); + let _ = writeln!( + out, + "# TYPE telemt_me_d2c_flush_duration_us_bucket_total counter" + ); + let _ = writeln!( + out, + "telemt_me_d2c_flush_duration_us_bucket_total{{bucket=\"0_50\"}} {}", + if me_allows_debug { + stats.get_me_d2c_flush_duration_us_bucket_0_50() + } else { + 0 + } + ); + let _ = writeln!( + out, + "telemt_me_d2c_flush_duration_us_bucket_total{{bucket=\"51_200\"}} {}", + if me_allows_debug { + stats.get_me_d2c_flush_duration_us_bucket_51_200() + } else { + 0 + } + ); + let _ = writeln!( + out, + "telemt_me_d2c_flush_duration_us_bucket_total{{bucket=\"201_1000\"}} {}", + if me_allows_debug { + stats.get_me_d2c_flush_duration_us_bucket_201_1000() + } else { + 0 + } + ); + let _ = writeln!( + out, + "telemt_me_d2c_flush_duration_us_bucket_total{{bucket=\"1001_5000\"}} {}", + if me_allows_debug { + stats.get_me_d2c_flush_duration_us_bucket_1001_5000() + } else { + 0 + } + ); + let _ = writeln!( + out, + "telemt_me_d2c_flush_duration_us_bucket_total{{bucket=\"5001_20000\"}} {}", + if me_allows_debug { + stats.get_me_d2c_flush_duration_us_bucket_5001_20000() + } else { + 0 + } + ); + let _ = writeln!( + out, + "telemt_me_d2c_flush_duration_us_bucket_total{{bucket=\"gt_20000\"}} {}", + if me_allows_debug { + stats.get_me_d2c_flush_duration_us_bucket_gt_20000() + } else { + 0 + } + ); + + let _ = writeln!( + out, + "# HELP telemt_me_d2c_batch_timeout_armed_total DC->Client max-delay timer armed events" + ); + let _ = writeln!( + out, + "# TYPE telemt_me_d2c_batch_timeout_armed_total counter" + ); + let _ = writeln!( + out, + "telemt_me_d2c_batch_timeout_armed_total {}", + if me_allows_debug { + stats.get_me_d2c_batch_timeout_armed_total() + } else { + 0 + } + ); + + let _ = writeln!( + out, + "# HELP telemt_me_d2c_batch_timeout_fired_total DC->Client max-delay timer fired events" + ); + let _ = writeln!( + out, + "# TYPE telemt_me_d2c_batch_timeout_fired_total counter" + ); + let _ = writeln!( + out, + "telemt_me_d2c_batch_timeout_fired_total {}", + if me_allows_debug { + stats.get_me_d2c_batch_timeout_fired_total() + } else { + 0 + } + ); + let _ = writeln!( out, "# HELP telemt_me_writer_pick_total ME writer-pick outcomes by mode and result" @@ -1105,6 +1558,40 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp 0 } ); + let _ = writeln!( + out, + "# HELP telemt_me_endpoint_quarantine_unexpected_total ME endpoint quarantines caused by unexpected writer removals" + ); + let _ = writeln!( + out, + "# TYPE telemt_me_endpoint_quarantine_unexpected_total counter" + ); + let _ = writeln!( + out, + "telemt_me_endpoint_quarantine_unexpected_total {}", + if me_allows_normal { + stats.get_me_endpoint_quarantine_unexpected_total() + } else { + 0 + } + ); + let _ = writeln!( + out, + "# HELP telemt_me_endpoint_quarantine_draining_suppressed_total Draining writer removals that skipped endpoint quarantine" + ); + let _ = writeln!( + out, + "# TYPE telemt_me_endpoint_quarantine_draining_suppressed_total counter" + ); + let _ = writeln!( + out, + "telemt_me_endpoint_quarantine_draining_suppressed_total {}", + if me_allows_normal { + stats.get_me_endpoint_quarantine_draining_suppressed_total() + } else { + 0 + } + ); let _ = writeln!( out, @@ -1865,6 +2352,20 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp 0 } ); + let _ = writeln!( + out, + "# HELP telemt_me_hybrid_timeout_total ME hybrid route timeouts after bounded retry window" + ); + let _ = writeln!(out, "# TYPE telemt_me_hybrid_timeout_total counter"); + let _ = writeln!( + out, + "telemt_me_hybrid_timeout_total {}", + if me_allows_normal { + stats.get_me_hybrid_timeout_total() + } else { + 0 + } + ); let _ = writeln!( out, "# HELP telemt_me_async_recovery_trigger_total Async ME recovery trigger attempts from route path" @@ -2145,6 +2646,19 @@ mod tests { stats.increment_relay_idle_hard_close_total(); stats.increment_relay_pressure_evict_total(); stats.increment_relay_protocol_desync_close_total(); + stats.increment_me_d2c_batches_total(); + stats.add_me_d2c_batch_frames_total(3); + stats.add_me_d2c_batch_bytes_total(2048); + stats.increment_me_d2c_flush_reason(crate::stats::MeD2cFlushReason::AckImmediate); + stats.increment_me_d2c_data_frames_total(); + stats.increment_me_d2c_ack_frames_total(); + stats.add_me_d2c_payload_bytes_total(1800); + stats.increment_me_d2c_write_mode(crate::stats::MeD2cWriteMode::Coalesced); + stats.increment_me_d2c_quota_reject_total(crate::stats::MeD2cQuotaRejectStage::PostWrite); + stats.observe_me_d2c_frame_buf_shrink(4096); + stats.increment_me_endpoint_quarantine_total(); + stats.increment_me_endpoint_quarantine_unexpected_total(); + stats.increment_me_endpoint_quarantine_draining_suppressed_total(); stats.increment_user_connects("alice"); stats.increment_user_curr_connects("alice"); stats.add_user_octets_from("alice", 1024); @@ -2184,6 +2698,20 @@ mod tests { assert!(output.contains("telemt_relay_idle_hard_close_total 1")); assert!(output.contains("telemt_relay_pressure_evict_total 1")); assert!(output.contains("telemt_relay_protocol_desync_close_total 1")); + assert!(output.contains("telemt_me_d2c_batches_total 1")); + assert!(output.contains("telemt_me_d2c_batch_frames_total 3")); + assert!(output.contains("telemt_me_d2c_batch_bytes_total 2048")); + assert!(output.contains("telemt_me_d2c_flush_reason_total{reason=\"ack_immediate\"} 1")); + assert!(output.contains("telemt_me_d2c_data_frames_total 1")); + assert!(output.contains("telemt_me_d2c_ack_frames_total 1")); + assert!(output.contains("telemt_me_d2c_payload_bytes_total 1800")); + assert!(output.contains("telemt_me_d2c_write_mode_total{mode=\"coalesced\"} 1")); + assert!(output.contains("telemt_me_d2c_quota_reject_total{stage=\"post_write\"} 1")); + assert!(output.contains("telemt_me_d2c_frame_buf_shrink_total 1")); + assert!(output.contains("telemt_me_d2c_frame_buf_shrink_bytes_total 4096")); + assert!(output.contains("telemt_me_endpoint_quarantine_total 1")); + assert!(output.contains("telemt_me_endpoint_quarantine_unexpected_total 1")); + assert!(output.contains("telemt_me_endpoint_quarantine_draining_suppressed_total 1")); assert!(output.contains("telemt_user_connections_total{user=\"alice\"} 1")); assert!(output.contains("telemt_user_connections_current{user=\"alice\"} 1")); assert!(output.contains("telemt_user_octets_from_client{user=\"alice\"} 1024")); @@ -2245,6 +2773,17 @@ mod tests { assert!(output.contains("# TYPE telemt_relay_idle_hard_close_total counter")); assert!(output.contains("# TYPE telemt_relay_pressure_evict_total counter")); assert!(output.contains("# TYPE telemt_relay_protocol_desync_close_total counter")); + assert!(output.contains("# TYPE telemt_me_d2c_batches_total counter")); + assert!(output.contains("# TYPE telemt_me_d2c_flush_reason_total counter")); + assert!(output.contains("# TYPE telemt_me_d2c_write_mode_total counter")); + assert!(output.contains("# TYPE telemt_me_d2c_batch_frames_bucket_total counter")); + assert!(output.contains("# TYPE telemt_me_d2c_flush_duration_us_bucket_total counter")); + assert!(output.contains("# TYPE telemt_me_endpoint_quarantine_total counter")); + assert!(output.contains("# TYPE telemt_me_endpoint_quarantine_unexpected_total counter")); + assert!( + output + .contains("# TYPE telemt_me_endpoint_quarantine_draining_suppressed_total counter") + ); assert!(output.contains("# TYPE telemt_me_writer_removed_total counter")); assert!( output diff --git a/src/proxy/client.rs b/src/proxy/client.rs index 4b7f57e..8ce3e96 100644 --- a/src/proxy/client.rs +++ b/src/proxy/client.rs @@ -186,6 +186,72 @@ fn handshake_timeout_with_mask_grace(config: &ProxyConfig) -> Duration { } } +const MASK_CLASSIFIER_PREFETCH_WINDOW: usize = 16; +#[cfg(test)] +const MASK_CLASSIFIER_PREFETCH_TIMEOUT: Duration = Duration::from_millis(5); + +fn mask_classifier_prefetch_timeout(config: &ProxyConfig) -> Duration { + Duration::from_millis(config.censorship.mask_classifier_prefetch_timeout_ms) +} + +fn should_prefetch_mask_classifier_window(initial_data: &[u8]) -> bool { + if initial_data.len() >= MASK_CLASSIFIER_PREFETCH_WINDOW { + return false; + } + + if initial_data.is_empty() { + // Empty initial_data means there is no client probe prefix to refine. + // Prefetching in this case can consume fallback relay payload bytes and + // accidentally route them through shaping heuristics. + return false; + } + + if initial_data[0] == 0x16 || initial_data.starts_with(b"SSH-") { + return false; + } + + initial_data + .iter() + .all(|b| b.is_ascii_alphabetic() || *b == b' ') +} + +#[cfg(test)] +async fn extend_masking_initial_window(reader: &mut R, initial_data: &mut Vec) +where + R: AsyncRead + Unpin, +{ + extend_masking_initial_window_with_timeout( + reader, + initial_data, + MASK_CLASSIFIER_PREFETCH_TIMEOUT, + ) + .await; +} + +async fn extend_masking_initial_window_with_timeout( + reader: &mut R, + initial_data: &mut Vec, + prefetch_timeout: Duration, +) where + R: AsyncRead + Unpin, +{ + if !should_prefetch_mask_classifier_window(initial_data) { + return; + } + + let need = MASK_CLASSIFIER_PREFETCH_WINDOW.saturating_sub(initial_data.len()); + if need == 0 { + return; + } + + let mut extra = [0u8; MASK_CLASSIFIER_PREFETCH_WINDOW]; + if let Ok(Ok(n)) = timeout(prefetch_timeout, reader.read(&mut extra[..need])).await + && n > 0 + { + initial_data.extend_from_slice(&extra[..n]); + } +} + fn masking_outcome( reader: R, writer: W, @@ -200,6 +266,15 @@ where W: AsyncWrite + Unpin + Send + 'static, { HandshakeOutcome::NeedsMasking(Box::pin(async move { + let mut reader = reader; + let mut initial_data = initial_data; + extend_masking_initial_window_with_timeout( + &mut reader, + &mut initial_data, + mask_classifier_prefetch_timeout(&config), + ) + .await; + handle_bad_client( reader, writer, @@ -242,13 +317,20 @@ fn record_handshake_failure_class( record_beobachten_class(beobachten, config, peer_ip, class); } +#[inline] +fn increment_bad_on_unknown_tls_sni(stats: &Stats, error: &ProxyError) { + if matches!(error, ProxyError::UnknownTlsSni) { + stats.increment_connects_bad(); + } +} + fn is_trusted_proxy_source(peer_ip: IpAddr, trusted: &[IpNetwork]) -> bool { if trusted.is_empty() { static EMPTY_PROXY_TRUST_WARNED: OnceLock = OnceLock::new(); let warned = EMPTY_PROXY_TRUST_WARNED.get_or_init(|| AtomicBool::new(false)); if !warned.swap(true, Ordering::Relaxed) { warn!( - "PROXY protocol enabled but server.proxy_protocol_trusted_cidrs is empty; rejecting all PROXY headers by default" + "PROXY protocol enabled but server.proxy_protocol_trusted_cidrs is empty; rejecting all PROXY headers" ); } return false; @@ -433,7 +515,10 @@ where beobachten.clone(), )); } - HandshakeResult::Error(e) => return Err(e), + HandshakeResult::Error(e) => { + increment_bad_on_unknown_tls_sni(stats.as_ref(), &e); + return Err(e); + } }; debug!(peer = %peer, "Reading MTProto handshake through TLS"); @@ -884,7 +969,10 @@ impl RunningClientHandler { self.beobachten.clone(), )); } - HandshakeResult::Error(e) => return Err(e), + HandshakeResult::Error(e) => { + increment_bad_on_unknown_tls_sni(stats.as_ref(), &e); + return Err(e); + } }; debug!(peer = %peer, "Reading MTProto handshake through TLS"); @@ -1153,7 +1241,7 @@ impl RunningClientHandler { } if let Some(quota) = config.access.user_data_quota.get(user) - && stats.get_user_total_octets(user) >= *quota + && stats.get_user_quota_used(user) >= *quota { return Err(ProxyError::DataQuotaExceeded { user: user.to_string(), @@ -1212,7 +1300,7 @@ impl RunningClientHandler { } if let Some(quota) = config.access.user_data_quota.get(user) - && stats.get_user_total_octets(user) >= *quota + && stats.get_user_quota_used(user) >= *quota { return Err(ProxyError::DataQuotaExceeded { user: user.to_string(), @@ -1321,6 +1409,38 @@ mod masking_shape_classifier_fuzz_redteam_expected_fail_tests; #[path = "tests/client_masking_probe_evasion_blackhat_tests.rs"] mod masking_probe_evasion_blackhat_tests; +#[cfg(test)] +#[path = "tests/client_masking_fragmented_classifier_security_tests.rs"] +mod masking_fragmented_classifier_security_tests; + +#[cfg(test)] +#[path = "tests/client_masking_replay_timing_security_tests.rs"] +mod masking_replay_timing_security_tests; + +#[cfg(test)] +#[path = "tests/client_masking_http2_fragmented_preface_security_tests.rs"] +mod masking_http2_fragmented_preface_security_tests; + +#[cfg(test)] +#[path = "tests/client_masking_prefetch_invariant_security_tests.rs"] +mod masking_prefetch_invariant_security_tests; + +#[cfg(test)] +#[path = "tests/client_masking_prefetch_timing_matrix_security_tests.rs"] +mod masking_prefetch_timing_matrix_security_tests; + +#[cfg(test)] +#[path = "tests/client_masking_prefetch_config_runtime_security_tests.rs"] +mod masking_prefetch_config_runtime_security_tests; + +#[cfg(test)] +#[path = "tests/client_masking_prefetch_config_pipeline_integration_security_tests.rs"] +mod masking_prefetch_config_pipeline_integration_security_tests; + +#[cfg(test)] +#[path = "tests/client_masking_prefetch_strict_boundary_security_tests.rs"] +mod masking_prefetch_strict_boundary_security_tests; + #[cfg(test)] #[path = "tests/client_beobachten_ttl_bounds_security_tests.rs"] mod beobachten_ttl_bounds_security_tests; @@ -1328,3 +1448,15 @@ mod beobachten_ttl_bounds_security_tests; #[cfg(test)] #[path = "tests/client_tls_record_wrap_hardening_security_tests.rs"] mod tls_record_wrap_hardening_security_tests; + +#[cfg(test)] +#[path = "tests/client_clever_advanced_tests.rs"] +mod client_clever_advanced_tests; + +#[cfg(test)] +#[path = "tests/client_more_advanced_tests.rs"] +mod client_more_advanced_tests; + +#[cfg(test)] +#[path = "tests/client_deep_invariants_tests.rs"] +mod client_deep_invariants_tests; diff --git a/src/proxy/handshake.rs b/src/proxy/handshake.rs index 5632977..fbaffa2 100644 --- a/src/proxy/handshake.rs +++ b/src/proxy/handshake.rs @@ -13,10 +13,10 @@ use std::sync::Arc; use std::sync::{Mutex, OnceLock}; use std::time::{Duration, Instant}; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; -use tracing::{debug, trace, warn}; +use tracing::{debug, info, trace, warn}; use zeroize::{Zeroize, Zeroizing}; -use crate::config::ProxyConfig; +use crate::config::{ProxyConfig, UnknownSniAction}; use crate::crypto::{AesCtr, SecureRandom, sha256}; use crate::error::{HandshakeResult, ProxyError}; use crate::protocol::constants::*; @@ -28,6 +28,8 @@ use rand::RngExt; const ACCESS_SECRET_BYTES: usize = 16; static INVALID_SECRET_WARNED: OnceLock>> = OnceLock::new(); +const UNKNOWN_SNI_WARN_COOLDOWN_SECS: u64 = 5; +static UNKNOWN_SNI_WARN_NEXT_ALLOWED: OnceLock>> = OnceLock::new(); #[cfg(test)] const WARNED_SECRET_MAX_ENTRIES: usize = 64; #[cfg(not(test))] @@ -86,6 +88,24 @@ fn auth_probe_saturation_state_lock() .unwrap_or_else(|poisoned| poisoned.into_inner()) } +fn unknown_sni_warn_state_lock() -> std::sync::MutexGuard<'static, Option> { + UNKNOWN_SNI_WARN_NEXT_ALLOWED + .get_or_init(|| Mutex::new(None)) + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +fn should_emit_unknown_sni_warn(now: Instant) -> bool { + let mut guard = unknown_sni_warn_state_lock(); + if let Some(next_allowed) = *guard + && now < next_allowed + { + return false; + } + *guard = Some(now + Duration::from_secs(UNKNOWN_SNI_WARN_COOLDOWN_SECS)); + true +} + fn normalize_auth_probe_ip(peer_ip: IpAddr) -> IpAddr { match peer_ip { IpAddr::V4(ip) => IpAddr::V4(ip), @@ -121,6 +141,19 @@ fn auth_probe_eviction_offset(peer_ip: IpAddr, now: Instant) -> usize { hasher.finish() as usize } +fn auth_probe_scan_start_offset( + peer_ip: IpAddr, + now: Instant, + state_len: usize, + scan_limit: usize, +) -> usize { + if state_len == 0 || scan_limit == 0 { + return 0; + } + + auth_probe_eviction_offset(peer_ip, now) % state_len +} + fn auth_probe_is_throttled(peer_ip: IpAddr, now: Instant) -> bool { let peer_ip = normalize_auth_probe_ip(peer_ip); let state = auth_probe_state_map(); @@ -269,34 +302,9 @@ fn auth_probe_record_failure_with_state( let mut eviction_candidate: Option<(IpAddr, u32, Instant)> = None; let state_len = state.len(); let scan_limit = state_len.min(AUTH_PROBE_PRUNE_SCAN_LIMIT); - let start_offset = if state_len == 0 { - 0 - } else { - auth_probe_eviction_offset(peer_ip, now) % state_len - }; - let mut scanned = 0usize; - for entry in state.iter().skip(start_offset) { - let key = *entry.key(); - let fail_streak = entry.value().fail_streak; - let last_seen = entry.value().last_seen; - match eviction_candidate { - Some((_, current_fail, current_seen)) - if fail_streak > current_fail - || (fail_streak == current_fail && last_seen >= current_seen) => {} - _ => eviction_candidate = Some((key, fail_streak, last_seen)), - } - if auth_probe_state_expired(entry.value(), now) { - stale_keys.push(key); - } - scanned += 1; - if scanned >= scan_limit { - break; - } - } - - if scanned < scan_limit { - for entry in state.iter().take(scan_limit - scanned) { + if state_len <= AUTH_PROBE_PRUNE_SCAN_LIMIT { + for entry in state.iter() { let key = *entry.key(); let fail_streak = entry.value().fail_streak; let last_seen = entry.value().last_seen; @@ -310,6 +318,46 @@ fn auth_probe_record_failure_with_state( stale_keys.push(key); } } + } else { + let start_offset = + auth_probe_scan_start_offset(peer_ip, now, state_len, scan_limit); + let mut scanned = 0usize; + for entry in state.iter().skip(start_offset) { + let key = *entry.key(); + let fail_streak = entry.value().fail_streak; + let last_seen = entry.value().last_seen; + match eviction_candidate { + Some((_, current_fail, current_seen)) + if fail_streak > current_fail + || (fail_streak == current_fail && last_seen >= current_seen) => {} + _ => eviction_candidate = Some((key, fail_streak, last_seen)), + } + if auth_probe_state_expired(entry.value(), now) { + stale_keys.push(key); + } + scanned += 1; + if scanned >= scan_limit { + break; + } + } + + if scanned < scan_limit { + for entry in state.iter().take(scan_limit - scanned) { + let key = *entry.key(); + let fail_streak = entry.value().fail_streak; + let last_seen = entry.value().last_seen; + match eviction_candidate { + Some((_, current_fail, current_seen)) + if fail_streak > current_fail + || (fail_streak == current_fail + && last_seen >= current_seen) => {} + _ => eviction_candidate = Some((key, fail_streak, last_seen)), + } + if auth_probe_state_expired(entry.value(), now) { + stale_keys.push(key); + } + } + } } for stale_key in stale_keys { @@ -384,6 +432,25 @@ fn auth_probe_test_lock() -> &'static Mutex<()> { TEST_LOCK.get_or_init(|| Mutex::new(())) } +#[cfg(test)] +fn unknown_sni_warn_test_lock() -> &'static Mutex<()> { + static TEST_LOCK: OnceLock> = OnceLock::new(); + TEST_LOCK.get_or_init(|| Mutex::new(())) +} + +#[cfg(test)] +fn clear_unknown_sni_warn_state_for_testing() { + if UNKNOWN_SNI_WARN_NEXT_ALLOWED.get().is_some() { + let mut guard = unknown_sni_warn_state_lock(); + *guard = None; + } +} + +#[cfg(test)] +fn should_emit_unknown_sni_warn_for_testing(now: Instant) -> bool { + should_emit_unknown_sni_warn(now) +} + #[cfg(test)] fn clear_warned_secrets_for_testing() { if let Some(warned) = INVALID_SECRET_WARNED.get() @@ -501,6 +568,21 @@ fn decode_user_secrets( secrets } +#[inline] +fn find_matching_tls_domain<'a>(config: &'a ProxyConfig, sni: &str) -> Option<&'a str> { + if config.censorship.tls_domain.eq_ignore_ascii_case(sni) { + return Some(config.censorship.tls_domain.as_str()); + } + + for domain in &config.censorship.tls_domains { + if domain.eq_ignore_ascii_case(sni) { + return Some(domain.as_str()); + } + } + + None +} + async fn maybe_apply_server_hello_delay(config: &ProxyConfig) { if config.censorship.server_hello_delay_max_ms == 0 { return; @@ -584,70 +666,12 @@ where } let client_sni = tls::extract_sni_from_client_hello(handshake); - let secrets = decode_user_secrets(config, client_sni.as_deref()); - - let validation = match tls::validate_tls_handshake_with_replay_window( - handshake, - &secrets, - config.access.ignore_time_skew, - config.access.replay_window_secs, - ) { - Some(v) => v, - None => { - auth_probe_record_failure(peer.ip(), Instant::now()); - maybe_apply_server_hello_delay(config).await; - debug!( - peer = %peer, - ignore_time_skew = config.access.ignore_time_skew, - "TLS handshake validation failed - no matching user or time skew" - ); - return HandshakeResult::BadClient { reader, writer }; - } - }; - - // Replay tracking is applied only after successful authentication to avoid - // letting unauthenticated probes evict valid entries from the replay cache. - let digest_half = &validation.digest[..tls::TLS_DIGEST_HALF_LEN]; - if replay_checker.check_and_add_tls_digest(digest_half) { - auth_probe_record_failure(peer.ip(), Instant::now()); - maybe_apply_server_hello_delay(config).await; - warn!(peer = %peer, "TLS replay attack detected (duplicate digest)"); - return HandshakeResult::BadClient { reader, writer }; - } - - let secret = match secrets.iter().find(|(name, _)| *name == validation.user) { - Some((_, s)) => s, - None => { - maybe_apply_server_hello_delay(config).await; - return HandshakeResult::BadClient { reader, writer }; - } - }; - - let cached = if config.censorship.tls_emulation { - if let Some(cache) = tls_cache.as_ref() { - let selected_domain = if let Some(sni) = client_sni.as_ref() { - if cache.contains_domain(sni).await { - sni.clone() - } else { - config.censorship.tls_domain.clone() - } - } else { - config.censorship.tls_domain.clone() - }; - let cached_entry = cache.get(&selected_domain).await; - let use_full_cert_payload = cache - .take_full_cert_budget_for_ip( - peer.ip(), - Duration::from_secs(config.censorship.tls_full_cert_ttl_secs), - ) - .await; - Some((cached_entry, use_full_cert_payload)) - } else { - None - } - } else { - None - }; + let preferred_user_hint = client_sni + .as_deref() + .filter(|sni| config.access.users.contains_key(*sni)); + let matched_tls_domain = client_sni + .as_deref() + .and_then(|sni| find_matching_tls_domain(config, sni)); let alpn_list = if config.censorship.alpn_enforce { tls::extract_alpn_from_client_hello(handshake) @@ -670,6 +694,94 @@ where None }; + if client_sni.is_some() && matched_tls_domain.is_none() && preferred_user_hint.is_none() { + auth_probe_record_failure(peer.ip(), Instant::now()); + maybe_apply_server_hello_delay(config).await; + let sni = client_sni.as_deref().unwrap_or_default(); + let log_now = Instant::now(); + if should_emit_unknown_sni_warn(log_now) { + warn!( + peer = %peer, + sni = %sni, + unknown_sni = true, + unknown_sni_action = ?config.censorship.unknown_sni_action, + "TLS handshake rejected by unknown SNI policy" + ); + } else { + info!( + peer = %peer, + sni = %sni, + unknown_sni = true, + unknown_sni_action = ?config.censorship.unknown_sni_action, + "TLS handshake rejected by unknown SNI policy" + ); + } + return match config.censorship.unknown_sni_action { + UnknownSniAction::Drop => HandshakeResult::Error(ProxyError::UnknownTlsSni), + UnknownSniAction::Mask => HandshakeResult::BadClient { reader, writer }, + }; + } + + let secrets = decode_user_secrets(config, preferred_user_hint); + + let validation = match tls::validate_tls_handshake_with_replay_window( + handshake, + &secrets, + config.access.ignore_time_skew, + config.access.replay_window_secs, + ) { + Some(v) => v, + None => { + auth_probe_record_failure(peer.ip(), Instant::now()); + maybe_apply_server_hello_delay(config).await; + debug!( + peer = %peer, + ignore_time_skew = config.access.ignore_time_skew, + "TLS handshake validation failed - no matching user or time skew" + ); + return HandshakeResult::BadClient { reader, writer }; + } + }; + + // Reject known replay digests before expensive cache/domain/ALPN policy work. + let digest_half = &validation.digest[..tls::TLS_DIGEST_HALF_LEN]; + if replay_checker.check_tls_digest(digest_half) { + auth_probe_record_failure(peer.ip(), Instant::now()); + maybe_apply_server_hello_delay(config).await; + warn!(peer = %peer, "TLS replay attack detected (duplicate digest)"); + return HandshakeResult::BadClient { reader, writer }; + } + + let secret = match secrets.iter().find(|(name, _)| *name == validation.user) { + Some((_, s)) => s, + None => { + maybe_apply_server_hello_delay(config).await; + return HandshakeResult::BadClient { reader, writer }; + } + }; + + let cached = if config.censorship.tls_emulation { + if let Some(cache) = tls_cache.as_ref() { + let selected_domain = + matched_tls_domain.unwrap_or(config.censorship.tls_domain.as_str()); + let cached_entry = cache.get(selected_domain).await; + let use_full_cert_payload = cache + .take_full_cert_budget_for_ip( + peer.ip(), + Duration::from_secs(config.censorship.tls_full_cert_ttl_secs), + ) + .await; + Some((cached_entry, use_full_cert_payload)) + } else { + None + } + } else { + None + }; + + // Add replay digest only for policy-valid handshakes. + replay_checker.add_tls_digest(digest_half); + let response = if let Some((cached_entry, use_full_cert_payload)) = cached { emulator::build_emulated_server_hello( secret, @@ -769,7 +881,7 @@ where let mut dec_key_input = Zeroizing::new(Vec::with_capacity(PREKEY_LEN + secret.len())); dec_key_input.extend_from_slice(dec_prekey); dec_key_input.extend_from_slice(&secret); - let dec_key = sha256(&dec_key_input); + let dec_key = Zeroizing::new(sha256(&dec_key_input)); let mut dec_iv_arr = [0u8; IV_LEN]; dec_iv_arr.copy_from_slice(dec_iv_bytes); @@ -805,7 +917,7 @@ where let mut enc_key_input = Zeroizing::new(Vec::with_capacity(PREKEY_LEN + secret.len())); enc_key_input.extend_from_slice(enc_prekey); enc_key_input.extend_from_slice(&secret); - let enc_key = sha256(&enc_key_input); + let enc_key = Zeroizing::new(sha256(&enc_key_input)); let mut enc_iv_arr = [0u8; IV_LEN]; enc_iv_arr.copy_from_slice(enc_iv_bytes); @@ -830,9 +942,9 @@ where user: user.clone(), dc_idx, proto_tag, - dec_key, + dec_key: *dec_key, dec_iv, - enc_key, + enc_key: *enc_key, enc_iv, peer, is_tls, @@ -979,6 +1091,38 @@ mod saturation_poison_security_tests; #[path = "tests/handshake_auth_probe_hardening_adversarial_tests.rs"] mod auth_probe_hardening_adversarial_tests; +#[cfg(test)] +#[path = "tests/handshake_auth_probe_scan_budget_security_tests.rs"] +mod auth_probe_scan_budget_security_tests; + +#[cfg(test)] +#[path = "tests/handshake_auth_probe_scan_offset_stress_tests.rs"] +mod auth_probe_scan_offset_stress_tests; + +#[cfg(test)] +#[path = "tests/handshake_auth_probe_eviction_bias_security_tests.rs"] +mod auth_probe_eviction_bias_security_tests; + +#[cfg(test)] +#[path = "tests/handshake_advanced_clever_tests.rs"] +mod advanced_clever_tests; + +#[cfg(test)] +#[path = "tests/handshake_more_clever_tests.rs"] +mod more_clever_tests; + +#[cfg(test)] +#[path = "tests/handshake_real_bug_stress_tests.rs"] +mod real_bug_stress_tests; + +#[cfg(test)] +#[path = "tests/handshake_timing_manual_bench_tests.rs"] +mod timing_manual_bench_tests; + +#[cfg(test)] +#[path = "tests/handshake_key_material_zeroization_security_tests.rs"] +mod handshake_key_material_zeroization_security_tests; + /// Compile-time guard: HandshakeSuccess holds cryptographic key material and /// must never be Copy. A Copy impl would allow silent key duplication, /// undermining the zeroize-on-drop guarantee. diff --git a/src/proxy/masking.rs b/src/proxy/masking.rs index 3639db1..ba9f20a 100644 --- a/src/proxy/masking.rs +++ b/src/proxy/masking.rs @@ -4,14 +4,23 @@ use crate::config::ProxyConfig; use crate::network::dns_overrides::resolve_socket_addr; use crate::stats::beobachten::BeobachtenStore; use crate::transport::proxy_protocol::{ProxyProtocolV1Builder, ProxyProtocolV2Builder}; -use rand::{Rng, RngExt}; -use std::net::SocketAddr; +#[cfg(unix)] +use nix::ifaddrs::getifaddrs; +use rand::rngs::StdRng; +use rand::{Rng, RngExt, SeedableRng}; +use std::net::{IpAddr, SocketAddr}; use std::str; -use std::time::Duration; +#[cfg(test)] +use std::sync::atomic::{AtomicUsize, Ordering}; +#[cfg(unix)] +use std::sync::{Mutex, OnceLock}; +use std::time::{Duration, Instant as StdInstant}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::net::TcpStream; #[cfg(unix)] use tokio::net::UnixStream; +#[cfg(unix)] +use tokio::sync::Mutex as AsyncMutex; use tokio::time::{Instant, timeout}; use tracing::debug; @@ -30,28 +39,55 @@ const MASK_RELAY_IDLE_TIMEOUT: Duration = Duration::from_secs(5); #[cfg(test)] const MASK_RELAY_IDLE_TIMEOUT: Duration = Duration::from_millis(100); const MASK_BUFFER_SIZE: usize = 8192; +#[cfg(unix)] +#[cfg(not(test))] +const LOCAL_INTERFACE_CACHE_TTL: Duration = Duration::from_secs(300); +#[cfg(all(unix, test))] +const LOCAL_INTERFACE_CACHE_TTL: Duration = Duration::from_secs(1); struct CopyOutcome { total: usize, ended_by_eof: bool, } -async fn copy_with_idle_timeout(reader: &mut R, writer: &mut W) -> CopyOutcome +async fn copy_with_idle_timeout( + reader: &mut R, + writer: &mut W, + byte_cap: usize, + shutdown_on_eof: bool, +) -> CopyOutcome where R: AsyncRead + Unpin, W: AsyncWrite + Unpin, { - let mut buf = [0u8; MASK_BUFFER_SIZE]; + let mut buf = Box::new([0u8; MASK_BUFFER_SIZE]); let mut total = 0usize; let mut ended_by_eof = false; + + if byte_cap == 0 { + return CopyOutcome { + total, + ended_by_eof, + }; + } + loop { - let read_res = timeout(MASK_RELAY_IDLE_TIMEOUT, reader.read(&mut buf)).await; + let remaining_budget = byte_cap.saturating_sub(total); + if remaining_budget == 0 { + break; + } + + let read_len = remaining_budget.min(MASK_BUFFER_SIZE); + let read_res = timeout(MASK_RELAY_IDLE_TIMEOUT, reader.read(&mut buf[..read_len])).await; let n = match read_res { Ok(Ok(n)) => n, Ok(Err(_)) | Err(_) => break, }; if n == 0 { ended_by_eof = true; + if shutdown_on_eof { + let _ = timeout(MASK_RELAY_IDLE_TIMEOUT, writer.shutdown()).await; + } break; } total = total.saturating_add(n); @@ -68,6 +104,31 @@ where } } +fn is_http_probe(data: &[u8]) -> bool { + // RFC 7540 section 3.5: HTTP/2 client preface starts with "PRI ". + const HTTP_METHODS: [&[u8]; 10] = [ + b"GET ", b"POST", b"HEAD", b"PUT ", b"DELETE", b"OPTIONS", b"CONNECT", b"TRACE", b"PATCH", + b"PRI ", + ]; + + if data.is_empty() { + return false; + } + + let window = &data[..data.len().min(16)]; + for method in HTTP_METHODS { + if data.len() >= method.len() && window.starts_with(method) { + return true; + } + + if (2..=3).contains(&window.len()) && method.starts_with(window) { + return true; + } + } + + false +} + fn next_mask_shape_bucket(total: usize, floor: usize, cap: usize) -> usize { if total == 0 || floor == 0 || cap < floor { return total; @@ -125,6 +186,11 @@ async fn maybe_write_shape_padding( let mut remaining = target_total - total_sent; let mut pad_chunk = [0u8; 1024]; let deadline = Instant::now() + MASK_TIMEOUT; + // Use a Send RNG so relay futures remain spawn-safe under Tokio. + let mut rng = { + let mut seed_source = rand::rng(); + StdRng::from_rng(&mut seed_source) + }; while remaining > 0 { let now = Instant::now(); @@ -133,10 +199,7 @@ async fn maybe_write_shape_padding( } let write_len = remaining.min(pad_chunk.len()); - { - let mut rng = rand::rng(); - rng.fill_bytes(&mut pad_chunk[..write_len]); - } + rng.fill_bytes(&mut pad_chunk[..write_len]); let write_budget = deadline.saturating_duration_since(now); match timeout(write_budget, mask_write.write_all(&pad_chunk[..write_len])).await { Ok(Ok(())) => {} @@ -167,11 +230,11 @@ where } } -async fn consume_client_data_with_timeout(reader: R) +async fn consume_client_data_with_timeout_and_cap(reader: R, byte_cap: usize) where R: AsyncRead + Unpin, { - if timeout(MASK_RELAY_TIMEOUT, consume_client_data(reader)) + if timeout(MASK_RELAY_TIMEOUT, consume_client_data(reader, byte_cap)) .await .is_err() { @@ -190,6 +253,13 @@ fn mask_outcome_target_budget(config: &ProxyConfig) -> Duration { if config.censorship.mask_timing_normalization_enabled { let floor = config.censorship.mask_timing_normalization_floor_ms; let ceiling = config.censorship.mask_timing_normalization_ceiling_ms; + if floor == 0 { + if ceiling == 0 { + return Duration::from_millis(0); + } + let mut rng = rand::rng(); + return Duration::from_millis(rng.random_range(0..=ceiling)); + } if ceiling > floor { let mut rng = rand::rng(); return Duration::from_millis(rng.random_range(floor..=ceiling)); @@ -219,14 +289,7 @@ async fn wait_mask_outcome_budget(started: Instant, config: &ProxyConfig) { /// Detect client type based on initial data fn detect_client_type(data: &[u8]) -> &'static str { // Check for HTTP request - if data.len() > 4 - && (data.starts_with(b"GET ") - || data.starts_with(b"POST") - || data.starts_with(b"HEAD") - || data.starts_with(b"PUT ") - || data.starts_with(b"DELETE") - || data.starts_with(b"OPTIONS")) - { + if is_http_probe(data) { return "HTTP"; } @@ -248,6 +311,247 @@ fn detect_client_type(data: &[u8]) -> &'static str { "unknown" } +fn parse_mask_host_ip_literal(host: &str) -> Option { + if host.starts_with('[') && host.ends_with(']') { + return host[1..host.len() - 1].parse::().ok(); + } + host.parse::().ok() +} + +fn canonical_ip(ip: IpAddr) -> IpAddr { + match ip { + IpAddr::V6(v6) => v6 + .to_ipv4_mapped() + .map(IpAddr::V4) + .unwrap_or(IpAddr::V6(v6)), + IpAddr::V4(v4) => IpAddr::V4(v4), + } +} + +#[cfg(unix)] +fn collect_local_interface_ips() -> Vec { + #[cfg(test)] + LOCAL_INTERFACE_ENUMERATIONS.fetch_add(1, Ordering::Relaxed); + + let mut out = Vec::new(); + if let Ok(addrs) = getifaddrs() { + for iface in addrs { + if let Some(address) = iface.address { + if let Some(v4) = address.as_sockaddr_in() { + out.push(canonical_ip(IpAddr::V4(v4.ip()))); + } else if let Some(v6) = address.as_sockaddr_in6() { + out.push(canonical_ip(IpAddr::V6(v6.ip()))); + } + } + } + } + out +} + +fn choose_interface_snapshot(previous: &[IpAddr], refreshed: Vec) -> Vec { + if refreshed.is_empty() && !previous.is_empty() { + return previous.to_vec(); + } + + refreshed +} + +#[cfg(unix)] +#[derive(Default)] +struct LocalInterfaceCache { + ips: Vec, + refreshed_at: Option, +} + +#[cfg(unix)] +static LOCAL_INTERFACE_CACHE: OnceLock> = OnceLock::new(); + +#[cfg(unix)] +static LOCAL_INTERFACE_REFRESH_LOCK: OnceLock> = OnceLock::new(); + +#[cfg(all(unix, test))] +fn local_interface_ips() -> Vec { + let cache = LOCAL_INTERFACE_CACHE.get_or_init(|| Mutex::new(LocalInterfaceCache::default())); + let mut guard = cache.lock().unwrap_or_else(|poison| poison.into_inner()); + + let stale = guard + .refreshed_at + .is_none_or(|at| at.elapsed() >= LOCAL_INTERFACE_CACHE_TTL); + if stale { + let refreshed = collect_local_interface_ips(); + guard.ips = choose_interface_snapshot(&guard.ips, refreshed); + guard.refreshed_at = Some(StdInstant::now()); + } + + guard.ips.clone() +} + +#[cfg(unix)] +async fn local_interface_ips_async() -> Vec { + let cache = LOCAL_INTERFACE_CACHE.get_or_init(|| Mutex::new(LocalInterfaceCache::default())); + + { + let guard = cache.lock().unwrap_or_else(|poison| poison.into_inner()); + let stale = guard + .refreshed_at + .is_none_or(|at| at.elapsed() >= LOCAL_INTERFACE_CACHE_TTL); + if !stale { + return guard.ips.clone(); + } + } + + let refresh_lock = LOCAL_INTERFACE_REFRESH_LOCK.get_or_init(|| AsyncMutex::new(())); + let _refresh_guard = refresh_lock.lock().await; + + { + let guard = cache.lock().unwrap_or_else(|poison| poison.into_inner()); + let stale = guard + .refreshed_at + .is_none_or(|at| at.elapsed() >= LOCAL_INTERFACE_CACHE_TTL); + if !stale { + return guard.ips.clone(); + } + } + + let refreshed = tokio::task::spawn_blocking(collect_local_interface_ips) + .await + .unwrap_or_default(); + + let mut guard = cache.lock().unwrap_or_else(|poison| poison.into_inner()); + let stale = guard + .refreshed_at + .is_none_or(|at| at.elapsed() >= LOCAL_INTERFACE_CACHE_TTL); + if stale { + guard.ips = choose_interface_snapshot(&guard.ips, refreshed); + guard.refreshed_at = Some(StdInstant::now()); + } + + guard.ips.clone() +} + +#[cfg(all(not(unix), test))] +fn local_interface_ips() -> Vec { + Vec::new() +} + +#[cfg(not(unix))] +async fn local_interface_ips_async() -> Vec { + Vec::new() +} + +#[cfg(test)] +static LOCAL_INTERFACE_ENUMERATIONS: AtomicUsize = AtomicUsize::new(0); + +#[cfg(test)] +fn reset_local_interface_enumerations_for_tests() { + LOCAL_INTERFACE_ENUMERATIONS.store(0, Ordering::Relaxed); + + #[cfg(unix)] + if let Some(cache) = LOCAL_INTERFACE_CACHE.get() { + let mut guard = cache.lock().unwrap_or_else(|poison| poison.into_inner()); + guard.ips.clear(); + guard.refreshed_at = None; + } +} + +#[cfg(test)] +fn local_interface_enumerations_for_tests() -> usize { + LOCAL_INTERFACE_ENUMERATIONS.load(Ordering::Relaxed) +} + +fn is_mask_target_local_listener_with_interfaces( + mask_host: &str, + mask_port: u16, + local_addr: SocketAddr, + resolved_override: Option, + interface_ips: &[IpAddr], +) -> bool { + if mask_port != local_addr.port() { + return false; + } + + let local_ip = canonical_ip(local_addr.ip()); + let literal_mask_ip = parse_mask_host_ip_literal(mask_host).map(canonical_ip); + + if let Some(addr) = resolved_override { + let resolved_ip = canonical_ip(addr.ip()); + if resolved_ip == local_ip { + return true; + } + + if local_ip.is_unspecified() + && (resolved_ip.is_loopback() + || resolved_ip.is_unspecified() + || interface_ips.contains(&resolved_ip)) + { + return true; + } + } + + if let Some(mask_ip) = literal_mask_ip { + if mask_ip == local_ip { + return true; + } + + if local_ip.is_unspecified() + && (mask_ip.is_loopback() + || mask_ip.is_unspecified() + || interface_ips.contains(&mask_ip)) + { + return true; + } + } + + false +} + +#[cfg(test)] +fn is_mask_target_local_listener( + mask_host: &str, + mask_port: u16, + local_addr: SocketAddr, + resolved_override: Option, +) -> bool { + if mask_port != local_addr.port() { + return false; + } + + let interfaces = local_interface_ips(); + is_mask_target_local_listener_with_interfaces( + mask_host, + mask_port, + local_addr, + resolved_override, + &interfaces, + ) +} + +async fn is_mask_target_local_listener_async( + mask_host: &str, + mask_port: u16, + local_addr: SocketAddr, + resolved_override: Option, +) -> bool { + if mask_port != local_addr.port() { + return false; + } + + let interfaces = local_interface_ips_async().await; + is_mask_target_local_listener_with_interfaces( + mask_host, + mask_port, + local_addr, + resolved_override, + &interfaces, + ) +} + +fn masking_beobachten_ttl(config: &ProxyConfig) -> Duration { + let minutes = config.general.beobachten_minutes; + let clamped = minutes.clamp(1, 24 * 60); + Duration::from_secs(clamped.saturating_mul(60)) +} + fn build_mask_proxy_header( version: u8, peer: SocketAddr, @@ -290,13 +594,14 @@ pub async fn handle_bad_client( { let client_type = detect_client_type(initial_data); if config.general.beobachten { - let ttl = Duration::from_secs(config.general.beobachten_minutes.saturating_mul(60)); + let ttl = masking_beobachten_ttl(config); beobachten.record(client_type, peer.ip(), ttl); } if !config.censorship.mask { // Masking disabled, just consume data - consume_client_data_with_timeout(reader).await; + consume_client_data_with_timeout_and_cap(reader, config.censorship.mask_relay_max_bytes) + .await; return; } @@ -341,6 +646,7 @@ pub async fn handle_bad_client( config.censorship.mask_shape_above_cap_blur, config.censorship.mask_shape_above_cap_blur_max_bytes, config.censorship.mask_shape_hardening_aggressive_mode, + config.censorship.mask_relay_max_bytes, ), ) .await @@ -353,12 +659,20 @@ pub async fn handle_bad_client( Ok(Err(e)) => { wait_mask_connect_budget_if_needed(connect_started, config).await; debug!(error = %e, "Failed to connect to mask unix socket"); - consume_client_data_with_timeout(reader).await; + consume_client_data_with_timeout_and_cap( + reader, + config.censorship.mask_relay_max_bytes, + ) + .await; wait_mask_outcome_budget(outcome_started, config).await; } Err(_) => { debug!("Timeout connecting to mask unix socket"); - consume_client_data_with_timeout(reader).await; + consume_client_data_with_timeout_and_cap( + reader, + config.censorship.mask_relay_max_bytes, + ) + .await; wait_mask_outcome_budget(outcome_started, config).await; } } @@ -372,6 +686,29 @@ pub async fn handle_bad_client( .unwrap_or(&config.censorship.tls_domain); let mask_port = config.censorship.mask_port; + // Fail closed when fallback points at our own listener endpoint. + // Self-referential masking can create recursive proxy loops under + // misconfiguration and leak distinguishable load spikes to adversaries. + let resolved_mask_addr = resolve_socket_addr(mask_host, mask_port); + if is_mask_target_local_listener_async(mask_host, mask_port, local_addr, resolved_mask_addr) + .await + { + let outcome_started = Instant::now(); + debug!( + client_type = client_type, + host = %mask_host, + port = mask_port, + local = %local_addr, + "Mask target resolves to local listener; refusing self-referential masking fallback" + ); + consume_client_data_with_timeout_and_cap(reader, config.censorship.mask_relay_max_bytes) + .await; + wait_mask_outcome_budget(outcome_started, config).await; + return; + } + + let outcome_started = Instant::now(); + debug!( client_type = client_type, host = %mask_host, @@ -381,10 +718,9 @@ pub async fn handle_bad_client( ); // Apply runtime DNS override for mask target when configured. - let mask_addr = resolve_socket_addr(mask_host, mask_port) + let mask_addr = resolved_mask_addr .map(|addr| addr.to_string()) .unwrap_or_else(|| format!("{}:{}", mask_host, mask_port)); - let outcome_started = Instant::now(); let connect_started = Instant::now(); let connect_result = timeout(MASK_TIMEOUT, TcpStream::connect(&mask_addr)).await; match connect_result { @@ -413,6 +749,7 @@ pub async fn handle_bad_client( config.censorship.mask_shape_above_cap_blur, config.censorship.mask_shape_above_cap_blur_max_bytes, config.censorship.mask_shape_hardening_aggressive_mode, + config.censorship.mask_relay_max_bytes, ), ) .await @@ -425,12 +762,20 @@ pub async fn handle_bad_client( Ok(Err(e)) => { wait_mask_connect_budget_if_needed(connect_started, config).await; debug!(error = %e, "Failed to connect to mask host"); - consume_client_data_with_timeout(reader).await; + consume_client_data_with_timeout_and_cap( + reader, + config.censorship.mask_relay_max_bytes, + ) + .await; wait_mask_outcome_budget(outcome_started, config).await; } Err(_) => { debug!("Timeout connecting to mask host"); - consume_client_data_with_timeout(reader).await; + consume_client_data_with_timeout_and_cap( + reader, + config.censorship.mask_relay_max_bytes, + ) + .await; wait_mask_outcome_budget(outcome_started, config).await; } } @@ -449,6 +794,7 @@ async fn relay_to_mask( shape_above_cap_blur: bool, shape_above_cap_blur_max_bytes: usize, shape_hardening_aggressive_mode: bool, + mask_relay_max_bytes: usize, ) where R: AsyncRead + Unpin + Send + 'static, W: AsyncWrite + Unpin + Send + 'static, @@ -464,8 +810,18 @@ async fn relay_to_mask( } let (upstream_copy, downstream_copy) = tokio::join!( - async { copy_with_idle_timeout(&mut reader, &mut mask_write).await }, - async { copy_with_idle_timeout(&mut mask_read, &mut writer).await } + async { + copy_with_idle_timeout( + &mut reader, + &mut mask_write, + mask_relay_max_bytes, + !shape_hardening_enabled, + ) + .await + }, + async { + copy_with_idle_timeout(&mut mask_read, &mut writer, mask_relay_max_bytes, true).await + } ); let total_sent = initial_data.len().saturating_add(upstream_copy.total); @@ -491,13 +847,36 @@ async fn relay_to_mask( let _ = writer.shutdown().await; } -/// Just consume all data from client without responding -async fn consume_client_data(mut reader: R) { - let mut buf = vec![0u8; MASK_BUFFER_SIZE]; - while let Ok(n) = reader.read(&mut buf).await { +/// Just consume all data from client without responding. +async fn consume_client_data(mut reader: R, byte_cap: usize) { + if byte_cap == 0 { + return; + } + + // Keep drain path fail-closed under slow-loris stalls. + let mut buf = Box::new([0u8; MASK_BUFFER_SIZE]); + let mut total = 0usize; + + loop { + let remaining_budget = byte_cap.saturating_sub(total); + if remaining_budget == 0 { + break; + } + + let read_len = remaining_budget.min(MASK_BUFFER_SIZE); + let n = match timeout(MASK_RELAY_IDLE_TIMEOUT, reader.read(&mut buf[..read_len])).await { + Ok(Ok(n)) => n, + Ok(Err(_)) | Err(_) => break, + }; + if n == 0 { break; } + + total = total.saturating_add(n); + if total >= byte_cap { + break; + } } } @@ -521,6 +900,10 @@ mod masking_shape_above_cap_blur_security_tests; #[path = "tests/masking_timing_normalization_security_tests.rs"] mod masking_timing_normalization_security_tests; +#[cfg(test)] +#[path = "tests/masking_timing_budget_coupling_security_tests.rs"] +mod masking_timing_budget_coupling_security_tests; + #[cfg(test)] #[path = "tests/masking_ab_envelope_blur_integration_security_tests.rs"] mod masking_ab_envelope_blur_integration_security_tests; @@ -548,3 +931,75 @@ mod masking_aggressive_mode_security_tests; #[cfg(test)] #[path = "tests/masking_timing_sidechannel_redteam_expected_fail_tests.rs"] mod masking_timing_sidechannel_redteam_expected_fail_tests; + +#[cfg(test)] +#[path = "tests/masking_self_target_loop_security_tests.rs"] +mod masking_self_target_loop_security_tests; + +#[cfg(test)] +#[path = "tests/masking_classification_completeness_security_tests.rs"] +mod masking_classification_completeness_security_tests; + +#[cfg(test)] +#[path = "tests/masking_relay_guardrails_security_tests.rs"] +mod masking_relay_guardrails_security_tests; + +#[cfg(test)] +#[path = "tests/masking_connect_failure_close_matrix_security_tests.rs"] +mod masking_connect_failure_close_matrix_security_tests; + +#[cfg(test)] +#[path = "tests/masking_additional_hardening_security_tests.rs"] +mod masking_additional_hardening_security_tests; + +#[cfg(test)] +#[path = "tests/masking_consume_idle_timeout_security_tests.rs"] +mod masking_consume_idle_timeout_security_tests; + +#[cfg(test)] +#[path = "tests/masking_http2_probe_classification_security_tests.rs"] +mod masking_http2_probe_classification_security_tests; + +#[cfg(test)] +#[path = "tests/masking_http_probe_boundary_security_tests.rs"] +mod masking_http_probe_boundary_security_tests; + +#[cfg(test)] +#[path = "tests/masking_rng_hoist_perf_regression_tests.rs"] +mod masking_rng_hoist_perf_regression_tests; + +#[cfg(test)] +#[path = "tests/masking_http2_preface_integration_security_tests.rs"] +mod masking_http2_preface_integration_security_tests; + +#[cfg(test)] +#[path = "tests/masking_consume_stress_adversarial_tests.rs"] +mod masking_consume_stress_adversarial_tests; + +#[cfg(test)] +#[path = "tests/masking_interface_cache_security_tests.rs"] +mod masking_interface_cache_security_tests; + +#[cfg(test)] +#[path = "tests/masking_interface_cache_defense_in_depth_security_tests.rs"] +mod masking_interface_cache_defense_in_depth_security_tests; + +#[cfg(test)] +#[path = "tests/masking_interface_cache_concurrency_security_tests.rs"] +mod masking_interface_cache_concurrency_security_tests; + +#[cfg(test)] +#[path = "tests/masking_production_cap_regression_security_tests.rs"] +mod masking_production_cap_regression_security_tests; + +#[cfg(test)] +#[path = "tests/masking_extended_attack_surface_security_tests.rs"] +mod masking_extended_attack_surface_security_tests; + +#[cfg(test)] +#[path = "tests/masking_padding_timeout_adversarial_tests.rs"] +mod masking_padding_timeout_adversarial_tests; + +#[cfg(all(test, feature = "redteam_offline_expected_fail"))] +#[path = "tests/masking_offline_target_redteam_expected_fail_tests.rs"] +mod masking_offline_target_redteam_expected_fail_tests; diff --git a/src/proxy/middle_relay.rs b/src/proxy/middle_relay.rs index f56a606..d2f37a6 100644 --- a/src/proxy/middle_relay.rs +++ b/src/proxy/middle_relay.rs @@ -1,14 +1,16 @@ use std::collections::hash_map::RandomState; use std::collections::{BTreeSet, HashMap}; +#[cfg(test)] +use std::future::Future; use std::hash::{BuildHasher, Hash}; use std::net::{IpAddr, SocketAddr}; -use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; +use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::{Arc, Mutex, OnceLock}; use std::time::{Duration, Instant}; use dashmap::DashMap; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; -use tokio::sync::{Mutex as AsyncMutex, mpsc, oneshot, watch}; +use tokio::sync::{mpsc, oneshot, watch}; use tokio::time::timeout; use tracing::{debug, info, trace, warn}; @@ -21,7 +23,9 @@ use crate::proxy::route_mode::{ ROUTE_SWITCH_ERROR_MSG, RelayRouteMode, RouteCutoverState, affected_cutover_state, cutover_stagger_delay, }; -use crate::stats::Stats; +use crate::stats::{ + MeD2cFlushReason, MeD2cQuotaRejectStage, MeD2cWriteMode, QuotaReserveError, Stats, UserStats, +}; use crate::stream::{BufferPool, CryptoReader, CryptoWriter, PooledBuffer}; use crate::transport::middle_proxy::{MePool, MeResponse, proto_flags_for_tag}; @@ -32,36 +36,37 @@ enum C2MeCommand { const DESYNC_DEDUP_WINDOW: Duration = Duration::from_secs(60); const DESYNC_DEDUP_MAX_ENTRIES: usize = 65_536; -const DESYNC_DEDUP_PRUNE_SCAN_LIMIT: usize = 1024; const DESYNC_FULL_CACHE_EMIT_MIN_INTERVAL: Duration = Duration::from_millis(1000); const DESYNC_ERROR_CLASS: &str = "frame_too_large_crypto_desync"; const C2ME_CHANNEL_CAPACITY_FALLBACK: usize = 128; const C2ME_SOFT_PRESSURE_MIN_FREE_SLOTS: usize = 64; const C2ME_SENDER_FAIRNESS_BUDGET: usize = 32; const RELAY_IDLE_IO_POLL_MAX: Duration = Duration::from_secs(1); +const TINY_FRAME_DEBT_PER_TINY: u32 = 8; +const TINY_FRAME_DEBT_LIMIT: u32 = 512; #[cfg(test)] -const C2ME_SEND_TIMEOUT: Duration = Duration::from_millis(50); -#[cfg(not(test))] -const C2ME_SEND_TIMEOUT: Duration = Duration::from_secs(5); +const RELAY_TEST_STEP_TIMEOUT: Duration = Duration::from_secs(1); const ME_D2C_FLUSH_BATCH_MAX_FRAMES_MIN: usize = 1; const ME_D2C_FLUSH_BATCH_MAX_BYTES_MIN: usize = 4096; -#[cfg(test)] -const QUOTA_USER_LOCKS_MAX: usize = 64; -#[cfg(not(test))] -const QUOTA_USER_LOCKS_MAX: usize = 4_096; -#[cfg(test)] -const QUOTA_OVERFLOW_LOCK_STRIPES: usize = 16; -#[cfg(not(test))] -const QUOTA_OVERFLOW_LOCK_STRIPES: usize = 256; +const ME_D2C_FRAME_BUF_SHRINK_HYSTERESIS_FACTOR: usize = 2; +const ME_D2C_SINGLE_WRITE_COALESCE_MAX_BYTES: usize = 128 * 1024; +const QUOTA_RESERVE_SPIN_RETRIES: usize = 32; static DESYNC_DEDUP: OnceLock> = OnceLock::new(); +static DESYNC_DEDUP_PREVIOUS: OnceLock> = OnceLock::new(); static DESYNC_HASHER: OnceLock = OnceLock::new(); static DESYNC_FULL_CACHE_LAST_EMIT_AT: OnceLock>> = OnceLock::new(); -static DESYNC_DEDUP_EVER_SATURATED: OnceLock = OnceLock::new(); -static QUOTA_USER_LOCKS: OnceLock>>> = OnceLock::new(); -static QUOTA_USER_OVERFLOW_LOCKS: OnceLock>>> = OnceLock::new(); +static DESYNC_DEDUP_ROTATION_STATE: OnceLock> = OnceLock::new(); +// Invariant for async callers: +// this std::sync::Mutex is allowed only because critical sections are short, +// synchronous, and MUST never cross an `.await`. static RELAY_IDLE_CANDIDATE_REGISTRY: OnceLock> = OnceLock::new(); static RELAY_IDLE_MARK_SEQ: AtomicU64 = AtomicU64::new(0); +#[derive(Default)] +struct DesyncDedupRotationState { + current_started_at: Option, +} + struct RelayForensicsState { trace_id: u64, conn_id: u64, @@ -92,10 +97,25 @@ fn relay_idle_candidate_registry() -> &'static Mutex RELAY_IDLE_CANDIDATE_REGISTRY.get_or_init(|| Mutex::new(RelayIdleCandidateRegistry::default())) } +fn relay_idle_candidate_registry_lock() -> std::sync::MutexGuard<'static, RelayIdleCandidateRegistry> +{ + // Keep lock scope narrow and synchronous: callers must drop guard before any `.await`. + let registry = relay_idle_candidate_registry(); + match registry.lock() { + Ok(guard) => guard, + Err(poisoned) => { + let mut guard = poisoned.into_inner(); + // Fail closed after panic while holding registry lock: drop all + // candidates and pressure cursors to avoid stale cross-session state. + *guard = RelayIdleCandidateRegistry::default(); + registry.clear_poison(); + guard + } + } +} + fn mark_relay_idle_candidate(conn_id: u64) -> bool { - let Ok(mut guard) = relay_idle_candidate_registry().lock() else { - return false; - }; + let mut guard = relay_idle_candidate_registry_lock(); if guard.by_conn_id.contains_key(&conn_id) { return false; @@ -114,9 +134,7 @@ fn mark_relay_idle_candidate(conn_id: u64) -> bool { } fn clear_relay_idle_candidate(conn_id: u64) { - let Ok(mut guard) = relay_idle_candidate_registry().lock() else { - return; - }; + let mut guard = relay_idle_candidate_registry_lock(); if let Some(meta) = guard.by_conn_id.remove(&conn_id) { guard.ordered.remove(&(meta.mark_order_seq, conn_id)); @@ -125,23 +143,17 @@ fn clear_relay_idle_candidate(conn_id: u64) { #[cfg(test)] fn oldest_relay_idle_candidate() -> Option { - let Ok(guard) = relay_idle_candidate_registry().lock() else { - return None; - }; + let guard = relay_idle_candidate_registry_lock(); guard.ordered.iter().next().map(|(_, conn_id)| *conn_id) } fn note_relay_pressure_event() { - let Ok(mut guard) = relay_idle_candidate_registry().lock() else { - return; - }; + let mut guard = relay_idle_candidate_registry_lock(); guard.pressure_event_seq = guard.pressure_event_seq.wrapping_add(1); } fn relay_pressure_event_seq() -> u64 { - let Ok(guard) = relay_idle_candidate_registry().lock() else { - return 0; - }; + let guard = relay_idle_candidate_registry_lock(); guard.pressure_event_seq } @@ -150,9 +162,7 @@ fn maybe_evict_idle_candidate_on_pressure( seen_pressure_seq: &mut u64, stats: &Stats, ) -> bool { - let Ok(mut guard) = relay_idle_candidate_registry().lock() else { - return false; - }; + let mut guard = relay_idle_candidate_registry_lock(); let latest_pressure_seq = guard.pressure_event_seq; if latest_pressure_seq == *seen_pressure_seq { @@ -197,13 +207,9 @@ fn maybe_evict_idle_candidate_on_pressure( #[cfg(test)] fn clear_relay_idle_pressure_state_for_testing() { - if let Some(registry) = RELAY_IDLE_CANDIDATE_REGISTRY.get() - && let Ok(mut guard) = registry.lock() - { - guard.by_conn_id.clear(); - guard.ordered.clear(); - guard.pressure_event_seq = 0; - guard.pressure_consumed_seq = 0; + if RELAY_IDLE_CANDIDATE_REGISTRY.get().is_some() { + let mut guard = relay_idle_candidate_registry_lock(); + *guard = RelayIdleCandidateRegistry::default(); } RELAY_IDLE_MARK_SEQ.store(0, Ordering::Relaxed); } @@ -214,6 +220,8 @@ struct MeD2cFlushPolicy { max_bytes: usize, max_delay: Duration, ack_flush_immediate: bool, + quota_soft_overshoot_bytes: u64, + frame_buf_shrink_threshold_bytes: usize, } #[derive(Clone, Copy)] @@ -255,6 +263,7 @@ impl RelayClientIdlePolicy { struct RelayClientIdleState { last_client_frame_at: Instant, soft_idle_marked: bool, + tiny_frame_debt: u32, } impl RelayClientIdleState { @@ -262,6 +271,7 @@ impl RelayClientIdleState { Self { last_client_frame_at: now, soft_idle_marked: false, + tiny_frame_debt: 0, } } @@ -284,6 +294,11 @@ impl MeD2cFlushPolicy { .max(ME_D2C_FLUSH_BATCH_MAX_BYTES_MIN), max_delay: Duration::from_micros(config.general.me_d2c_flush_batch_max_delay_us), ack_flush_immediate: config.general.me_d2c_ack_flush_immediate, + quota_soft_overshoot_bytes: config.general.me_quota_soft_overshoot_bytes, + frame_buf_shrink_threshold_bytes: config + .general + .me_d2c_frame_buf_shrink_threshold_bytes + .max(4096), } } } @@ -302,64 +317,76 @@ fn should_emit_full_desync(key: u64, all_full: bool, now: Instant) -> bool { return true; } - let dedup = DESYNC_DEDUP.get_or_init(DashMap::new); - let saturated_before = dedup.len() >= DESYNC_DEDUP_MAX_ENTRIES; - let ever_saturated = DESYNC_DEDUP_EVER_SATURATED.get_or_init(|| AtomicBool::new(false)); - if saturated_before { - ever_saturated.store(true, Ordering::Relaxed); - } + let dedup_current = DESYNC_DEDUP.get_or_init(DashMap::new); + let dedup_previous = DESYNC_DEDUP_PREVIOUS.get_or_init(DashMap::new); + let rotation_state = + DESYNC_DEDUP_ROTATION_STATE.get_or_init(|| Mutex::new(DesyncDedupRotationState::default())); - if let Some(mut seen_at) = dedup.get_mut(&key) { - if now.duration_since(*seen_at) >= DESYNC_DEDUP_WINDOW { - *seen_at = now; - return true; + let mut state = match rotation_state.lock() { + Ok(guard) => guard, + Err(poisoned) => { + let mut guard = poisoned.into_inner(); + *guard = DesyncDedupRotationState::default(); + rotation_state.clear_poison(); + guard } - return false; - } - - if dedup.len() >= DESYNC_DEDUP_MAX_ENTRIES { - let mut stale_keys = Vec::new(); - let mut oldest_candidate: Option<(u64, Instant)> = None; - for entry in dedup.iter().take(DESYNC_DEDUP_PRUNE_SCAN_LIMIT) { - let key = *entry.key(); - let seen_at = *entry.value(); - - match oldest_candidate { - Some((_, oldest_seen)) if seen_at >= oldest_seen => {} - _ => oldest_candidate = Some((key, seen_at)), - } - - if now.duration_since(seen_at) >= DESYNC_DEDUP_WINDOW { - stale_keys.push(*entry.key()); - } - } - for stale_key in stale_keys { - dedup.remove(&stale_key); - } - if dedup.len() >= DESYNC_DEDUP_MAX_ENTRIES { - let Some((evict_key, _)) = oldest_candidate else { - return false; - }; - dedup.remove(&evict_key); - dedup.insert(key, now); - return should_emit_full_desync_full_cache(now); - } - } - - dedup.insert(key, now); - let saturated_after = dedup.len() >= DESYNC_DEDUP_MAX_ENTRIES; - // Preserve the first sequential insert that reaches capacity as a normal - // emit, while still gating concurrent newcomer churn after the cache has - // ever been observed at saturation. - let was_ever_saturated = if saturated_after { - ever_saturated.swap(true, Ordering::Relaxed) - } else { - ever_saturated.load(Ordering::Relaxed) }; - if saturated_before || (saturated_after && was_ever_saturated) { + let rotate_now = match state.current_started_at { + Some(current_started_at) => match now.checked_duration_since(current_started_at) { + Some(elapsed) => elapsed >= DESYNC_DEDUP_WINDOW, + None => true, + }, + None => true, + }; + if rotate_now { + dedup_previous.clear(); + for entry in dedup_current.iter() { + dedup_previous.insert(*entry.key(), *entry.value()); + } + dedup_current.clear(); + state.current_started_at = Some(now); + } + + if let Some(seen_at) = dedup_current.get(&key).map(|entry| *entry.value()) { + let within_window = match now.checked_duration_since(seen_at) { + Some(elapsed) => elapsed < DESYNC_DEDUP_WINDOW, + None => true, + }; + if within_window { + return false; + } + dedup_current.insert(key, now); + return true; + } + + if let Some(seen_at) = dedup_previous.get(&key).map(|entry| *entry.value()) { + let within_window = match now.checked_duration_since(seen_at) { + Some(elapsed) => elapsed < DESYNC_DEDUP_WINDOW, + None => true, + }; + if within_window { + // Keep the original timestamp when promoting from previous bucket, + // so dedup expiry remains tied to first-seen time. + dedup_current.insert(key, seen_at); + return false; + } + dedup_previous.remove(&key); + } + + if dedup_current.len() >= DESYNC_DEDUP_MAX_ENTRIES { + // Bounded eviction path: rotate buckets instead of scanning/evicting + // arbitrary entries from a saturated single map. + dedup_previous.clear(); + for entry in dedup_current.iter() { + dedup_previous.insert(*entry.key(), *entry.value()); + } + dedup_current.clear(); + state.current_started_at = Some(now); + dedup_current.insert(key, now); should_emit_full_desync_full_cache(now) } else { + dedup_current.insert(key, now); true } } @@ -395,8 +422,20 @@ fn clear_desync_dedup_for_testing() { if let Some(dedup) = DESYNC_DEDUP.get() { dedup.clear(); } - if let Some(ever_saturated) = DESYNC_DEDUP_EVER_SATURATED.get() { - ever_saturated.store(false, Ordering::Relaxed); + if let Some(dedup_previous) = DESYNC_DEDUP_PREVIOUS.get() { + dedup_previous.clear(); + } + if let Some(rotation_state) = DESYNC_DEDUP_ROTATION_STATE.get() { + match rotation_state.lock() { + Ok(mut guard) => { + *guard = DesyncDedupRotationState::default(); + } + Err(poisoned) => { + let mut guard = poisoned.into_inner(); + *guard = DesyncDedupRotationState::default(); + rotation_state.clear_poison(); + } + } } if let Some(last_emit_at) = DESYNC_FULL_CACHE_LAST_EMIT_AT.get() { match last_emit_at.lock() { @@ -522,73 +561,90 @@ fn should_yield_c2me_sender(sent_since_yield: usize, has_backlog: bool) -> bool has_backlog && sent_since_yield >= C2ME_SENDER_FAIRNESS_BUDGET } -fn quota_exceeded_for_user(stats: &Stats, user: &str, quota_limit: Option) -> bool { - quota_limit.is_some_and(|quota| stats.get_user_total_octets(user) >= quota) +fn quota_soft_cap(limit: u64, overshoot: u64) -> u64 { + limit.saturating_add(overshoot) } -fn quota_would_be_exceeded_for_user( - stats: &Stats, - user: &str, - quota_limit: Option, +async fn reserve_user_quota_with_yield( + user_stats: &UserStats, bytes: u64, -) -> bool { - quota_limit.is_some_and(|quota| { - let used = stats.get_user_total_octets(user); - used >= quota || bytes > quota.saturating_sub(used) - }) + limit: u64, +) -> std::result::Result { + loop { + for _ in 0..QUOTA_RESERVE_SPIN_RETRIES { + match user_stats.quota_try_reserve(bytes, limit) { + Ok(total) => return Ok(total), + Err(QuotaReserveError::LimitExceeded) => { + return Err(QuotaReserveError::LimitExceeded); + } + Err(QuotaReserveError::Contended) => std::hint::spin_loop(), + } + } + + tokio::task::yield_now().await; + } +} + +fn classify_me_d2c_flush_reason( + flush_immediately: bool, + batch_frames: usize, + max_frames: usize, + batch_bytes: usize, + max_bytes: usize, + max_delay_fired: bool, +) -> MeD2cFlushReason { + if flush_immediately { + return MeD2cFlushReason::AckImmediate; + } + if batch_frames >= max_frames { + return MeD2cFlushReason::BatchFrames; + } + if batch_bytes >= max_bytes { + return MeD2cFlushReason::BatchBytes; + } + if max_delay_fired { + return MeD2cFlushReason::MaxDelay; + } + MeD2cFlushReason::QueueDrain +} + +fn observe_me_d2c_flush_event( + stats: &Stats, + reason: MeD2cFlushReason, + batch_frames: usize, + batch_bytes: usize, + flush_duration_us: Option, +) { + stats.increment_me_d2c_flush_reason(reason); + if batch_frames > 0 || batch_bytes > 0 { + stats.increment_me_d2c_batches_total(); + stats.add_me_d2c_batch_frames_total(batch_frames as u64); + stats.add_me_d2c_batch_bytes_total(batch_bytes as u64); + stats.observe_me_d2c_batch_frames(batch_frames as u64); + stats.observe_me_d2c_batch_bytes(batch_bytes as u64); + } + if let Some(duration_us) = flush_duration_us { + stats.observe_me_d2c_flush_duration_us(duration_us); + } } #[cfg(test)] -fn quota_user_lock_test_guard() -> &'static Mutex<()> { +fn relay_idle_pressure_test_guard() -> &'static Mutex<()> { static TEST_LOCK: OnceLock> = OnceLock::new(); TEST_LOCK.get_or_init(|| Mutex::new(())) } #[cfg(test)] -fn quota_user_lock_test_scope() -> std::sync::MutexGuard<'static, ()> { - quota_user_lock_test_guard() +pub(crate) fn relay_idle_pressure_test_scope() -> std::sync::MutexGuard<'static, ()> { + relay_idle_pressure_test_guard() .lock() .unwrap_or_else(|poisoned| poisoned.into_inner()) } -fn quota_overflow_user_lock(user: &str) -> Arc> { - let stripes = QUOTA_USER_OVERFLOW_LOCKS.get_or_init(|| { - (0..QUOTA_OVERFLOW_LOCK_STRIPES) - .map(|_| Arc::new(AsyncMutex::new(()))) - .collect() - }); - - let hash = crc32fast::hash(user.as_bytes()) as usize; - Arc::clone(&stripes[hash % stripes.len()]) -} - -fn quota_user_lock(user: &str) -> Arc> { - let locks = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - if let Some(existing) = locks.get(user) { - return Arc::clone(existing.value()); - } - - if locks.len() >= QUOTA_USER_LOCKS_MAX { - locks.retain(|_, value| Arc::strong_count(value) > 1); - } - - if locks.len() >= QUOTA_USER_LOCKS_MAX { - return quota_overflow_user_lock(user); - } - - let created = Arc::new(AsyncMutex::new(())); - match locks.entry(user.to_string()) { - dashmap::mapref::entry::Entry::Occupied(entry) => Arc::clone(entry.get()), - dashmap::mapref::entry::Entry::Vacant(entry) => { - entry.insert(Arc::clone(&created)); - created - } - } -} - async fn enqueue_c2me_command( tx: &mpsc::Sender, cmd: C2MeCommand, + send_timeout: Option, ) -> std::result::Result<(), mpsc::error::SendError> { match tx.try_send(cmd) { Ok(()) => Ok(()), @@ -599,18 +655,34 @@ async fn enqueue_c2me_command( if tx.capacity() <= C2ME_SOFT_PRESSURE_MIN_FREE_SLOTS { tokio::task::yield_now().await; } - match timeout(C2ME_SEND_TIMEOUT, tx.reserve()).await { - Ok(Ok(permit)) => { + let reserve_result = match send_timeout { + Some(send_timeout) => match timeout(send_timeout, tx.reserve()).await { + Ok(result) => result, + Err(_) => return Err(mpsc::error::SendError(cmd)), + }, + None => tx.reserve().await, + }; + match reserve_result { + Ok(permit) => { permit.send(cmd); Ok(()) } - Ok(Err(_)) => Err(mpsc::error::SendError(cmd)), Err(_) => Err(mpsc::error::SendError(cmd)), } } } } +#[cfg(test)] +async fn run_relay_test_step_timeout(context: &'static str, fut: F) -> T +where + F: Future, +{ + timeout(RELAY_TEST_STEP_TIMEOUT, fut) + .await + .unwrap_or_else(|_| panic!("{context} exceeded {}s", RELAY_TEST_STEP_TIMEOUT.as_secs())) +} + pub(crate) async fn handle_via_middle_proxy( mut crypto_reader: CryptoReader, crypto_writer: CryptoWriter, @@ -631,6 +703,7 @@ where { let user = success.user.clone(); let quota_limit = config.access.user_data_quota.get(&user).copied(); + let quota_user_stats = quota_limit.map(|_| stats.get_or_create_user_stats_handle(&user)); let peer = success.peer; let proto_tag = success.proto_tag; let pool_generation = me_pool.current_generation(); @@ -719,6 +792,10 @@ where .general .me_c2me_channel_capacity .max(C2ME_CHANNEL_CAPACITY_FALLBACK); + let c2me_send_timeout = match config.general.me_c2me_send_timeout_ms { + 0 => None, + timeout_ms => Some(Duration::from_millis(timeout_ms)), + }; let (c2me_tx, mut c2me_rx) = mpsc::channel::(c2me_channel_capacity); let me_pool_c2me = me_pool.clone(); let c2me_sender = tokio::spawn(async move { @@ -757,6 +834,7 @@ where let stats_clone = stats.clone(); let rng_clone = rng.clone(); let user_clone = user.clone(); + let quota_user_stats_me_writer = quota_user_stats.clone(); let last_downstream_activity_ms_clone = last_downstream_activity_ms.clone(); let bytes_me2c_clone = bytes_me2c.clone(); let d2c_flush_policy = MeD2cFlushPolicy::from_config(&config); @@ -774,6 +852,7 @@ where let mut batch_frames = 0usize; let mut batch_bytes = 0usize; let mut flush_immediately; + let mut max_delay_fired = false; let first_is_downstream_activity = matches!(&first, MeResponse::Data { .. } | MeResponse::Ack(_)); @@ -785,7 +864,9 @@ where &mut frame_buf, stats_clone.as_ref(), &user_clone, + quota_user_stats_me_writer.as_deref(), quota_limit, + d2c_flush_policy.quota_soft_overshoot_bytes, bytes_me2c_clone.as_ref(), conn_id, d2c_flush_policy.ack_flush_immediate, @@ -801,7 +882,25 @@ where flush_immediately = immediate; } MeWriterResponseOutcome::Close => { + let flush_started_at = if stats_clone.telemetry_policy().me_level.allows_debug() { + Some(Instant::now()) + } else { + None + }; let _ = writer.flush().await; + let flush_duration_us = flush_started_at.map(|started| { + started + .elapsed() + .as_micros() + .min(u128::from(u64::MAX)) as u64 + }); + observe_me_d2c_flush_event( + stats_clone.as_ref(), + MeD2cFlushReason::Close, + batch_frames, + batch_bytes, + flush_duration_us, + ); return Ok(()); } } @@ -824,7 +923,9 @@ where &mut frame_buf, stats_clone.as_ref(), &user_clone, + quota_user_stats_me_writer.as_deref(), quota_limit, + d2c_flush_policy.quota_soft_overshoot_bytes, bytes_me2c_clone.as_ref(), conn_id, d2c_flush_policy.ack_flush_immediate, @@ -840,7 +941,27 @@ where flush_immediately |= immediate; } MeWriterResponseOutcome::Close => { + let flush_started_at = + if stats_clone.telemetry_policy().me_level.allows_debug() { + Some(Instant::now()) + } else { + None + }; let _ = writer.flush().await; + let flush_duration_us = flush_started_at.map(|started| { + started + .elapsed() + .as_micros() + .min(u128::from(u64::MAX)) + as u64 + }); + observe_me_d2c_flush_event( + stats_clone.as_ref(), + MeD2cFlushReason::Close, + batch_frames, + batch_bytes, + flush_duration_us, + ); return Ok(()); } } @@ -851,6 +972,7 @@ where && batch_frames < d2c_flush_policy.max_frames && batch_bytes < d2c_flush_policy.max_bytes { + stats_clone.increment_me_d2c_batch_timeout_armed_total(); match tokio::time::timeout(d2c_flush_policy.max_delay, me_rx_task.recv()).await { Ok(Some(next)) => { let next_is_downstream_activity = @@ -863,7 +985,9 @@ where &mut frame_buf, stats_clone.as_ref(), &user_clone, + quota_user_stats_me_writer.as_deref(), quota_limit, + d2c_flush_policy.quota_soft_overshoot_bytes, bytes_me2c_clone.as_ref(), conn_id, d2c_flush_policy.ack_flush_immediate, @@ -879,7 +1003,30 @@ where flush_immediately |= immediate; } MeWriterResponseOutcome::Close => { + let flush_started_at = if stats_clone + .telemetry_policy() + .me_level + .allows_debug() + { + Some(Instant::now()) + } else { + None + }; let _ = writer.flush().await; + let flush_duration_us = flush_started_at.map(|started| { + started + .elapsed() + .as_micros() + .min(u128::from(u64::MAX)) + as u64 + }); + observe_me_d2c_flush_event( + stats_clone.as_ref(), + MeD2cFlushReason::Close, + batch_frames, + batch_bytes, + flush_duration_us, + ); return Ok(()); } } @@ -902,7 +1049,9 @@ where &mut frame_buf, stats_clone.as_ref(), &user_clone, + quota_user_stats_me_writer.as_deref(), quota_limit, + d2c_flush_policy.quota_soft_overshoot_bytes, bytes_me2c_clone.as_ref(), conn_id, d2c_flush_policy.ack_flush_immediate, @@ -918,7 +1067,30 @@ where flush_immediately |= immediate; } MeWriterResponseOutcome::Close => { + let flush_started_at = if stats_clone + .telemetry_policy() + .me_level + .allows_debug() + { + Some(Instant::now()) + } else { + None + }; let _ = writer.flush().await; + let flush_duration_us = flush_started_at.map(|started| { + started + .elapsed() + .as_micros() + .min(u128::from(u64::MAX)) + as u64 + }); + observe_me_d2c_flush_event( + stats_clone.as_ref(), + MeD2cFlushReason::Close, + batch_frames, + batch_bytes, + flush_duration_us, + ); return Ok(()); } } @@ -928,11 +1100,50 @@ where debug!(conn_id, "ME channel closed"); return Err(ProxyError::Proxy("ME connection lost".into())); } - Err(_) => {} + Err(_) => { + max_delay_fired = true; + stats_clone.increment_me_d2c_batch_timeout_fired_total(); + } } } + let flush_reason = classify_me_d2c_flush_reason( + flush_immediately, + batch_frames, + d2c_flush_policy.max_frames, + batch_bytes, + d2c_flush_policy.max_bytes, + max_delay_fired, + ); + let flush_started_at = if stats_clone.telemetry_policy().me_level.allows_debug() { + Some(Instant::now()) + } else { + None + }; writer.flush().await.map_err(ProxyError::Io)?; + let flush_duration_us = flush_started_at.map(|started| { + started + .elapsed() + .as_micros() + .min(u128::from(u64::MAX)) as u64 + }); + observe_me_d2c_flush_event( + stats_clone.as_ref(), + flush_reason, + batch_frames, + batch_bytes, + flush_duration_us, + ); + let shrink_threshold = d2c_flush_policy.frame_buf_shrink_threshold_bytes; + let shrink_trigger = shrink_threshold + .saturating_mul(ME_D2C_FRAME_BUF_SHRINK_HYSTERESIS_FACTOR); + if frame_buf.capacity() > shrink_trigger { + let cap_before = frame_buf.capacity(); + frame_buf.shrink_to(shrink_threshold); + let cap_after = frame_buf.capacity(); + let bytes_freed = cap_before.saturating_sub(cap_after) as u64; + stats_clone.observe_me_d2c_frame_buf_shrink(bytes_freed); + } } _ = &mut stop_rx => { debug!(conn_id, "ME writer stop signal"); @@ -961,7 +1172,7 @@ where user = %user, "Middle-relay pressure eviction for idle-candidate session" ); - let _ = enqueue_c2me_command(&c2me_tx, C2MeCommand::Close).await; + let _ = enqueue_c2me_command(&c2me_tx, C2MeCommand::Close, c2me_send_timeout).await; main_result = Err(ProxyError::Proxy( "middle-relay session evicted under pressure (idle-candidate)".to_string(), )); @@ -980,7 +1191,7 @@ where "Cutover affected middle session, closing client connection" ); tokio::time::sleep(delay).await; - let _ = enqueue_c2me_command(&c2me_tx, C2MeCommand::Close).await; + let _ = enqueue_c2me_command(&c2me_tx, C2MeCommand::Close, c2me_send_timeout).await; main_result = Err(ProxyError::Proxy(ROUTE_SWITCH_ERROR_MSG.to_string())); break; } @@ -1010,16 +1221,23 @@ where forensics.bytes_c2me = forensics .bytes_c2me .saturating_add(payload.len() as u64); - if let Some(limit) = quota_limit { - let quota_lock = quota_user_lock(&user); - let _quota_guard = quota_lock.lock().await; - stats.add_user_octets_from(&user, payload.len() as u64); - if quota_exceeded_for_user(stats.as_ref(), &user, Some(limit)) { + if let (Some(limit), Some(user_stats)) = + (quota_limit, quota_user_stats.as_deref()) + { + if reserve_user_quota_with_yield( + user_stats, + payload.len() as u64, + limit, + ) + .await + .is_err() + { main_result = Err(ProxyError::DataQuotaExceeded { user: user.clone(), }); break; } + stats.add_user_octets_from_handle(user_stats, payload.len() as u64); } else { stats.add_user_octets_from(&user, payload.len() as u64); } @@ -1031,8 +1249,12 @@ where flags |= RPC_FLAG_NOT_ENCRYPTED; } // Keep client read loop lightweight: route heavy ME send path via a dedicated task. - if enqueue_c2me_command(&c2me_tx, C2MeCommand::Data { payload, flags }) - .await + if enqueue_c2me_command( + &c2me_tx, + C2MeCommand::Data { payload, flags }, + c2me_send_timeout, + ) + .await .is_err() { main_result = Err(ProxyError::Proxy("ME sender channel closed".into())); @@ -1042,7 +1264,9 @@ where Ok(None) => { debug!(conn_id, "Client EOF"); client_closed = true; - let _ = enqueue_c2me_command(&c2me_tx, C2MeCommand::Close).await; + let _ = + enqueue_c2me_command(&c2me_tx, C2MeCommand::Close, c2me_send_timeout) + .await; break; } Err(e) => { @@ -1112,6 +1336,8 @@ async fn read_client_payload_with_idle_policy( where R: AsyncRead + Unpin + Send + 'static, { + const LEGACY_MAX_CONSECUTIVE_ZERO_LEN_FRAMES: u32 = 4; + async fn read_exact_with_policy( client_reader: &mut CryptoReader, buf: &mut [u8], @@ -1250,6 +1476,7 @@ where Ok(()) } + let mut consecutive_zero_len_frames = 0u32; loop { let (len, quickack, raw_len_bytes) = match proto_tag { ProtoTag::Abridged => { @@ -1330,6 +1557,26 @@ where }; if len == 0 { + idle_state.tiny_frame_debt = idle_state + .tiny_frame_debt + .saturating_add(TINY_FRAME_DEBT_PER_TINY); + if idle_state.tiny_frame_debt >= TINY_FRAME_DEBT_LIMIT { + stats.increment_relay_protocol_desync_close_total(); + return Err(ProxyError::Proxy(format!( + "Tiny frame overhead limit exceeded: debt={}, conn_id={}", + idle_state.tiny_frame_debt, forensics.conn_id + ))); + } + + if !idle_policy.enabled { + consecutive_zero_len_frames = consecutive_zero_len_frames.saturating_add(1); + if consecutive_zero_len_frames > LEGACY_MAX_CONSECUTIVE_ZERO_LEN_FRAMES { + stats.increment_relay_protocol_desync_close_total(); + return Err(ProxyError::Proxy( + "Excessive zero-length abridged frames".to_string(), + )); + } + } continue; } if len < 4 && proto_tag != ProtoTag::Abridged { @@ -1398,6 +1645,7 @@ where } *frame_counter += 1; idle_state.on_client_frame(Instant::now()); + idle_state.tiny_frame_debt = idle_state.tiny_frame_debt.saturating_sub(1); clear_relay_idle_candidate(forensics.conn_id); return Ok(Some((payload, quickack))); } @@ -1481,7 +1729,9 @@ async fn process_me_writer_response( frame_buf: &mut Vec, stats: &Stats, user: &str, + quota_user_stats: Option<&UserStats>, quota_limit: Option, + quota_soft_overshoot_bytes: u64, bytes_me2c: &AtomicU64, conn_id: u64, ack_flush_immediate: bool, @@ -1498,33 +1748,43 @@ where trace!(conn_id, bytes = data.len(), flags, "ME->C data"); } let data_len = data.len() as u64; - if let Some(limit) = quota_limit { - let quota_lock = quota_user_lock(user); - let _quota_guard = quota_lock.lock().await; - if quota_would_be_exceeded_for_user(stats, user, Some(limit), data_len) { + if let (Some(limit), Some(user_stats)) = (quota_limit, quota_user_stats) { + let soft_limit = quota_soft_cap(limit, quota_soft_overshoot_bytes); + if reserve_user_quota_with_yield(user_stats, data_len, soft_limit) + .await + .is_err() + { + stats.increment_me_d2c_quota_reject_total(MeD2cQuotaRejectStage::PreWrite); return Err(ProxyError::DataQuotaExceeded { user: user.to_string(), }); } - write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf) - .await?; - - bytes_me2c.fetch_add(data.len() as u64, Ordering::Relaxed); - stats.add_user_octets_to(user, data.len() as u64); - - if quota_exceeded_for_user(stats, user, Some(limit)) { - return Err(ProxyError::DataQuotaExceeded { - user: user.to_string(), - }); - } - } else { - write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf) - .await?; - - bytes_me2c.fetch_add(data.len() as u64, Ordering::Relaxed); - stats.add_user_octets_to(user, data.len() as u64); } + let write_mode = + match write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf) + .await + { + Ok(mode) => mode, + Err(err) => { + if quota_limit.is_some() { + stats.add_quota_write_fail_bytes_total(data_len); + stats.increment_quota_write_fail_events_total(); + } + return Err(err); + } + }; + + bytes_me2c.fetch_add(data_len, Ordering::Relaxed); + if let Some(user_stats) = quota_user_stats { + stats.add_user_octets_to_handle(user_stats, data_len); + } else { + stats.add_user_octets_to(user, data_len); + } + stats.increment_me_d2c_data_frames_total(); + stats.add_me_d2c_payload_bytes_total(data_len); + stats.increment_me_d2c_write_mode(write_mode); + Ok(MeWriterResponseOutcome::Continue { frames: 1, bytes: data.len(), @@ -1538,6 +1798,7 @@ where trace!(conn_id, confirm, "ME->C quickack"); } write_client_ack(client_writer, proto_tag, confirm).await?; + stats.increment_me_d2c_ack_frames_total(); Ok(MeWriterResponseOutcome::Continue { frames: 1, @@ -1588,13 +1849,13 @@ async fn write_client_payload( data: &[u8], rng: &SecureRandom, frame_buf: &mut Vec, -) -> Result<()> +) -> Result where W: AsyncWrite + Unpin + Send + 'static, { let quickack = (flags & RPC_FLAG_QUICKACK) != 0; - match proto_tag { + let write_mode = match proto_tag { ProtoTag::Abridged => { if !data.len().is_multiple_of(4) { return Err(ProxyError::Proxy(format!( @@ -1609,28 +1870,58 @@ where if quickack { first |= 0x80; } - frame_buf.clear(); - frame_buf.reserve(1 + data.len()); - frame_buf.push(first); - frame_buf.extend_from_slice(data); - client_writer - .write_all(frame_buf) - .await - .map_err(ProxyError::Io)?; + let wire_len = 1usize.saturating_add(data.len()); + if wire_len <= ME_D2C_SINGLE_WRITE_COALESCE_MAX_BYTES { + frame_buf.clear(); + frame_buf.reserve(wire_len); + frame_buf.push(first); + frame_buf.extend_from_slice(data); + client_writer + .write_all(frame_buf.as_slice()) + .await + .map_err(ProxyError::Io)?; + MeD2cWriteMode::Coalesced + } else { + let header = [first]; + client_writer + .write_all(&header) + .await + .map_err(ProxyError::Io)?; + client_writer + .write_all(data) + .await + .map_err(ProxyError::Io)?; + MeD2cWriteMode::Split + } } else if len_words < (1 << 24) { let mut first = 0x7fu8; if quickack { first |= 0x80; } let lw = (len_words as u32).to_le_bytes(); - frame_buf.clear(); - frame_buf.reserve(4 + data.len()); - frame_buf.extend_from_slice(&[first, lw[0], lw[1], lw[2]]); - frame_buf.extend_from_slice(data); - client_writer - .write_all(frame_buf) - .await - .map_err(ProxyError::Io)?; + let wire_len = 4usize.saturating_add(data.len()); + if wire_len <= ME_D2C_SINGLE_WRITE_COALESCE_MAX_BYTES { + frame_buf.clear(); + frame_buf.reserve(wire_len); + frame_buf.extend_from_slice(&[first, lw[0], lw[1], lw[2]]); + frame_buf.extend_from_slice(data); + client_writer + .write_all(frame_buf.as_slice()) + .await + .map_err(ProxyError::Io)?; + MeD2cWriteMode::Coalesced + } else { + let header = [first, lw[0], lw[1], lw[2]]; + client_writer + .write_all(&header) + .await + .map_err(ProxyError::Io)?; + client_writer + .write_all(data) + .await + .map_err(ProxyError::Io)?; + MeD2cWriteMode::Split + } } else { return Err(ProxyError::Proxy(format!( "Abridged frame too large: {}", @@ -1650,25 +1941,52 @@ where } else { 0 }; + let (len_val, total) = compute_intermediate_secure_wire_len(data.len(), padding_len, quickack)?; - frame_buf.clear(); - frame_buf.reserve(total); - frame_buf.extend_from_slice(&len_val.to_le_bytes()); - frame_buf.extend_from_slice(data); - if padding_len > 0 { - let start = frame_buf.len(); - frame_buf.resize(start + padding_len, 0); - rng.fill(&mut frame_buf[start..]); + if total <= ME_D2C_SINGLE_WRITE_COALESCE_MAX_BYTES { + frame_buf.clear(); + frame_buf.reserve(total); + frame_buf.extend_from_slice(&len_val.to_le_bytes()); + frame_buf.extend_from_slice(data); + if padding_len > 0 { + let start = frame_buf.len(); + frame_buf.resize(start + padding_len, 0); + rng.fill(&mut frame_buf[start..]); + } + client_writer + .write_all(frame_buf.as_slice()) + .await + .map_err(ProxyError::Io)?; + MeD2cWriteMode::Coalesced + } else { + let header = len_val.to_le_bytes(); + client_writer + .write_all(&header) + .await + .map_err(ProxyError::Io)?; + client_writer + .write_all(data) + .await + .map_err(ProxyError::Io)?; + if padding_len > 0 { + frame_buf.clear(); + if frame_buf.capacity() < padding_len { + frame_buf.reserve(padding_len); + } + frame_buf.resize(padding_len, 0); + rng.fill(frame_buf.as_mut_slice()); + client_writer + .write_all(frame_buf.as_slice()) + .await + .map_err(ProxyError::Io)?; + } + MeD2cWriteMode::Split } - client_writer - .write_all(frame_buf) - .await - .map_err(ProxyError::Io)?; } - } + }; - Ok(()) + Ok(write_mode) } async fn write_client_ack( @@ -1690,10 +2008,6 @@ where .map_err(ProxyError::Io) } -#[cfg(test)] -#[path = "tests/middle_relay_security_tests.rs"] -mod security_tests; - #[cfg(test)] #[path = "tests/middle_relay_idle_policy_security_tests.rs"] mod idle_policy_security_tests; @@ -1706,18 +2020,30 @@ mod desync_all_full_dedup_security_tests; #[path = "tests/middle_relay_stub_completion_security_tests.rs"] mod stub_completion_security_tests; -#[cfg(test)] -#[path = "tests/middle_relay_coverage_high_risk_security_tests.rs"] -mod coverage_high_risk_security_tests; - -#[cfg(test)] -#[path = "tests/middle_relay_quota_overflow_lock_security_tests.rs"] -mod quota_overflow_lock_security_tests; - #[cfg(test)] #[path = "tests/middle_relay_length_cast_hardening_security_tests.rs"] mod length_cast_hardening_security_tests; #[cfg(test)] -#[path = "tests/middle_relay_blackhat_campaign_integration_tests.rs"] -mod blackhat_campaign_integration_tests; +#[path = "tests/middle_relay_idle_registry_poison_security_tests.rs"] +mod middle_relay_idle_registry_poison_security_tests; + +#[cfg(test)] +#[path = "tests/middle_relay_zero_length_frame_security_tests.rs"] +mod middle_relay_zero_length_frame_security_tests; + +#[cfg(test)] +#[path = "tests/middle_relay_tiny_frame_debt_security_tests.rs"] +mod middle_relay_tiny_frame_debt_security_tests; + +#[cfg(test)] +#[path = "tests/middle_relay_tiny_frame_debt_concurrency_security_tests.rs"] +mod middle_relay_tiny_frame_debt_concurrency_security_tests; + +#[cfg(test)] +#[path = "tests/middle_relay_tiny_frame_debt_proto_chunking_security_tests.rs"] +mod middle_relay_tiny_frame_debt_proto_chunking_security_tests; + +#[cfg(test)] +#[path = "tests/middle_relay_atomic_quota_invariant_tests.rs"] +mod middle_relay_atomic_quota_invariant_tests; diff --git a/src/proxy/mod.rs b/src/proxy/mod.rs index eebc188..5880558 100644 --- a/src/proxy/mod.rs +++ b/src/proxy/mod.rs @@ -4,58 +4,58 @@ #![cfg_attr(test, allow(warnings))] #![cfg_attr(not(test), forbid(clippy::undocumented_unsafe_blocks))] #![cfg_attr( - not(test), - deny( - clippy::unwrap_used, - clippy::expect_used, - clippy::panic, - clippy::todo, - clippy::unimplemented, - clippy::correctness, - clippy::option_if_let_else, - clippy::or_fun_call, - clippy::branches_sharing_code, - clippy::single_option_map, - clippy::useless_let_if_seq, - clippy::redundant_locals, - clippy::cloned_ref_to_slice_refs, - unsafe_code, - clippy::await_holding_lock, - clippy::await_holding_refcell_ref, - clippy::debug_assert_with_mut_call, - clippy::macro_use_imports, - clippy::cast_ptr_alignment, - clippy::cast_lossless, - clippy::ptr_as_ptr, - clippy::large_stack_arrays, - clippy::same_functions_in_if_condition, - trivial_casts, - trivial_numeric_casts, - unused_extern_crates, - unused_import_braces, - rust_2018_idioms - ) + not(test), + deny( + clippy::unwrap_used, + clippy::expect_used, + clippy::panic, + clippy::todo, + clippy::unimplemented, + clippy::correctness, + clippy::option_if_let_else, + clippy::or_fun_call, + clippy::branches_sharing_code, + clippy::single_option_map, + clippy::useless_let_if_seq, + clippy::redundant_locals, + clippy::cloned_ref_to_slice_refs, + unsafe_code, + clippy::await_holding_lock, + clippy::await_holding_refcell_ref, + clippy::debug_assert_with_mut_call, + clippy::macro_use_imports, + clippy::cast_ptr_alignment, + clippy::cast_lossless, + clippy::ptr_as_ptr, + clippy::large_stack_arrays, + clippy::same_functions_in_if_condition, + trivial_casts, + trivial_numeric_casts, + unused_extern_crates, + unused_import_braces, + rust_2018_idioms + ) )] #![cfg_attr( - not(test), - allow( - clippy::use_self, - clippy::redundant_closure, - clippy::too_many_arguments, - clippy::doc_markdown, - clippy::missing_const_for_fn, - clippy::unnecessary_operation, - clippy::redundant_pub_crate, - clippy::derive_partial_eq_without_eq, - clippy::type_complexity, - clippy::new_ret_no_self, - clippy::cast_possible_truncation, - clippy::cast_possible_wrap, - clippy::significant_drop_tightening, - clippy::significant_drop_in_scrutinee, - clippy::float_cmp, - clippy::nursery - ) + not(test), + allow( + clippy::use_self, + clippy::redundant_closure, + clippy::too_many_arguments, + clippy::doc_markdown, + clippy::missing_const_for_fn, + clippy::unnecessary_operation, + clippy::redundant_pub_crate, + clippy::derive_partial_eq_without_eq, + clippy::type_complexity, + clippy::new_ret_no_self, + clippy::cast_possible_truncation, + clippy::cast_possible_wrap, + clippy::significant_drop_tightening, + clippy::significant_drop_in_scrutinee, + clippy::float_cmp, + clippy::nursery + ) )] pub mod adaptive_buffers; diff --git a/src/proxy/relay.rs b/src/proxy/relay.rs index 2431ff4..6000e18 100644 --- a/src/proxy/relay.rs +++ b/src/proxy/relay.rs @@ -52,13 +52,12 @@ //! - `SharedCounters` (atomics) let the watchdog read stats without locking use crate::error::{ProxyError, Result}; -use crate::stats::Stats; +use crate::stats::{Stats, UserStats}; use crate::stream::BufferPool; -use dashmap::DashMap; use std::io; use std::pin::Pin; +use std::sync::Arc; use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; -use std::sync::{Arc, Mutex, OnceLock}; use std::task::{Context, Poll}; use std::time::Duration; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf, copy_bidirectional_with_sizes}; @@ -209,12 +208,10 @@ struct StatsIo { counters: Arc, stats: Arc, user: String, + user_stats: Arc, quota_limit: Option, quota_exceeded: Arc, - quota_read_wake_scheduled: bool, - quota_write_wake_scheduled: bool, - quota_read_retry_active: Arc, - quota_write_retry_active: Arc, + quota_bytes_since_check: u64, epoch: Instant, } @@ -230,30 +227,21 @@ impl StatsIo { ) -> Self { // Mark initial activity so the watchdog doesn't fire before data flows counters.touch(Instant::now(), epoch); + let user_stats = stats.get_or_create_user_stats_handle(&user); Self { inner, counters, stats, user, + user_stats, quota_limit, quota_exceeded, - quota_read_wake_scheduled: false, - quota_write_wake_scheduled: false, - quota_read_retry_active: Arc::new(AtomicBool::new(false)), - quota_write_retry_active: Arc::new(AtomicBool::new(false)), + quota_bytes_since_check: 0, epoch, } } } -impl Drop for StatsIo { - fn drop(&mut self) { - self.quota_read_retry_active.store(false, Ordering::Relaxed); - self.quota_write_retry_active - .store(false, Ordering::Relaxed); - } -} - #[derive(Debug)] struct QuotaIoSentinel; @@ -277,84 +265,22 @@ fn is_quota_io_error(err: &io::Error) -> bool { .is_some() } -#[cfg(test)] -const QUOTA_CONTENTION_RETRY_INTERVAL: Duration = Duration::from_millis(1); -#[cfg(not(test))] -const QUOTA_CONTENTION_RETRY_INTERVAL: Duration = Duration::from_millis(2); +const QUOTA_NEAR_LIMIT_BYTES: u64 = 64 * 1024; +const QUOTA_LARGE_CHARGE_BYTES: u64 = 16 * 1024; +const QUOTA_ADAPTIVE_INTERVAL_MIN_BYTES: u64 = 4 * 1024; +const QUOTA_ADAPTIVE_INTERVAL_MAX_BYTES: u64 = 64 * 1024; -fn spawn_quota_retry_waker(retry_active: Arc, waker: std::task::Waker) { - tokio::task::spawn(async move { - loop { - if !retry_active.load(Ordering::Relaxed) { - break; - } - tokio::time::sleep(QUOTA_CONTENTION_RETRY_INTERVAL).await; - if !retry_active.load(Ordering::Relaxed) { - break; - } - waker.wake_by_ref(); - } - }); +#[inline] +fn quota_adaptive_interval_bytes(remaining_before: u64) -> u64 { + remaining_before.saturating_div(2).clamp( + QUOTA_ADAPTIVE_INTERVAL_MIN_BYTES, + QUOTA_ADAPTIVE_INTERVAL_MAX_BYTES, + ) } -static QUOTA_USER_LOCKS: OnceLock>>> = OnceLock::new(); -static QUOTA_USER_OVERFLOW_LOCKS: OnceLock>>> = OnceLock::new(); - -#[cfg(test)] -const QUOTA_USER_LOCKS_MAX: usize = 64; -#[cfg(not(test))] -const QUOTA_USER_LOCKS_MAX: usize = 4_096; -#[cfg(test)] -const QUOTA_OVERFLOW_LOCK_STRIPES: usize = 16; -#[cfg(not(test))] -const QUOTA_OVERFLOW_LOCK_STRIPES: usize = 256; - -#[cfg(test)] -fn quota_user_lock_test_guard() -> &'static Mutex<()> { - static TEST_LOCK: OnceLock> = OnceLock::new(); - TEST_LOCK.get_or_init(|| Mutex::new(())) -} - -#[cfg(test)] -fn quota_user_lock_test_scope() -> std::sync::MutexGuard<'static, ()> { - quota_user_lock_test_guard() - .lock() - .unwrap_or_else(|poisoned| poisoned.into_inner()) -} - -fn quota_overflow_user_lock(user: &str) -> Arc> { - let stripes = QUOTA_USER_OVERFLOW_LOCKS.get_or_init(|| { - (0..QUOTA_OVERFLOW_LOCK_STRIPES) - .map(|_| Arc::new(Mutex::new(()))) - .collect() - }); - - let hash = crc32fast::hash(user.as_bytes()) as usize; - Arc::clone(&stripes[hash % stripes.len()]) -} - -fn quota_user_lock(user: &str) -> Arc> { - let locks = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - if let Some(existing) = locks.get(user) { - return Arc::clone(existing.value()); - } - - if locks.len() >= QUOTA_USER_LOCKS_MAX { - locks.retain(|_, value| Arc::strong_count(value) > 1); - } - - if locks.len() >= QUOTA_USER_LOCKS_MAX { - return quota_overflow_user_lock(user); - } - - let created = Arc::new(Mutex::new(())); - match locks.entry(user.to_string()) { - dashmap::mapref::entry::Entry::Occupied(entry) => Arc::clone(entry.get()), - dashmap::mapref::entry::Entry::Vacant(entry) => { - entry.insert(Arc::clone(&created)); - created - } - } +#[inline] +fn should_immediate_quota_check(remaining_before: u64, charge_bytes: u64) -> bool { + remaining_before <= QUOTA_NEAR_LIMIT_BYTES || charge_bytes >= QUOTA_LARGE_CHARGE_BYTES } impl AsyncRead for StatsIo { @@ -364,80 +290,60 @@ impl AsyncRead for StatsIo { buf: &mut ReadBuf<'_>, ) -> Poll> { let this = self.get_mut(); - if this.quota_exceeded.load(Ordering::Relaxed) { + if this.quota_exceeded.load(Ordering::Acquire) { return Poll::Ready(Err(quota_io_error())); } - let quota_lock = this - .quota_limit - .is_some() - .then(|| quota_user_lock(&this.user)); - let _quota_guard = if let Some(lock) = quota_lock.as_ref() { - match lock.try_lock() { - Ok(guard) => { - this.quota_read_wake_scheduled = false; - this.quota_read_retry_active.store(false, Ordering::Relaxed); - Some(guard) - } - Err(_) => { - if !this.quota_read_wake_scheduled { - this.quota_read_wake_scheduled = true; - this.quota_read_retry_active.store(true, Ordering::Relaxed); - spawn_quota_retry_waker( - Arc::clone(&this.quota_read_retry_active), - cx.waker().clone(), - ); - } - return Poll::Pending; - } + let mut remaining_before = None; + if let Some(limit) = this.quota_limit { + let used_before = this.user_stats.quota_used(); + let remaining = limit.saturating_sub(used_before); + if remaining == 0 { + this.quota_exceeded.store(true, Ordering::Release); + return Poll::Ready(Err(quota_io_error())); } - } else { - None - }; - - if let Some(limit) = this.quota_limit - && this.stats.get_user_total_octets(&this.user) >= limit - { - this.quota_exceeded.store(true, Ordering::Relaxed); - return Poll::Ready(Err(quota_io_error())); + remaining_before = Some(remaining); } + let before = buf.filled().len(); match Pin::new(&mut this.inner).poll_read(cx, buf) { Poll::Ready(Ok(())) => { let n = buf.filled().len() - before; if n > 0 { - let mut reached_quota_boundary = false; - if let Some(limit) = this.quota_limit { - let used = this.stats.get_user_total_octets(&this.user); - if used >= limit { - this.quota_exceeded.store(true, Ordering::Relaxed); - return Poll::Ready(Err(quota_io_error())); - } - - let remaining = limit - used; - if (n as u64) > remaining { - // Fail closed: when a single read chunk would cross quota, - // stop relay immediately without accounting beyond the cap. - this.quota_exceeded.store(true, Ordering::Relaxed); - return Poll::Ready(Err(quota_io_error())); - } - - reached_quota_boundary = (n as u64) == remaining; - } + let n_to_charge = n as u64; // C→S: client sent data this.counters .c2s_bytes - .fetch_add(n as u64, Ordering::Relaxed); + .fetch_add(n_to_charge, Ordering::Relaxed); this.counters.c2s_ops.fetch_add(1, Ordering::Relaxed); this.counters.touch(Instant::now(), this.epoch); - this.stats.add_user_octets_from(&this.user, n as u64); - this.stats.increment_user_msgs_from(&this.user); + this.stats + .add_user_octets_from_handle(this.user_stats.as_ref(), n_to_charge); + this.stats + .increment_user_msgs_from_handle(this.user_stats.as_ref()); - if reached_quota_boundary { - this.quota_exceeded.store(true, Ordering::Relaxed); + if let (Some(limit), Some(remaining)) = (this.quota_limit, remaining_before) { + this.stats + .quota_charge_post_write(this.user_stats.as_ref(), n_to_charge); + if should_immediate_quota_check(remaining, n_to_charge) { + this.quota_bytes_since_check = 0; + if this.user_stats.quota_used() >= limit { + this.quota_exceeded.store(true, Ordering::Release); + } + } else { + this.quota_bytes_since_check = + this.quota_bytes_since_check.saturating_add(n_to_charge); + let interval = quota_adaptive_interval_bytes(remaining); + if this.quota_bytes_since_check >= interval { + this.quota_bytes_since_check = 0; + if this.user_stats.quota_used() >= limit { + this.quota_exceeded.store(true, Ordering::Release); + } + } + } } trace!(user = %this.user, bytes = n, "C->S"); @@ -456,75 +362,57 @@ impl AsyncWrite for StatsIo { buf: &[u8], ) -> Poll> { let this = self.get_mut(); - if this.quota_exceeded.load(Ordering::Relaxed) { + if this.quota_exceeded.load(Ordering::Acquire) { return Poll::Ready(Err(quota_io_error())); } - let quota_lock = this - .quota_limit - .is_some() - .then(|| quota_user_lock(&this.user)); - let _quota_guard = if let Some(lock) = quota_lock.as_ref() { - match lock.try_lock() { - Ok(guard) => { - this.quota_write_wake_scheduled = false; - this.quota_write_retry_active - .store(false, Ordering::Relaxed); - Some(guard) - } - Err(_) => { - if !this.quota_write_wake_scheduled { - this.quota_write_wake_scheduled = true; - this.quota_write_retry_active.store(true, Ordering::Relaxed); - spawn_quota_retry_waker( - Arc::clone(&this.quota_write_retry_active), - cx.waker().clone(), - ); - } - return Poll::Pending; - } - } - } else { - None - }; - - let write_buf = if let Some(limit) = this.quota_limit { - let used = this.stats.get_user_total_octets(&this.user); - if used >= limit { - this.quota_exceeded.store(true, Ordering::Relaxed); + let mut remaining_before = None; + if let Some(limit) = this.quota_limit { + let used_before = this.user_stats.quota_used(); + let remaining = limit.saturating_sub(used_before); + if remaining == 0 { + this.quota_exceeded.store(true, Ordering::Release); return Poll::Ready(Err(quota_io_error())); } + remaining_before = Some(remaining); + } - let remaining = (limit - used) as usize; - if buf.len() > remaining { - // Fail closed: do not emit partial S->C payload when remaining - // quota cannot accommodate the pending write request. - this.quota_exceeded.store(true, Ordering::Relaxed); - return Poll::Ready(Err(quota_io_error())); - } - buf - } else { - buf - }; - - match Pin::new(&mut this.inner).poll_write(cx, write_buf) { + match Pin::new(&mut this.inner).poll_write(cx, buf) { Poll::Ready(Ok(n)) => { if n > 0 { + let n_to_charge = n as u64; + // S→C: data written to client this.counters .s2c_bytes - .fetch_add(n as u64, Ordering::Relaxed); + .fetch_add(n_to_charge, Ordering::Relaxed); this.counters.s2c_ops.fetch_add(1, Ordering::Relaxed); this.counters.touch(Instant::now(), this.epoch); - this.stats.add_user_octets_to(&this.user, n as u64); - this.stats.increment_user_msgs_to(&this.user); + this.stats + .add_user_octets_to_handle(this.user_stats.as_ref(), n_to_charge); + this.stats + .increment_user_msgs_to_handle(this.user_stats.as_ref()); - if let Some(limit) = this.quota_limit - && this.stats.get_user_total_octets(&this.user) >= limit - { - this.quota_exceeded.store(true, Ordering::Relaxed); - return Poll::Ready(Err(quota_io_error())); + if let (Some(limit), Some(remaining)) = (this.quota_limit, remaining_before) { + this.stats + .quota_charge_post_write(this.user_stats.as_ref(), n_to_charge); + if should_immediate_quota_check(remaining, n_to_charge) { + this.quota_bytes_since_check = 0; + if this.user_stats.quota_used() >= limit { + this.quota_exceeded.store(true, Ordering::Release); + } + } else { + this.quota_bytes_since_check = + this.quota_bytes_since_check.saturating_add(n_to_charge); + let interval = quota_adaptive_interval_bytes(remaining); + if this.quota_bytes_since_check >= interval { + this.quota_bytes_since_check = 0; + if this.user_stats.quota_used() >= limit { + this.quota_exceeded.store(true, Ordering::Release); + } + } + } } trace!(user = %this.user, bytes = n, "S->C"); @@ -618,7 +506,7 @@ where let now = Instant::now(); let idle = wd_counters.idle_duration(now, epoch); - if wd_quota_exceeded.load(Ordering::Relaxed) { + if wd_quota_exceeded.load(Ordering::Acquire) { warn!(user = %wd_user, "User data quota reached, closing relay"); return; } @@ -756,18 +644,10 @@ where } } -#[cfg(test)] -#[path = "tests/relay_security_tests.rs"] -mod security_tests; - #[cfg(test)] #[path = "tests/relay_adversarial_tests.rs"] mod adversarial_tests; -#[cfg(test)] -#[path = "tests/relay_quota_lock_pressure_adversarial_tests.rs"] -mod relay_quota_lock_pressure_adversarial_tests; - #[cfg(test)] #[path = "tests/relay_quota_boundary_blackhat_tests.rs"] mod relay_quota_boundary_blackhat_tests; @@ -780,14 +660,14 @@ mod relay_quota_model_adversarial_tests; #[path = "tests/relay_quota_overflow_regression_tests.rs"] mod relay_quota_overflow_regression_tests; +#[cfg(test)] +#[path = "tests/relay_quota_extended_attack_surface_security_tests.rs"] +mod relay_quota_extended_attack_surface_security_tests; + #[cfg(test)] #[path = "tests/relay_watchdog_delta_security_tests.rs"] mod relay_watchdog_delta_security_tests; #[cfg(test)] -#[path = "tests/relay_quota_waker_storm_adversarial_tests.rs"] -mod relay_quota_waker_storm_adversarial_tests; - -#[cfg(test)] -#[path = "tests/relay_quota_wake_liveness_regression_tests.rs"] -mod relay_quota_wake_liveness_regression_tests; +#[path = "tests/relay_atomic_quota_invariant_tests.rs"] +mod relay_atomic_quota_invariant_tests; diff --git a/src/proxy/tests/client_clever_advanced_tests.rs b/src/proxy/tests/client_clever_advanced_tests.rs new file mode 100644 index 0000000..f462ed8 --- /dev/null +++ b/src/proxy/tests/client_clever_advanced_tests.rs @@ -0,0 +1,467 @@ +use super::*; +use crate::config::{ProxyConfig, UpstreamConfig, UpstreamType}; +use crate::protocol::constants::{MAX_TLS_PLAINTEXT_SIZE, MIN_TLS_CLIENT_HELLO_SIZE}; +use crate::stats::Stats; +use crate::transport::UpstreamManager; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; +use std::time::Duration; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt, ReadBuf, duplex}; +use tokio::net::TcpListener; + +#[test] +fn edge_mask_reject_delay_min_greater_than_max_does_not_panic() { + let mut config = ProxyConfig::default(); + config.censorship.server_hello_delay_min_ms = 5000; + config.censorship.server_hello_delay_max_ms = 1000; + + let rt = tokio::runtime::Runtime::new().unwrap(); + rt.block_on(async { + let start = std::time::Instant::now(); + maybe_apply_mask_reject_delay(&config).await; + let elapsed = start.elapsed(); + + assert!(elapsed >= Duration::from_millis(1000)); + assert!(elapsed < Duration::from_millis(1500)); + }); +} + +#[test] +fn edge_handshake_timeout_with_mask_grace_saturating_add_prevents_overflow() { + let mut config = ProxyConfig::default(); + config.timeouts.client_handshake = u64::MAX; + config.censorship.mask = true; + + let timeout = handshake_timeout_with_mask_grace(&config); + assert_eq!(timeout.as_secs(), u64::MAX); +} + +#[test] +fn edge_tls_clienthello_len_in_bounds_exact_boundaries() { + assert!(tls_clienthello_len_in_bounds(MIN_TLS_CLIENT_HELLO_SIZE)); + assert!(!tls_clienthello_len_in_bounds( + MIN_TLS_CLIENT_HELLO_SIZE - 1 + )); + assert!(tls_clienthello_len_in_bounds(MAX_TLS_PLAINTEXT_SIZE)); + assert!(!tls_clienthello_len_in_bounds(MAX_TLS_PLAINTEXT_SIZE + 1)); +} + +#[test] +fn edge_synthetic_local_addr_boundaries() { + assert_eq!(synthetic_local_addr(0).port(), 0); + assert_eq!(synthetic_local_addr(80).port(), 80); + assert_eq!(synthetic_local_addr(u16::MAX).port(), u16::MAX); +} + +#[test] +fn edge_beobachten_record_handshake_failure_class_stream_error_eof() { + let beobachten = BeobachtenStore::new(); + let mut config = ProxyConfig::default(); + config.general.beobachten = true; + config.general.beobachten_minutes = 1; + + let eof_err = ProxyError::Stream(crate::error::StreamError::UnexpectedEof); + let peer_ip: IpAddr = "198.51.100.100".parse().unwrap(); + + record_handshake_failure_class(&beobachten, &config, peer_ip, &eof_err); + + let snapshot = beobachten.snapshot_text(Duration::from_secs(60)); + assert!(snapshot.contains("[expected_64_got_0]")); +} + +#[tokio::test] +async fn adversarial_tls_handshake_timeout_during_masking_delay() { + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.timeouts.client_handshake = 1; + cfg.censorship.mask = true; + cfg.censorship.server_hello_delay_min_ms = 3000; + cfg.censorship.server_hello_delay_max_ms = 3000; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let (server_side, mut client_side) = duplex(4096); + + let handle = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.1:55000".parse().unwrap(), + config, + stats.clone(), + Arc::new(UpstreamManager::new( + vec![], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + false, + )); + + client_side + .write_all(&[0x16, 0x03, 0x01, 0xFF, 0xFF]) + .await + .unwrap(); + + let result = tokio::time::timeout(Duration::from_secs(4), handle) + .await + .unwrap() + .unwrap(); + + assert!(matches!(result, Err(ProxyError::TgHandshakeTimeout))); + assert_eq!(stats.get_handshake_timeouts(), 1); +} + +#[tokio::test] +async fn blackhat_proxy_protocol_slowloris_timeout() { + let mut cfg = ProxyConfig::default(); + cfg.server.proxy_protocol_header_timeout_ms = 200; + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + + let (server_side, mut client_side) = duplex(4096); + let handle = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.2:55000".parse().unwrap(), + config, + stats.clone(), + Arc::new(UpstreamManager::new( + vec![], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + true, + )); + + client_side.write_all(b"PROXY TCP4 192.").await.unwrap(); + tokio::time::sleep(Duration::from_millis(300)).await; + + let result = tokio::time::timeout(Duration::from_secs(2), handle) + .await + .unwrap() + .unwrap(); + + assert!(matches!(result, Err(ProxyError::InvalidProxyProtocol))); + assert_eq!(stats.get_connects_bad(), 1); +} + +#[test] +fn blackhat_ipv4_mapped_ipv6_proxy_source_bypass_attempt() { + let trusted = vec!["192.0.2.0/24".parse().unwrap()]; + let peer_ip = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0xffff, 0xc000, 0x0201)); + assert!(!is_trusted_proxy_source(peer_ip, &trusted)); +} + +#[tokio::test] +async fn negative_proxy_protocol_enabled_but_client_sends_tls_hello() { + let mut cfg = ProxyConfig::default(); + cfg.server.proxy_protocol_header_timeout_ms = 500; + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + + let (server_side, mut client_side) = duplex(4096); + let handle = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.3:55000".parse().unwrap(), + config, + stats.clone(), + Arc::new(UpstreamManager::new( + vec![], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + true, + )); + + client_side + .write_all(&[0x16, 0x03, 0x01, 0x02, 0x00]) + .await + .unwrap(); + + let result = tokio::time::timeout(Duration::from_secs(2), handle) + .await + .unwrap() + .unwrap(); + + assert!(matches!(result, Err(ProxyError::InvalidProxyProtocol))); + assert_eq!(stats.get_connects_bad(), 1); +} + +#[tokio::test] +async fn edge_client_stream_exactly_4_bytes_eof() { + let config = Arc::new(ProxyConfig::default()); + let stats = Arc::new(Stats::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(4096); + let handle = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.4:55000".parse().unwrap(), + config, + stats.clone(), + Arc::new(UpstreamManager::new( + vec![], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + beobachten.clone(), + false, + )); + + client_side + .write_all(&[0x16, 0x03, 0x01, 0x00]) + .await + .unwrap(); + client_side.shutdown().await.unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(2), handle).await; + + let snapshot = beobachten.snapshot_text(Duration::from_secs(60)); + assert!(snapshot.contains("[expected_64_got_0]")); +} + +#[tokio::test] +async fn edge_client_stream_tls_header_valid_but_body_1_byte_short_eof() { + let config = Arc::new(ProxyConfig::default()); + let stats = Arc::new(Stats::new()); + + let (server_side, mut client_side) = duplex(4096); + let handle = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.5:55000".parse().unwrap(), + config, + stats.clone(), + Arc::new(UpstreamManager::new( + vec![], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + false, + )); + + client_side + .write_all(&[0x16, 0x03, 0x01, 0x00, 100]) + .await + .unwrap(); + client_side.write_all(&vec![0x41; 99]).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(2), handle).await; + assert_eq!(stats.get_connects_bad(), 1); +} + +#[tokio::test] +async fn integration_non_tls_modes_disabled_immediately_masks() { + let mut cfg = ProxyConfig::default(); + cfg.general.modes.classic = false; + cfg.general.modes.secure = false; + cfg.censorship.mask = true; + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + + let (server_side, mut client_side) = duplex(4096); + let handle = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.6:55000".parse().unwrap(), + config, + stats.clone(), + Arc::new(UpstreamManager::new( + vec![], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + false, + )); + + client_side.write_all(b"GET / HTTP/1.1\r\n").await.unwrap(); + let _ = tokio::time::timeout(Duration::from_secs(2), handle).await; + assert_eq!(stats.get_connects_bad(), 1); +} + +struct YieldingReader { + data: Vec, + pos: usize, + yields_left: usize, +} + +impl AsyncRead for YieldingReader { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let this = self.get_mut(); + if this.yields_left > 0 { + this.yields_left -= 1; + cx.waker().wake_by_ref(); + return Poll::Pending; + } + if this.pos >= this.data.len() { + return Poll::Ready(Ok(())); + } + buf.put_slice(&this.data[this.pos..this.pos + 1]); + this.pos += 1; + this.yields_left = 2; + Poll::Ready(Ok(())) + } +} + +#[tokio::test] +async fn fuzz_read_with_progress_heavy_yielding() { + let expected_data = b"HEAVY_YIELD_TEST_DATA".to_vec(); + let mut reader = YieldingReader { + data: expected_data.clone(), + pos: 0, + yields_left: 2, + }; + + let mut buf = vec![0u8; expected_data.len()]; + let read_bytes = read_with_progress(&mut reader, &mut buf).await.unwrap(); + + assert_eq!(read_bytes, expected_data.len()); + assert_eq!(buf, expected_data); +} + +#[test] +fn edge_wrap_tls_application_record_exactly_u16_max() { + let payload = vec![0u8; 65535]; + let wrapped = wrap_tls_application_record(&payload); + assert_eq!(wrapped.len(), 65540); + assert_eq!(wrapped[0], TLS_RECORD_APPLICATION); + assert_eq!(&wrapped[3..5], &65535u16.to_be_bytes()); +} + +#[test] +fn fuzz_wrap_tls_application_record_lengths() { + let lengths = [0, 1, 65534, 65535, 65536, 131070, 131071, 131072]; + for len in lengths { + let payload = vec![0u8; len]; + let wrapped = wrap_tls_application_record(&payload); + let expected_chunks = len.div_ceil(65535).max(1); + assert_eq!(wrapped.len(), len + 5 * expected_chunks); + } +} + +#[tokio::test] +async fn stress_user_connection_reservation_concurrent_same_ip_exhaustion() { + let user = "stress-same-ip-user"; + let mut config = ProxyConfig::default(); + config.access.user_max_tcp_conns.insert(user.to_string(), 5); + + let config = Arc::new(config); + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 10).await; + + let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(198, 51, 100, 77)), 55000); + + let mut tasks = tokio::task::JoinSet::new(); + let mut reservations = Vec::new(); + + for _ in 0..10 { + let config = config.clone(); + let stats = stats.clone(); + let ip_tracker = ip_tracker.clone(); + tasks.spawn(async move { + RunningClientHandler::acquire_user_connection_reservation_static( + user, &config, stats, peer, ip_tracker, + ) + .await + }); + } + + let mut successes = 0; + let mut failures = 0; + + while let Some(res) = tasks.join_next().await { + match res.unwrap() { + Ok(r) => { + successes += 1; + reservations.push(r); + } + Err(_) => failures += 1, + } + } + + assert_eq!(successes, 5); + assert_eq!(failures, 5); + assert_eq!(stats.get_user_curr_connects(user), 5); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 1); + + for reservation in reservations { + reservation.release().await; + } + + assert_eq!(stats.get_user_curr_connects(user), 0); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 0); +} diff --git a/src/proxy/tests/client_deep_invariants_tests.rs b/src/proxy/tests/client_deep_invariants_tests.rs new file mode 100644 index 0000000..e57f817 --- /dev/null +++ b/src/proxy/tests/client_deep_invariants_tests.rs @@ -0,0 +1,222 @@ +use super::*; +use crate::config::ProxyConfig; +use crate::protocol::constants::MIN_TLS_CLIENT_HELLO_SIZE; +use crate::stats::Stats; +use crate::transport::UpstreamManager; +use std::sync::Arc; +use std::time::Duration; +use tokio::io::{AsyncWriteExt, duplex}; + +fn preload_user_quota(stats: &Stats, user: &str, bytes: u64) { + let user_stats = stats.get_or_create_user_stats_handle(user); + stats.quota_charge_post_write(user_stats.as_ref(), bytes); +} + +#[test] +fn invariant_wrap_tls_application_record_exact_multiples() { + let chunk_size = u16::MAX as usize; + let payload = vec![0xAA; chunk_size * 2]; + + let wrapped = wrap_tls_application_record(&payload); + + assert_eq!(wrapped.len(), 2 * (5 + chunk_size)); + assert_eq!(wrapped[0], TLS_RECORD_APPLICATION); + assert_eq!(&wrapped[3..5], &65535u16.to_be_bytes()); + + let second_header_idx = 5 + chunk_size; + assert_eq!(wrapped[second_header_idx], TLS_RECORD_APPLICATION); + assert_eq!( + &wrapped[second_header_idx + 3..second_header_idx + 5], + &65535u16.to_be_bytes() + ); +} + +#[tokio::test] +async fn invariant_tls_clienthello_truncation_exact_boundary_triggers_masking() { + let config = Arc::new(ProxyConfig::default()); + let stats = Arc::new(Stats::new()); + + let (server_side, mut client_side) = duplex(4096); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.20:55000".parse().unwrap(), + config, + stats.clone(), + Arc::new(UpstreamManager::new( + vec![], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + false, + )); + + let claimed_len = MIN_TLS_CLIENT_HELLO_SIZE as u16; + let mut header = vec![0x16, 0x03, 0x01]; + header.extend_from_slice(&claimed_len.to_be_bytes()); + + client_side.write_all(&header).await.unwrap(); + client_side + .write_all(&vec![0x42; MIN_TLS_CLIENT_HELLO_SIZE - 1]) + .await + .unwrap(); + client_side.shutdown().await.unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(2), handler) + .await + .unwrap(); + assert_eq!(stats.get_connects_bad(), 1); +} + +#[tokio::test] +async fn invariant_acquire_reservation_ip_limit_rollback() { + let user = "rollback-test-user"; + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 10); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 1).await; + + let peer_a = "198.51.100.21:55000".parse().unwrap(); + let _res_a = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer_a, + ip_tracker.clone(), + ) + .await + .unwrap(); + + assert_eq!(stats.get_user_curr_connects(user), 1); + + let peer_b = "203.0.113.22:55000".parse().unwrap(); + let res_b = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer_b, + ip_tracker.clone(), + ) + .await; + + assert!(matches!( + res_b, + Err(ProxyError::ConnectionLimitExceeded { .. }) + )); + assert_eq!(stats.get_user_curr_connects(user), 1); +} + +#[tokio::test] +async fn invariant_quota_exact_boundary_inclusive() { + let user = "quota-strict-user"; + let mut config = ProxyConfig::default(); + config.access.user_data_quota.insert(user.to_string(), 1000); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + let peer = "198.51.100.23:55000".parse().unwrap(); + + preload_user_quota(stats.as_ref(), user, 999); + let res1 = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await; + assert!(res1.is_ok()); + res1.unwrap().release().await; + + preload_user_quota(stats.as_ref(), user, 1); + let res2 = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await; + assert!(matches!(res2, Err(ProxyError::DataQuotaExceeded { .. }))); +} + +#[tokio::test] +async fn invariant_direct_mode_partial_header_eof_is_error_not_bad_connect() { + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = true; + cfg.general.beobachten_minutes = 1; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(4096); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.25:55000".parse().unwrap(), + config, + stats.clone(), + Arc::new(UpstreamManager::new( + vec![], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + beobachten.clone(), + false, + )); + + client_side.write_all(&[0xEF, 0xEF, 0xEF]).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let result = tokio::time::timeout(Duration::from_secs(2), handler) + .await + .unwrap() + .unwrap(); + + assert!(result.is_err()); + assert_eq!(stats.get_connects_bad(), 0); + let snapshot = beobachten.snapshot_text(Duration::from_secs(60)); + assert!(snapshot.contains("[expected_64_got_0]")); +} + +#[tokio::test] +async fn invariant_route_mode_snapshot_picks_up_latest_mode() { + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + assert!(matches!( + route_runtime.snapshot().mode, + RelayRouteMode::Direct + )); + + route_runtime.set_mode(RelayRouteMode::Middle); + assert!(matches!( + route_runtime.snapshot().mode, + RelayRouteMode::Middle + )); +} diff --git a/src/proxy/tests/client_masking_fragmented_classifier_security_tests.rs b/src/proxy/tests/client_masking_fragmented_classifier_security_tests.rs new file mode 100644 index 0000000..d7ac4ef --- /dev/null +++ b/src/proxy/tests/client_masking_fragmented_classifier_security_tests.rs @@ -0,0 +1,100 @@ +use super::*; +use crate::config::{UpstreamConfig, UpstreamType}; +use std::sync::Arc; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::net::TcpListener; +use tokio::time::Duration; + +fn new_upstream_manager(stats: Arc) -> Arc { + Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats, + )) +} + +#[tokio::test] +async fn fragmented_connect_probe_is_classified_as_http_via_prefetch_window() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = Vec::new(); + stream.read_to_end(&mut got).await.unwrap(); + got + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = true; + cfg.general.beobachten_minutes = 1; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.general.modes.classic = false; + cfg.general.modes.secure = false; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(4096); + let peer: SocketAddr = "198.51.100.251:57501".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats.clone(), + new_upstream_manager(stats), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + beobachten.clone(), + false, + )); + + client_side.write_all(b"CONNE").await.unwrap(); + client_side + .write_all(b"CT example.org:443 HTTP/1.1\r\nHost: example.org\r\n\r\n") + .await + .unwrap(); + client_side.shutdown().await.unwrap(); + + let forwarded = tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + assert!( + forwarded.starts_with(b"CONNECT example.org:443 HTTP/1.1"), + "mask backend must receive the full fragmented CONNECT probe" + ); + + let result = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + assert!(result.is_ok()); + + let snapshot = beobachten.snapshot_text(Duration::from_secs(60)); + assert!(snapshot.contains("[HTTP]")); + assert!(snapshot.contains("198.51.100.251-1")); +} diff --git a/src/proxy/tests/client_masking_http2_fragmented_preface_security_tests.rs b/src/proxy/tests/client_masking_http2_fragmented_preface_security_tests.rs new file mode 100644 index 0000000..3036f95 --- /dev/null +++ b/src/proxy/tests/client_masking_http2_fragmented_preface_security_tests.rs @@ -0,0 +1,122 @@ +use super::*; +use crate::config::{UpstreamConfig, UpstreamType}; +use std::sync::Arc; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::net::TcpListener; +use tokio::time::{Duration, sleep}; + +fn new_upstream_manager(stats: Arc) -> Arc { + Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats, + )) +} + +async fn run_http2_fragment_case(split_at: usize, delay_ms: u64, peer: SocketAddr) { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let preface = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = Vec::new(); + stream.read_to_end(&mut got).await.unwrap(); + got + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = true; + cfg.general.beobachten_minutes = 1; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.general.modes.classic = false; + cfg.general.modes.secure = false; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(4096); + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats.clone(), + new_upstream_manager(stats), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + beobachten.clone(), + false, + )); + + let first = split_at.min(preface.len()); + client_side.write_all(&preface[..first]).await.unwrap(); + if first < preface.len() { + sleep(Duration::from_millis(delay_ms)).await; + client_side.write_all(&preface[first..]).await.unwrap(); + } + client_side.shutdown().await.unwrap(); + + let forwarded = tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + assert!( + forwarded.starts_with(&preface), + "mask backend must receive an intact HTTP/2 preface prefix" + ); + + let result = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + assert!(result.is_ok()); + + let snapshot = beobachten.snapshot_text(Duration::from_secs(60)); + assert!(snapshot.contains("[HTTP]")); + assert!(snapshot.contains(&format!("{}-1", peer.ip()))); +} + +#[tokio::test] +async fn http2_preface_fragmentation_matrix_is_classified_and_forwarded() { + let cases = [(2usize, 0u64), (3, 0), (4, 0), (2, 7), (3, 7), (8, 1)]; + + for (i, (split_at, delay_ms)) in cases.into_iter().enumerate() { + let peer: SocketAddr = format!("198.51.100.{}:58{}", 140 + i, 100 + i) + .parse() + .unwrap(); + run_http2_fragment_case(split_at, delay_ms, peer).await; + } +} + +#[tokio::test] +async fn http2_preface_splitpoint_light_fuzz_classifies_http() { + for split_at in 2usize..=12 { + let delay_ms = if split_at % 3 == 0 { 7 } else { 1 }; + let peer: SocketAddr = format!("198.51.101.{}:59{}", split_at, 10 + split_at) + .parse() + .unwrap(); + run_http2_fragment_case(split_at, delay_ms, peer).await; + } +} diff --git a/src/proxy/tests/client_masking_prefetch_config_pipeline_integration_security_tests.rs b/src/proxy/tests/client_masking_prefetch_config_pipeline_integration_security_tests.rs new file mode 100644 index 0000000..e64dc03 --- /dev/null +++ b/src/proxy/tests/client_masking_prefetch_config_pipeline_integration_security_tests.rs @@ -0,0 +1,150 @@ +use super::*; +use crate::config::{UpstreamConfig, UpstreamType}; +use std::sync::Arc; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::net::TcpListener; +use tokio::time::{Duration, sleep}; + +fn new_upstream_manager(stats: Arc) -> Arc { + Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats, + )) +} + +async fn run_pipeline_prefetch_case( + prefetch_timeout_ms: u64, + delayed_tail_ms: u64, + peer: SocketAddr, +) -> (Vec, String) { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = Vec::new(); + stream.read_to_end(&mut got).await.unwrap(); + got + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = true; + cfg.general.beobachten_minutes = 1; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_classifier_prefetch_timeout_ms = prefetch_timeout_ms; + cfg.general.modes.classic = false; + cfg.general.modes.secure = false; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(4096); + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats.clone(), + new_upstream_manager(stats), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + beobachten.clone(), + false, + )); + + client_side.write_all(b"C").await.unwrap(); + sleep(Duration::from_millis(delayed_tail_ms)).await; + + client_side + .write_all(b"ONNECT example.org:443 HTTP/1.1\r\nHost: example.org\r\n\r\n") + .await + .unwrap(); + client_side.shutdown().await.unwrap(); + + let forwarded = tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + let result = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + assert!(result.is_ok()); + + let snapshot = beobachten.snapshot_text(Duration::from_secs(60)); + (forwarded, snapshot) +} + +#[tokio::test] +async fn tdd_pipeline_prefetch_5ms_misses_15ms_tail_and_classifies_as_port_scanner() { + let peer: SocketAddr = "198.51.100.171:58071".parse().unwrap(); + let (forwarded, snapshot) = run_pipeline_prefetch_case(5, 15, peer).await; + + assert!( + forwarded.starts_with(b"CONNECT"), + "mask backend must still receive full payload bytes in-order" + ); + assert!( + snapshot.contains("[HTTP]") || snapshot.contains("[port-scanner]"), + "unexpected classifier snapshot for 5ms delayed-tail case: {snapshot}" + ); +} + +#[tokio::test] +async fn tdd_pipeline_prefetch_20ms_recovers_15ms_tail_and_classifies_as_http() { + let peer: SocketAddr = "198.51.100.172:58072".parse().unwrap(); + let (forwarded, snapshot) = run_pipeline_prefetch_case(20, 15, peer).await; + + assert!( + forwarded.starts_with(b"CONNECT"), + "mask backend must receive full CONNECT payload" + ); + assert!( + snapshot.contains("[HTTP]"), + "20ms budget should recover delayed fragmented prefix and classify as HTTP" + ); +} + +#[tokio::test] +async fn matrix_pipeline_prefetch_budget_behavior_5_20_50ms() { + let peer5: SocketAddr = "198.51.100.173:58073".parse().unwrap(); + let peer20: SocketAddr = "198.51.100.174:58074".parse().unwrap(); + let peer50: SocketAddr = "198.51.100.175:58075".parse().unwrap(); + + let (_, snap5) = run_pipeline_prefetch_case(5, 35, peer5).await; + let (_, snap20) = run_pipeline_prefetch_case(20, 35, peer20).await; + let (_, snap50) = run_pipeline_prefetch_case(50, 35, peer50).await; + + assert!( + snap5.contains("[HTTP]") || snap5.contains("[port-scanner]"), + "unexpected 5ms snapshot: {snap5}" + ); + assert!( + snap20.contains("[HTTP]") || snap20.contains("[port-scanner]"), + "unexpected 20ms snapshot: {snap20}" + ); + assert!(snap50.contains("[HTTP]")); +} diff --git a/src/proxy/tests/client_masking_prefetch_config_runtime_security_tests.rs b/src/proxy/tests/client_masking_prefetch_config_runtime_security_tests.rs new file mode 100644 index 0000000..64e7a85 --- /dev/null +++ b/src/proxy/tests/client_masking_prefetch_config_runtime_security_tests.rs @@ -0,0 +1,88 @@ +use super::*; +use tokio::io::{AsyncWriteExt, duplex}; +use tokio::time::{Duration, sleep}; + +#[test] +fn prefetch_timeout_budget_reads_from_config() { + let mut cfg = ProxyConfig::default(); + assert_eq!( + mask_classifier_prefetch_timeout(&cfg), + Duration::from_millis(5), + "default prefetch timeout budget must remain 5ms" + ); + + cfg.censorship.mask_classifier_prefetch_timeout_ms = 20; + assert_eq!( + mask_classifier_prefetch_timeout(&cfg), + Duration::from_millis(20), + "runtime prefetch timeout budget must follow configured value" + ); +} + +#[tokio::test] +async fn configured_prefetch_budget_20ms_recovers_tail_delayed_15ms() { + let (mut reader, mut writer) = duplex(1024); + + let writer_task = tokio::spawn(async move { + sleep(Duration::from_millis(15)).await; + writer + .write_all(b"ONNECT example.org:443 HTTP/1.1\r\n") + .await + .expect("tail bytes must be writable"); + writer + .shutdown() + .await + .expect("writer shutdown must succeed"); + }); + + let mut initial_data = b"C".to_vec(); + extend_masking_initial_window_with_timeout( + &mut reader, + &mut initial_data, + Duration::from_millis(20), + ) + .await; + + writer_task + .await + .expect("writer task must not panic in runtime timeout test"); + + assert!( + initial_data.starts_with(b"CONNECT"), + "20ms configured prefetch budget should recover 15ms delayed CONNECT tail" + ); +} + +#[tokio::test] +async fn configured_prefetch_budget_5ms_misses_tail_delayed_15ms() { + let (mut reader, mut writer) = duplex(1024); + + let writer_task = tokio::spawn(async move { + sleep(Duration::from_millis(15)).await; + writer + .write_all(b"ONNECT example.org:443 HTTP/1.1\r\n") + .await + .expect("tail bytes must be writable"); + writer + .shutdown() + .await + .expect("writer shutdown must succeed"); + }); + + let mut initial_data = b"C".to_vec(); + extend_masking_initial_window_with_timeout( + &mut reader, + &mut initial_data, + Duration::from_millis(5), + ) + .await; + + writer_task + .await + .expect("writer task must not panic in runtime timeout test"); + + assert!( + !initial_data.starts_with(b"CONNECT"), + "5ms configured prefetch budget should miss 15ms delayed CONNECT tail" + ); +} diff --git a/src/proxy/tests/client_masking_prefetch_invariant_security_tests.rs b/src/proxy/tests/client_masking_prefetch_invariant_security_tests.rs new file mode 100644 index 0000000..b49db3c --- /dev/null +++ b/src/proxy/tests/client_masking_prefetch_invariant_security_tests.rs @@ -0,0 +1,264 @@ +use super::*; +use crate::config::{UpstreamConfig, UpstreamType}; +use crate::crypto::sha256_hmac; +use crate::protocol::constants::{HANDSHAKE_LEN, TLS_VERSION}; +use crate::protocol::tls; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::net::TcpListener; + +struct PipelineHarness { + config: Arc, + stats: Arc, + upstream_manager: Arc, + replay_checker: Arc, + buffer_pool: Arc, + rng: Arc, + route_runtime: Arc, + ip_tracker: Arc, + beobachten: Arc, +} + +fn build_harness(secret_hex: &str, mask_port: u16) -> PipelineHarness { + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = mask_port; + cfg.censorship.mask_proxy_protocol = 0; + cfg.access.ignore_time_skew = true; + cfg.access + .users + .insert("user".to_string(), secret_hex.to_string()); + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + + PipelineHarness { + config, + stats, + upstream_manager, + replay_checker: Arc::new(ReplayChecker::new(256, Duration::from_secs(60))), + buffer_pool: Arc::new(BufferPool::new()), + rng: Arc::new(SecureRandom::new()), + route_runtime: Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + ip_tracker: Arc::new(UserIpTracker::new()), + beobachten: Arc::new(BeobachtenStore::new()), + } +} + +fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec { + let total_len = 5 + tls_len; + let mut handshake = vec![fill; total_len]; + + handshake[0] = 0x16; + handshake[1] = 0x03; + handshake[2] = 0x01; + handshake[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes()); + + let session_id_len: usize = 32; + handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8; + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + let computed = sha256_hmac(secret, &handshake); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] + .copy_from_slice(&digest); + + handshake +} + +fn wrap_tls_application_data(payload: &[u8]) -> Vec { + let mut record = Vec::with_capacity(5 + payload.len()); + record.push(0x17); + record.extend_from_slice(&TLS_VERSION); + record.extend_from_slice(&(payload.len() as u16).to_be_bytes()); + record.extend_from_slice(payload); + record +} + +async fn read_and_discard_tls_record_body(stream: &mut T, header: [u8; 5]) +where + T: tokio::io::AsyncRead + Unpin, +{ + let len = u16::from_be_bytes([header[3], header[4]]) as usize; + let mut body = vec![0u8; len]; + stream.read_exact(&mut body).await.unwrap(); +} + +#[test] +fn empty_initial_data_prefetch_gate_is_fail_closed() { + assert!( + !should_prefetch_mask_classifier_window(&[]), + "empty initial_data must not trigger classifier prefetch" + ); +} + +#[tokio::test] +async fn blackhat_empty_initial_data_prefetch_must_not_consume_fallback_payload() { + let payload = b"\x17\x03\x03\x00\x10coalesced-tail-bytes".to_vec(); + let (mut reader, mut writer) = duplex(1024); + + writer.write_all(&payload).await.unwrap(); + writer.shutdown().await.unwrap(); + + let mut initial_data = Vec::new(); + extend_masking_initial_window(&mut reader, &mut initial_data).await; + + assert!( + initial_data.is_empty(), + "empty initial_data must remain empty after prefetch stage" + ); + + let mut remaining = Vec::new(); + reader.read_to_end(&mut remaining).await.unwrap(); + assert_eq!( + remaining, payload, + "prefetch stage must not consume fallback payload when initial_data is empty" + ); +} + +#[tokio::test] +async fn positive_fragmented_http_prefix_still_prefetches_within_window() { + let (mut reader, mut writer) = duplex(1024); + writer + .write_all(b"NECT example.org:443 HTTP/1.1\r\n") + .await + .unwrap(); + writer.shutdown().await.unwrap(); + + let mut initial_data = b"CON".to_vec(); + extend_masking_initial_window(&mut reader, &mut initial_data).await; + + assert!( + initial_data.starts_with(b"CONNECT"), + "fragmented HTTP method prefix should still be recoverable by prefetch" + ); + assert!( + initial_data.len() <= 16, + "prefetch window must remain bounded" + ); +} + +#[tokio::test] +async fn light_fuzz_empty_initial_data_never_prefetches_any_bytes() { + let mut seed = 0xD15C_A11E_2026_0322u64; + + for _ in 0..128 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let len = ((seed & 0x3f) as usize).saturating_add(1); + let mut payload = vec![0u8; len]; + for (idx, byte) in payload.iter_mut().enumerate() { + *byte = (seed as u8).wrapping_add(idx as u8).wrapping_mul(17); + } + + let (mut reader, mut writer) = duplex(1024); + writer.write_all(&payload).await.unwrap(); + writer.shutdown().await.unwrap(); + + let mut initial_data = Vec::new(); + extend_masking_initial_window(&mut reader, &mut initial_data).await; + assert!(initial_data.is_empty()); + + let mut remaining = Vec::new(); + reader.read_to_end(&mut remaining).await.unwrap(); + assert_eq!(remaining, payload); + } +} + +#[tokio::test] +async fn blackhat_integration_empty_initial_data_path_is_byte_exact_and_eof_clean() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0xD3u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 411, 600, 0x2B); + let mut invalid_payload = vec![0u8; HANDSHAKE_LEN]; + invalid_payload[0] = 0xFF; + let invalid_mtproto_record = wrap_tls_application_data(&invalid_payload); + let trailing_record = wrap_tls_application_data(b"empty-prefetch-invariant"); + let expected = trailing_record.clone(); + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got = vec![0u8; expected.len()]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, expected); + + let mut one = [0u8; 1]; + let n = stream.read(&mut one).await.unwrap(); + assert_eq!( + n, 0, + "fallback stream must not append synthetic bytes on empty initial_data path" + ); + }); + + let harness = build_harness("d3d3d3d3d3d3d3d3d3d3d3d3d3d3d3d3", backend_addr.port()); + let (server_side, mut client_side) = duplex(131072); + + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.245:56145".parse().unwrap(), + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut head = [0u8; 5]; + client_side.read_exact(&mut head).await.unwrap(); + assert_eq!(head[0], 0x16); + read_and_discard_tls_record_body(&mut client_side, head).await; + + client_side + .write_all(&invalid_mtproto_record) + .await + .unwrap(); + client_side.write_all(&trailing_record).await.unwrap(); + client_side.shutdown().await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} diff --git a/src/proxy/tests/client_masking_prefetch_strict_boundary_security_tests.rs b/src/proxy/tests/client_masking_prefetch_strict_boundary_security_tests.rs new file mode 100644 index 0000000..cbb6603 --- /dev/null +++ b/src/proxy/tests/client_masking_prefetch_strict_boundary_security_tests.rs @@ -0,0 +1,72 @@ +use super::*; +use tokio::io::{AsyncWriteExt, duplex}; +use tokio::time::{Duration, advance, sleep}; + +async fn run_strict_prefetch_case(prefetch_ms: u64, tail_delay_ms: u64) -> Vec { + let (mut reader, mut writer) = duplex(1024); + + let writer_task = tokio::spawn(async move { + sleep(Duration::from_millis(tail_delay_ms)).await; + let _ = writer + .write_all(b"ONNECT example.org:443 HTTP/1.1\r\n") + .await; + let _ = writer.shutdown().await; + }); + + let mut initial_data = b"C".to_vec(); + let mut prefetch_task = tokio::spawn(async move { + extend_masking_initial_window_with_timeout( + &mut reader, + &mut initial_data, + Duration::from_millis(prefetch_ms), + ) + .await; + initial_data + }); + + tokio::task::yield_now().await; + + if tail_delay_ms > 0 { + advance(Duration::from_millis(tail_delay_ms)).await; + tokio::task::yield_now().await; + } + + if prefetch_ms > tail_delay_ms { + advance(Duration::from_millis(prefetch_ms - tail_delay_ms)).await; + tokio::task::yield_now().await; + } + + let result = prefetch_task.await.expect("prefetch task must not panic"); + writer_task.await.expect("writer task must not panic"); + result +} + +#[tokio::test(start_paused = true)] +async fn strict_prefetch_5ms_misses_15ms_tail() { + let got = run_strict_prefetch_case(5, 15).await; + assert_eq!(got, b"C".to_vec()); +} + +#[tokio::test(start_paused = true)] +async fn strict_prefetch_20ms_recovers_15ms_tail() { + let got = run_strict_prefetch_case(20, 15).await; + assert!(got.starts_with(b"CONNECT")); +} + +#[tokio::test(start_paused = true)] +async fn strict_prefetch_50ms_recovers_35ms_tail() { + let got = run_strict_prefetch_case(50, 35).await; + assert!(got.starts_with(b"CONNECT")); +} + +#[tokio::test(start_paused = true)] +async fn strict_prefetch_equal_budget_and_delay_recovers_tail() { + let got = run_strict_prefetch_case(20, 20).await; + assert!(got.starts_with(b"CONNECT")); +} + +#[tokio::test(start_paused = true)] +async fn strict_prefetch_one_ms_after_budget_misses_tail() { + let got = run_strict_prefetch_case(20, 21).await; + assert_eq!(got, b"C".to_vec()); +} diff --git a/src/proxy/tests/client_masking_prefetch_timing_matrix_security_tests.rs b/src/proxy/tests/client_masking_prefetch_timing_matrix_security_tests.rs new file mode 100644 index 0000000..bee1eb3 --- /dev/null +++ b/src/proxy/tests/client_masking_prefetch_timing_matrix_security_tests.rs @@ -0,0 +1,98 @@ +use super::*; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::time::{Duration, sleep, timeout}; + +async fn extend_masking_initial_window_with_budget( + reader: &mut R, + initial_data: &mut Vec, + prefetch_timeout: Duration, +) where + R: AsyncRead + Unpin, +{ + if !should_prefetch_mask_classifier_window(initial_data) { + return; + } + + let need = 16usize.saturating_sub(initial_data.len()); + if need == 0 { + return; + } + + let mut extra = [0u8; 16]; + if let Ok(Ok(n)) = timeout(prefetch_timeout, reader.read(&mut extra[..need])).await + && n > 0 + { + initial_data.extend_from_slice(&extra[..n]); + } +} + +async fn run_prefetch_budget_case(prefetch_budget_ms: u64, delayed_tail_ms: u64) -> bool { + let (mut reader, mut writer) = duplex(1024); + + let writer_task = tokio::spawn(async move { + sleep(Duration::from_millis(delayed_tail_ms)).await; + writer + .write_all(b"ONNECT example.org:443 HTTP/1.1\r\n") + .await + .expect("tail bytes must be writable"); + writer + .shutdown() + .await + .expect("writer shutdown must succeed"); + }); + + let mut initial_data = b"C".to_vec(); + extend_masking_initial_window_with_budget( + &mut reader, + &mut initial_data, + Duration::from_millis(prefetch_budget_ms), + ) + .await; + + writer_task + .await + .expect("writer task must not panic during matrix case"); + + initial_data.starts_with(b"CONNECT") +} + +#[tokio::test] +async fn adversarial_prefetch_budget_matrix_5_20_50ms_for_fragmented_connect_tail() { + let cases = [ + // (tail-delay-ms, expected CONNECT recovery for budgets [5, 20, 50]) + (2u64, [true, true, true]), + (15u64, [false, true, true]), + (35u64, [false, false, true]), + ]; + + for (tail_delay_ms, expected) in cases { + let got_5 = run_prefetch_budget_case(5, tail_delay_ms).await; + let got_20 = run_prefetch_budget_case(20, tail_delay_ms).await; + let got_50 = run_prefetch_budget_case(50, tail_delay_ms).await; + + assert_eq!( + got_5, expected[0], + "5ms prefetch budget mismatch for tail delay {}ms", + tail_delay_ms + ); + assert_eq!( + got_20, expected[1], + "20ms prefetch budget mismatch for tail delay {}ms", + tail_delay_ms + ); + assert_eq!( + got_50, expected[2], + "50ms prefetch budget mismatch for tail delay {}ms", + tail_delay_ms + ); + } +} + +#[tokio::test] +async fn control_current_runtime_prefetch_budget_is_5ms() { + assert_eq!( + MASK_CLASSIFIER_PREFETCH_TIMEOUT, + Duration::from_millis(5), + "matrix assumptions require current runtime prefetch budget to stay at 5ms" + ); +} diff --git a/src/proxy/tests/client_masking_replay_timing_security_tests.rs b/src/proxy/tests/client_masking_replay_timing_security_tests.rs new file mode 100644 index 0000000..c3339e8 --- /dev/null +++ b/src/proxy/tests/client_masking_replay_timing_security_tests.rs @@ -0,0 +1,167 @@ +use super::*; +use crate::config::{UpstreamConfig, UpstreamType}; +use crate::crypto::sha256_hmac; +use crate::protocol::constants::{HANDSHAKE_LEN, TLS_VERSION}; +use crate::protocol::tls; +use std::sync::Arc; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::time::{Duration, Instant}; + +fn new_upstream_manager(stats: Arc) -> Arc { + Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats, + )) +} + +fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec { + let total_len = 5 + tls_len; + let mut handshake = vec![fill; total_len]; + + handshake[0] = 0x16; + handshake[1] = 0x03; + handshake[2] = 0x01; + handshake[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes()); + + let session_id_len: usize = 32; + handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8; + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + let computed = sha256_hmac(secret, &handshake); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] + .copy_from_slice(&digest); + handshake +} + +async fn run_replay_candidate_session( + replay_checker: Arc, + hello: &[u8], + peer: SocketAddr, + drive_mtproto_fail: bool, +) -> Duration { + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = 1; + cfg.censorship.mask_timing_normalization_enabled = false; + cfg.access.ignore_time_skew = true; + cfg.access.users.insert( + "user".to_string(), + "abababababababababababababababab".to_string(), + ); + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(65536); + let started = Instant::now(); + + let task = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats.clone(), + new_upstream_manager(stats), + replay_checker, + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + beobachten, + false, + )); + + client_side.write_all(hello).await.unwrap(); + + if drive_mtproto_fail { + let mut server_hello_head = [0u8; 5]; + client_side + .read_exact(&mut server_hello_head) + .await + .unwrap(); + assert_eq!(server_hello_head[0], 0x16); + let body_len = u16::from_be_bytes([server_hello_head[3], server_hello_head[4]]) as usize; + let mut body = vec![0u8; body_len]; + client_side.read_exact(&mut body).await.unwrap(); + + let mut invalid_mtproto_record = Vec::with_capacity(5 + HANDSHAKE_LEN); + invalid_mtproto_record.push(0x17); + invalid_mtproto_record.extend_from_slice(&TLS_VERSION); + invalid_mtproto_record.extend_from_slice(&(HANDSHAKE_LEN as u16).to_be_bytes()); + invalid_mtproto_record.extend_from_slice(&vec![0u8; HANDSHAKE_LEN]); + client_side + .write_all(&invalid_mtproto_record) + .await + .unwrap(); + client_side + .write_all(b"GET /replay-fallback HTTP/1.1\r\nHost: x\r\n\r\n") + .await + .unwrap(); + } + + client_side.shutdown().await.unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(4), task) + .await + .unwrap() + .unwrap(); + + started.elapsed() +} + +#[tokio::test] +async fn replay_reject_still_honors_masking_timing_budget() { + let replay_checker = Arc::new(ReplayChecker::new(256, Duration::from_secs(60))); + let hello = make_valid_tls_client_hello(&[0xAB; 16], 7, 600, 0x51); + + let seed_elapsed = run_replay_candidate_session( + Arc::clone(&replay_checker), + &hello, + "198.51.100.201:58001".parse().unwrap(), + true, + ) + .await; + + assert!( + seed_elapsed >= Duration::from_millis(40) && seed_elapsed < Duration::from_millis(250), + "seed replay-candidate run must honor masking timing budget without unbounded delay" + ); + + let replay_elapsed = run_replay_candidate_session( + Arc::clone(&replay_checker), + &hello, + "198.51.100.202:58002".parse().unwrap(), + false, + ) + .await; + + assert!( + replay_elapsed >= Duration::from_millis(40) && replay_elapsed < Duration::from_millis(250), + "replay rejection path must still satisfy masking timing budget without unbounded DB/CPU delay" + ); +} diff --git a/src/proxy/tests/client_more_advanced_tests.rs b/src/proxy/tests/client_more_advanced_tests.rs new file mode 100644 index 0000000..8f9d832 --- /dev/null +++ b/src/proxy/tests/client_more_advanced_tests.rs @@ -0,0 +1,288 @@ +use super::*; +use crate::config::ProxyConfig; +use crate::stats::Stats; +use crate::transport::UpstreamManager; +use std::sync::Arc; +use std::time::Duration; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; + +fn preload_user_quota(stats: &Stats, user: &str, bytes: u64) { + let user_stats = stats.get_or_create_user_stats_handle(user); + stats.quota_charge_post_write(user_stats.as_ref(), bytes); +} + +#[tokio::test] +async fn edge_mask_delay_bypassed_if_max_is_zero() { + let mut config = ProxyConfig::default(); + config.censorship.server_hello_delay_min_ms = 10_000; + config.censorship.server_hello_delay_max_ms = 0; + + let start = std::time::Instant::now(); + maybe_apply_mask_reject_delay(&config).await; + assert!(start.elapsed() < Duration::from_millis(50)); +} + +#[test] +fn edge_beobachten_ttl_clamps_exactly_to_24_hours() { + let mut config = ProxyConfig::default(); + config.general.beobachten = true; + config.general.beobachten_minutes = 100_000; + + let ttl = beobachten_ttl(&config); + assert_eq!(ttl.as_secs(), 24 * 60 * 60); +} + +#[test] +fn edge_wrap_tls_application_record_empty_payload() { + let wrapped = wrap_tls_application_record(&[]); + assert_eq!(wrapped.len(), 5); + assert_eq!(wrapped[0], TLS_RECORD_APPLICATION); + assert_eq!(&wrapped[3..5], &[0, 0]); +} + +#[tokio::test] +async fn boundary_user_data_quota_exact_match_rejects() { + let user = "quota-boundary-user"; + let mut config = ProxyConfig::default(); + config.access.user_data_quota.insert(user.to_string(), 1024); + + let stats = Arc::new(Stats::new()); + preload_user_quota(stats.as_ref(), user, 1024); + + let ip_tracker = Arc::new(UserIpTracker::new()); + let peer = "198.51.100.10:55000".parse().unwrap(); + + let result = RunningClientHandler::acquire_user_connection_reservation_static( + user, &config, stats, peer, ip_tracker, + ) + .await; + + assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. }))); +} + +#[tokio::test] +async fn boundary_user_expiration_in_past_rejects() { + let user = "expired-boundary-user"; + let mut config = ProxyConfig::default(); + let expired_time = chrono::Utc::now() - chrono::Duration::milliseconds(1); + config + .access + .user_expirations + .insert(user.to_string(), expired_time); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + let peer = "198.51.100.11:55000".parse().unwrap(); + + let result = RunningClientHandler::acquire_user_connection_reservation_static( + user, &config, stats, peer, ip_tracker, + ) + .await; + + assert!(matches!(result, Err(ProxyError::UserExpired { .. }))); +} + +#[tokio::test] +async fn blackhat_proxy_protocol_massive_garbage_rejected_quickly() { + let mut cfg = ProxyConfig::default(); + cfg.server.proxy_protocol_header_timeout_ms = 300; + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + + let (server_side, mut client_side) = duplex(4096); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.12:55000".parse().unwrap(), + config, + stats.clone(), + Arc::new(UpstreamManager::new( + vec![], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + true, + )); + + client_side.write_all(&vec![b'A'; 2000]).await.unwrap(); + + let result = tokio::time::timeout(Duration::from_secs(2), handler) + .await + .unwrap() + .unwrap(); + assert!(matches!(result, Err(ProxyError::InvalidProxyProtocol))); + assert_eq!(stats.get_connects_bad(), 1); +} + +#[tokio::test] +async fn edge_tls_body_immediate_eof_triggers_masking_and_bad_connect() { + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = true; + cfg.general.beobachten_minutes = 1; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(4096); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.13:55000".parse().unwrap(), + config, + stats.clone(), + Arc::new(UpstreamManager::new( + vec![], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + beobachten.clone(), + false, + )); + + client_side + .write_all(&[0x16, 0x03, 0x01, 0x00, 100]) + .await + .unwrap(); + client_side.shutdown().await.unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(2), handler) + .await + .unwrap(); + + assert_eq!(stats.get_connects_bad(), 1); +} + +#[tokio::test] +async fn security_classic_mode_disabled_masks_valid_length_payload() { + let mut cfg = ProxyConfig::default(); + cfg.general.modes.classic = false; + cfg.general.modes.secure = false; + cfg.censorship.mask = true; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + + let (server_side, mut client_side) = duplex(4096); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.15:55000".parse().unwrap(), + config, + stats.clone(), + Arc::new(UpstreamManager::new( + vec![], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + false, + )); + + client_side.write_all(&vec![0xEF; 64]).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(2), handler) + .await + .unwrap(); + assert_eq!(stats.get_connects_bad(), 1); +} + +#[tokio::test] +async fn concurrency_ip_tracker_strict_limit_one_rapid_churn() { + let user = "rapid-churn-user"; + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 10); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 1).await; + + let peer = "198.51.100.16:55000".parse().unwrap(); + + for _ in 0..500 { + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await + .unwrap(); + reservation.release().await; + } + + assert_eq!(stats.get_user_curr_connects(user), 0); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 0); +} + +#[tokio::test] +async fn quirk_read_with_progress_zero_length_buffer_returns_zero_immediately() { + let (mut server_side, _client_side) = duplex(4096); + let mut empty_buf = &mut [][..]; + + let result = tokio::time::timeout( + Duration::from_millis(50), + read_with_progress(&mut server_side, &mut empty_buf), + ) + .await; + + assert!(result.is_ok()); + assert_eq!(result.unwrap().unwrap(), 0); +} + +#[tokio::test] +async fn stress_read_with_progress_cancellation_safety() { + let (mut server_side, mut client_side) = duplex(4096); + + client_side.write_all(b"12345").await.unwrap(); + + let mut buf = [0u8; 10]; + let result = tokio::time::timeout( + Duration::from_millis(50), + read_with_progress(&mut server_side, &mut buf), + ) + .await; + + assert!(result.is_err()); + + client_side.write_all(b"67890").await.unwrap(); + let mut buf2 = [0u8; 5]; + server_side.read_exact(&mut buf2).await.unwrap(); + assert_eq!(&buf2, b"67890"); +} diff --git a/src/proxy/tests/client_security_tests.rs b/src/proxy/tests/client_security_tests.rs index 6338e23..1b46c6d 100644 --- a/src/proxy/tests/client_security_tests.rs +++ b/src/proxy/tests/client_security_tests.rs @@ -7,6 +7,9 @@ use crate::protocol::tls; use crate::proxy::handshake::HandshakeSuccess; use crate::stream::{CryptoReader, CryptoWriter}; use crate::transport::proxy_protocol::ProxyProtocolV1Builder; +use rand::Rng; +use rand::SeedableRng; +use rand::rngs::StdRng; use std::net::Ipv4Addr; use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; use tokio::net::{TcpListener, TcpStream}; @@ -25,6 +28,220 @@ fn synthetic_local_addr_uses_configured_port_for_max() { assert_eq!(addr.port(), u16::MAX); } +#[test] +fn handshake_timeout_with_mask_grace_includes_mask_margin() { + let mut config = ProxyConfig::default(); + config.timeouts.client_handshake = 2; + + config.censorship.mask = false; + assert_eq!( + handshake_timeout_with_mask_grace(&config), + Duration::from_secs(2) + ); + + config.censorship.mask = true; + assert_eq!( + handshake_timeout_with_mask_grace(&config), + Duration::from_millis(2750), + "mask mode extends handshake timeout by 750 ms" + ); +} + +#[tokio::test] +async fn read_with_progress_reads_partial_buffers_before_eof() { + let data = vec![0xAA, 0xBB, 0xCC]; + let mut reader = std::io::Cursor::new(data); + let mut buf = [0u8; 5]; + + let read = read_with_progress(&mut reader, &mut buf).await.unwrap(); + assert_eq!(read, 3); + assert_eq!(&buf[..3], &[0xAA, 0xBB, 0xCC]); +} + +#[test] +fn is_trusted_proxy_source_respects_cidr_list_and_empty_rejects_all() { + let peer: IpAddr = "10.10.10.10".parse().unwrap(); + assert!(!is_trusted_proxy_source(peer, &[])); + + let trusted = vec!["10.0.0.0/8".parse().unwrap()]; + assert!(is_trusted_proxy_source(peer, &trusted)); + + let not_trusted = vec!["192.0.2.0/24".parse().unwrap()]; + assert!(!is_trusted_proxy_source(peer, ¬_trusted)); +} + +#[test] +fn is_trusted_proxy_source_accepts_cidr_zero_zero_as_global_cidr() { + let peer: IpAddr = "203.0.113.42".parse().unwrap(); + let trust_all = vec!["0.0.0.0/0".parse().unwrap()]; + assert!(is_trusted_proxy_source(peer, &trust_all)); + + let peer_v6: IpAddr = "2001:db8::1".parse().unwrap(); + let trust_all_v6 = vec!["::/0".parse().unwrap()]; + assert!(is_trusted_proxy_source(peer_v6, &trust_all_v6)); +} + +struct ErrorReader; + +impl tokio::io::AsyncRead for ErrorReader { + fn poll_read( + self: std::pin::Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + _buf: &mut tokio::io::ReadBuf<'_>, + ) -> std::task::Poll> { + std::task::Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + "fake error", + ))) + } +} + +#[tokio::test] +async fn read_with_progress_returns_error_from_failed_reader() { + let mut reader = ErrorReader; + let mut buf = [0u8; 8]; + let err = read_with_progress(&mut reader, &mut buf).await.unwrap_err(); + assert_eq!(err.kind(), std::io::ErrorKind::UnexpectedEof); +} + +#[test] +fn handshake_timeout_with_mask_grace_handles_maximum_values_without_overflow() { + let mut config = ProxyConfig::default(); + config.timeouts.client_handshake = u64::MAX; + config.censorship.mask = true; + + let timeout = handshake_timeout_with_mask_grace(&config); + assert!(timeout >= Duration::from_secs(u64::MAX)); +} + +#[tokio::test] +async fn read_with_progress_zero_length_buffer_returns_zero() { + let data = vec![1, 2, 3]; + let mut reader = std::io::Cursor::new(data); + let mut buf = []; + + let read = read_with_progress(&mut reader, &mut buf).await.unwrap(); + assert_eq!(read, 0); +} + +#[test] +fn handshake_timeout_without_mask_is_exact_base() { + let mut config = ProxyConfig::default(); + config.timeouts.client_handshake = 7; + config.censorship.mask = false; + + assert_eq!( + handshake_timeout_with_mask_grace(&config), + Duration::from_secs(7) + ); +} + +#[test] +fn handshake_timeout_mask_enabled_adds_750ms() { + let mut config = ProxyConfig::default(); + config.timeouts.client_handshake = 3; + config.censorship.mask = true; + + assert_eq!( + handshake_timeout_with_mask_grace(&config), + Duration::from_millis(3750) + ); +} + +#[tokio::test] +async fn read_with_progress_full_then_empty_transition() { + let data = vec![0x10, 0x20]; + let mut cursor = std::io::Cursor::new(data); + let mut buf = [0u8; 2]; + + assert_eq!(read_with_progress(&mut cursor, &mut buf).await.unwrap(), 2); + assert_eq!(read_with_progress(&mut cursor, &mut buf).await.unwrap(), 0); +} + +#[tokio::test] +async fn read_with_progress_fragmented_io_works_over_multiple_calls() { + let mut cursor = std::io::Cursor::new(vec![1, 2, 3, 4, 5]); + let mut result = Vec::new(); + + for chunk_size in 1..=5 { + let mut b = vec![0u8; chunk_size]; + let n = read_with_progress(&mut cursor, &mut b).await.unwrap(); + result.extend_from_slice(&b[..n]); + if n == 0 { + break; + } + } + + assert_eq!(result, vec![1, 2, 3, 4, 5]); +} + +#[tokio::test] +async fn read_with_progress_stress_randomized_chunk_sizes() { + for i in 0..128 { + let mut rng = StdRng::seed_from_u64(i as u64 + 1); + let mut input: Vec = (0..(i % 41)).map(|_| rng.next_u32() as u8).collect(); + let mut cursor = std::io::Cursor::new(input.clone()); + let mut collected = Vec::new(); + + while cursor.position() < cursor.get_ref().len() as u64 { + let chunk = 1 + (rng.next_u32() as usize % 8); + let mut b = vec![0u8; chunk]; + let read = read_with_progress(&mut cursor, &mut b).await.unwrap(); + collected.extend_from_slice(&b[..read]); + if read == 0 { + break; + } + } + + assert_eq!(collected, input); + } +} + +#[test] +fn is_trusted_proxy_source_boundary_narrow_ipv4() { + let matching = "172.16.0.1".parse().unwrap(); + let not_matching = "172.15.255.255".parse().unwrap(); + let cidr = vec!["172.16.0.0/12".parse().unwrap()]; + assert!(is_trusted_proxy_source(matching, &cidr)); + assert!(!is_trusted_proxy_source(not_matching, &cidr)); +} + +#[test] +fn is_trusted_proxy_source_rejects_out_of_family_ipv6_v4_cidr() { + let peer = "2001:db8::1".parse().unwrap(); + let cidr = vec!["10.0.0.0/8".parse().unwrap()]; + assert!(!is_trusted_proxy_source(peer, &cidr)); +} + +#[test] +fn wrap_tls_application_record_reserved_chunks_look_reasonable() { + let payload = vec![0xAA; 1 + (u16::MAX as usize) + 2]; + let wrapped = wrap_tls_application_record(&payload); + assert!(wrapped.len() > payload.len()); + assert!(wrapped.contains(&0x17)); +} + +#[test] +fn wrap_tls_application_record_roundtrip_size_check() { + let payload_len = 3000; + let payload = vec![0x55; payload_len]; + let wrapped = wrap_tls_application_record(&payload); + + let mut idx = 0; + let mut consumed = 0; + while idx + 5 <= wrapped.len() { + assert_eq!(wrapped[idx], 0x17); + let len = u16::from_be_bytes([wrapped[idx + 3], wrapped[idx + 4]]) as usize; + consumed += len; + idx += 5 + len; + if idx >= wrapped.len() { + break; + } + } + + assert_eq!(consumed, payload_len); +} + fn make_crypto_reader(reader: R) -> CryptoReader where R: tokio::io::AsyncRead + Unpin, @@ -43,6 +260,11 @@ where CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) } +fn preload_user_quota(stats: &Stats, user: &str, bytes: u64) { + let user_stats = stats.get_or_create_user_stats_handle(user); + stats.quota_charge_post_write(user_stats.as_ref(), bytes); +} + #[tokio::test] async fn user_connection_reservation_drop_enqueues_cleanup_synchronously() { let ip_tracker = Arc::new(crate::ip_tracker::UserIpTracker::new()); @@ -2841,7 +3063,7 @@ async fn quota_rejection_does_not_reserve_ip_or_trigger_rollback() { .insert("user".to_string(), 1024); let stats = Stats::new(); - stats.add_user_octets_from("user", 1024); + preload_user_quota(&stats, "user", 1024); let ip_tracker = UserIpTracker::new(); let peer_addr: SocketAddr = "203.0.113.211:50001".parse().unwrap(); diff --git a/src/proxy/tests/client_tls_record_wrap_hardening_security_tests.rs b/src/proxy/tests/client_tls_record_wrap_hardening_security_tests.rs index 08f52d1..7964cdd 100644 --- a/src/proxy/tests/client_tls_record_wrap_hardening_security_tests.rs +++ b/src/proxy/tests/client_tls_record_wrap_hardening_security_tests.rs @@ -25,13 +25,26 @@ fn wrap_tls_application_record_oversized_payload_is_chunked_without_truncation() let len = u16::from_be_bytes([record[offset + 3], record[offset + 4]]) as usize; let body_start = offset + 5; let body_end = body_start + len; - assert!(body_end <= record.len(), "declared TLS record length must be in-bounds"); + assert!( + body_end <= record.len(), + "declared TLS record length must be in-bounds" + ); recovered.extend_from_slice(&record[body_start..body_end]); offset = body_end; frames += 1; } - assert_eq!(offset, record.len(), "record parser must consume exact output size"); - assert_eq!(frames, 2, "oversized payload should split into exactly two records"); - assert_eq!(recovered, payload, "chunked records must preserve full payload"); + assert_eq!( + offset, + record.len(), + "record parser must consume exact output size" + ); + assert_eq!( + frames, 2, + "oversized payload should split into exactly two records" + ); + assert_eq!( + recovered, payload, + "chunked records must preserve full payload" + ); } diff --git a/src/proxy/tests/direct_relay_security_tests.rs b/src/proxy/tests/direct_relay_security_tests.rs index 16fe8da..a731830 100644 --- a/src/proxy/tests/direct_relay_security_tests.rs +++ b/src/proxy/tests/direct_relay_security_tests.rs @@ -773,8 +773,7 @@ fn anchored_open_nix_path_writes_expected_lines() { "target/telemt-unknown-dc-anchored-open-ok-{}/unknown-dc.log", std::process::id() ); - let sanitized = - sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize"); + let sanitized = sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize"); let _ = fs::remove_file(&sanitized.resolved_path); let mut first = open_unknown_dc_log_append_anchored(&sanitized) @@ -787,7 +786,10 @@ fn anchored_open_nix_path_writes_expected_lines() { let content = fs::read_to_string(&sanitized.resolved_path).expect("anchored log file must be readable"); - let lines: Vec<&str> = content.lines().filter(|line| !line.trim().is_empty()).collect(); + let lines: Vec<&str> = content + .lines() + .filter(|line| !line.trim().is_empty()) + .collect(); assert_eq!(lines.len(), 2, "expected one line per anchored append call"); assert!( lines.contains(&"dc_idx=31200") && lines.contains(&"dc_idx=31201"), @@ -811,8 +813,7 @@ fn anchored_open_parallel_appends_preserve_line_integrity() { "target/telemt-unknown-dc-anchored-open-parallel-{}/unknown-dc.log", std::process::id() ); - let sanitized = - sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize"); + let sanitized = sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize"); let _ = fs::remove_file(&sanitized.resolved_path); let mut workers = Vec::new(); @@ -831,8 +832,15 @@ fn anchored_open_parallel_appends_preserve_line_integrity() { let content = fs::read_to_string(&sanitized.resolved_path).expect("parallel log file must be readable"); - let lines: Vec<&str> = content.lines().filter(|line| !line.trim().is_empty()).collect(); - assert_eq!(lines.len(), 64, "expected one complete line per worker append"); + let lines: Vec<&str> = content + .lines() + .filter(|line| !line.trim().is_empty()) + .collect(); + assert_eq!( + lines.len(), + 64, + "expected one complete line per worker append" + ); for line in lines { assert!( line.starts_with("dc_idx="), @@ -867,8 +875,7 @@ fn anchored_open_creates_private_0600_file_permissions() { "target/telemt-unknown-dc-anchored-perms-{}/unknown-dc.log", std::process::id() ); - let sanitized = - sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize"); + let sanitized = sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize"); let _ = fs::remove_file(&sanitized.resolved_path); let mut file = open_unknown_dc_log_append_anchored(&sanitized) @@ -905,8 +912,7 @@ fn anchored_open_rejects_existing_symlink_target() { "target/telemt-unknown-dc-anchored-symlink-target-{}/unknown-dc.log", std::process::id() ); - let sanitized = - sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize"); + let sanitized = sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize"); let outside = std::env::temp_dir().join(format!( "telemt-unknown-dc-anchored-symlink-outside-{}.log", @@ -943,8 +949,7 @@ fn anchored_open_high_contention_multi_write_preserves_complete_lines() { "target/telemt-unknown-dc-anchored-contention-{}/unknown-dc.log", std::process::id() ); - let sanitized = - sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize"); + let sanitized = sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize"); let _ = fs::remove_file(&sanitized.resolved_path); let workers = 24usize; @@ -970,7 +975,10 @@ fn anchored_open_high_contention_multi_write_preserves_complete_lines() { let content = fs::read_to_string(&sanitized.resolved_path) .expect("contention output file must be readable"); - let lines: Vec<&str> = content.lines().filter(|line| !line.trim().is_empty()).collect(); + let lines: Vec<&str> = content + .lines() + .filter(|line| !line.trim().is_empty()) + .collect(); assert_eq!( lines.len(), workers * rounds, @@ -1014,8 +1022,7 @@ fn append_unknown_dc_line_returns_error_for_read_only_descriptor() { "target/telemt-unknown-dc-append-ro-{}/unknown-dc.log", std::process::id() ); - let sanitized = - sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize"); + let sanitized = sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize"); fs::write(&sanitized.resolved_path, "seed\n").expect("seed file must be writable"); let mut readonly = std::fs::OpenOptions::new() diff --git a/src/proxy/tests/handshake_advanced_clever_tests.rs b/src/proxy/tests/handshake_advanced_clever_tests.rs new file mode 100644 index 0000000..76347c4 --- /dev/null +++ b/src/proxy/tests/handshake_advanced_clever_tests.rs @@ -0,0 +1,719 @@ +use super::*; +use crate::crypto::{AesCtr, sha256, sha256_hmac}; +use crate::protocol::constants::{ProtoTag, RESERVED_NONCE_BEGINNINGS, RESERVED_NONCE_FIRST_BYTES}; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; +use std::sync::Arc; +use std::time::{Duration, Instant}; + +// --- Helpers --- + +fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> { + auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig { + let mut cfg = ProxyConfig::default(); + cfg.access.users.clear(); + cfg.access + .users + .insert("user".to_string(), secret_hex.to_string()); + cfg.access.ignore_time_skew = true; + cfg.general.modes.secure = true; + cfg.general.modes.classic = true; + cfg.general.modes.tls = true; + cfg +} + +fn make_valid_tls_handshake(secret: &[u8], timestamp: u32) -> Vec { + let session_id_len: usize = 32; + let len = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + session_id_len; + let mut handshake = vec![0x42u8; len]; + + handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8; + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + + let computed = sha256_hmac(secret, &handshake); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] + .copy_from_slice(&digest); + handshake +} + +fn make_valid_mtproto_handshake( + secret_hex: &str, + proto_tag: ProtoTag, + dc_idx: i16, +) -> [u8; HANDSHAKE_LEN] { + let secret = hex::decode(secret_hex).expect("secret hex must decode"); + let mut handshake = [0x5Au8; HANDSHAKE_LEN]; + for (idx, b) in handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN] + .iter_mut() + .enumerate() + { + *b = (idx as u8).wrapping_add(1); + } + + let dec_prekey = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN]; + let dec_iv_bytes = &handshake[SKIP_LEN + PREKEY_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]; + + let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len()); + dec_key_input.extend_from_slice(dec_prekey); + dec_key_input.extend_from_slice(&secret); + let dec_key = sha256(&dec_key_input); + + let mut dec_iv_arr = [0u8; IV_LEN]; + dec_iv_arr.copy_from_slice(dec_iv_bytes); + let dec_iv = u128::from_be_bytes(dec_iv_arr); + + let mut stream = AesCtr::new(&dec_key, dec_iv); + let keystream = stream.encrypt(&[0u8; HANDSHAKE_LEN]); + + let mut target_plain = [0u8; HANDSHAKE_LEN]; + target_plain[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes()); + target_plain[DC_IDX_POS..DC_IDX_POS + 2].copy_from_slice(&dc_idx.to_le_bytes()); + + for idx in PROTO_TAG_POS..HANDSHAKE_LEN { + handshake[idx] = target_plain[idx] ^ keystream[idx]; + } + + handshake +} + +fn make_valid_tls_client_hello_with_alpn( + secret: &[u8], + timestamp: u32, + alpn_protocols: &[&[u8]], +) -> Vec { + let mut body = Vec::new(); + body.extend_from_slice(&TLS_VERSION); + body.extend_from_slice(&[0u8; 32]); + body.push(32); + body.extend_from_slice(&[0x42u8; 32]); + body.extend_from_slice(&2u16.to_be_bytes()); + body.extend_from_slice(&[0x13, 0x01]); + body.push(1); + body.push(0); + + let mut ext_blob = Vec::new(); + if !alpn_protocols.is_empty() { + let mut alpn_list = Vec::new(); + for proto in alpn_protocols { + alpn_list.push(proto.len() as u8); + alpn_list.extend_from_slice(proto); + } + let mut alpn_data = Vec::new(); + alpn_data.extend_from_slice(&(alpn_list.len() as u16).to_be_bytes()); + alpn_data.extend_from_slice(&alpn_list); + + ext_blob.extend_from_slice(&0x0010u16.to_be_bytes()); + ext_blob.extend_from_slice(&(alpn_data.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&alpn_data); + } + body.extend_from_slice(&(ext_blob.len() as u16).to_be_bytes()); + body.extend_from_slice(&ext_blob); + + let mut handshake = Vec::new(); + handshake.push(0x01); + let body_len = (body.len() as u32).to_be_bytes(); + handshake.extend_from_slice(&body_len[1..4]); + handshake.extend_from_slice(&body); + + let mut record = Vec::new(); + record.push(TLS_RECORD_HANDSHAKE); + record.extend_from_slice(&[0x03, 0x01]); + record.extend_from_slice(&(handshake.len() as u16).to_be_bytes()); + record.extend_from_slice(&handshake); + + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + let computed = sha256_hmac(secret, &record); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].copy_from_slice(&digest); + + record +} + +// --- Category 1: Edge Cases & Protocol Boundaries --- + +#[tokio::test] +async fn tls_minimum_viable_length_boundary() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret = [0x11u8; 16]; + let config = test_config_with_secret_hex("11111111111111111111111111111111"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "192.0.2.1:12345".parse().unwrap(); + + let min_len = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1; + let mut exact_min_handshake = vec![0x42u8; min_len]; + exact_min_handshake[min_len - 1] = 0; + exact_min_handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + let digest = sha256_hmac(&secret, &exact_min_handshake); + exact_min_handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] + .copy_from_slice(&digest); + + let res = handle_tls_handshake( + &exact_min_handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!( + matches!(res, HandshakeResult::Success(_)), + "Exact minimum length TLS handshake must succeed" + ); + + let short_handshake = vec![0x42u8; min_len - 1]; + let res_short = handle_tls_handshake( + &short_handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!( + matches!(res_short, HandshakeResult::BadClient { .. }), + "Handshake 1 byte shorter than minimum must fail closed" + ); +} + +#[tokio::test] +async fn mtproto_extreme_dc_index_serialization() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret_hex = "22222222222222222222222222222222"; + let config = test_config_with_secret_hex(secret_hex); + for (idx, extreme_dc) in [i16::MIN, i16::MAX, -1, 0].into_iter().enumerate() { + // Keep replay state independent per case so we validate dc_idx encoding, + // not duplicate-handshake rejection behavior. + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 0, 2, 2)), 12345 + idx as u16); + let handshake = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, extreme_dc); + let res = handle_mtproto_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + + match res { + HandshakeResult::Success((_, _, success)) => { + assert_eq!( + success.dc_idx, extreme_dc, + "Extreme DC index {} must serialize/deserialize perfectly", + extreme_dc + ); + } + _ => panic!( + "MTProto handshake with extreme DC index {} failed", + extreme_dc + ), + } + } +} + +#[tokio::test] +async fn alpn_strict_case_and_padding_rejection() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret = [0x33u8; 16]; + let mut config = test_config_with_secret_hex("33333333333333333333333333333333"); + config.censorship.alpn_enforce = true; + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "192.0.2.3:12345".parse().unwrap(); + + let bad_alpns: &[&[u8]] = &[b"H2", b"h2\0", b" http/1.1", b"http/1.1\n"]; + + for bad_alpn in bad_alpns { + let handshake = make_valid_tls_client_hello_with_alpn(&secret, 0, &[*bad_alpn]); + let res = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!( + matches!(res, HandshakeResult::BadClient { .. }), + "ALPN strict enforcement must reject {:?}", + bad_alpn + ); + } +} + +#[test] +fn ipv4_mapped_ipv6_bucketing_anomaly() { + let ipv4_mapped_1 = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0xffff, 0xc000, 0x0201)); + let ipv4_mapped_2 = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0xffff, 0xc633, 0x6402)); + + let norm_1 = normalize_auth_probe_ip(ipv4_mapped_1); + let norm_2 = normalize_auth_probe_ip(ipv4_mapped_2); + + assert_eq!( + norm_1, norm_2, + "IPv4-mapped IPv6 addresses must collapse into the same /64 bucket (::0)" + ); + assert_eq!( + norm_1, + IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), + "The bucket must be exactly ::0" + ); +} + +// --- Category 2: Adversarial & Black Hat --- + +#[tokio::test] +async fn mtproto_invalid_ciphertext_does_not_poison_replay_cache() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret_hex = "55555555555555555555555555555555"; + let config = test_config_with_secret_hex(secret_hex); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let peer: SocketAddr = "192.0.2.5:12345".parse().unwrap(); + + let valid_handshake = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 1); + let mut invalid_handshake = valid_handshake; + invalid_handshake[SKIP_LEN + PREKEY_LEN + IV_LEN + 1] ^= 0xFF; + + let res_invalid = handle_mtproto_handshake( + &invalid_handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + assert!(matches!(res_invalid, HandshakeResult::BadClient { .. })); + + let res_valid = handle_mtproto_handshake( + &valid_handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + assert!( + matches!(res_valid, HandshakeResult::Success(_)), + "Invalid MTProto ciphertext must not poison the replay cache" + ); +} + +#[tokio::test] +async fn tls_invalid_session_does_not_poison_replay_cache() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret = [0x66u8; 16]; + let config = test_config_with_secret_hex("66666666666666666666666666666666"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "192.0.2.6:12345".parse().unwrap(); + + let valid_handshake = make_valid_tls_handshake(&secret, 0); + let mut invalid_handshake = valid_handshake.clone(); + let session_idx = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1; + invalid_handshake[session_idx] ^= 0xFF; + + let res_invalid = handle_tls_handshake( + &invalid_handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(res_invalid, HandshakeResult::BadClient { .. })); + + let res_valid = handle_tls_handshake( + &valid_handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!( + matches!(res_valid, HandshakeResult::Success(_)), + "Invalid TLS payload must not poison the replay cache" + ); +} + +#[tokio::test] +async fn server_hello_delay_timing_neutrality_on_hmac_failure() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret = [0x77u8; 16]; + let mut config = test_config_with_secret_hex("77777777777777777777777777777777"); + config.censorship.server_hello_delay_min_ms = 50; + config.censorship.server_hello_delay_max_ms = 50; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "192.0.2.7:12345".parse().unwrap(); + + let mut invalid_handshake = make_valid_tls_handshake(&secret, 0); + invalid_handshake[tls::TLS_DIGEST_POS] ^= 0xFF; + + let start = Instant::now(); + let res = handle_tls_handshake( + &invalid_handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + let elapsed = start.elapsed(); + + assert!(matches!(res, HandshakeResult::BadClient { .. })); + assert!( + elapsed >= Duration::from_millis(45), + "Invalid HMAC must still incur the configured ServerHello delay to prevent timing side-channels" + ); +} + +#[tokio::test] +async fn server_hello_delay_inversion_resilience() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret = [0x88u8; 16]; + let mut config = test_config_with_secret_hex("88888888888888888888888888888888"); + config.censorship.server_hello_delay_min_ms = 100; + config.censorship.server_hello_delay_max_ms = 10; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "192.0.2.8:12345".parse().unwrap(); + + let valid_handshake = make_valid_tls_handshake(&secret, 0); + + let start = Instant::now(); + let res = handle_tls_handshake( + &valid_handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + let elapsed = start.elapsed(); + + assert!(matches!(res, HandshakeResult::Success(_))); + assert!( + elapsed >= Duration::from_millis(90), + "Delay logic must gracefully handle min > max inversions via max.max(min)" + ); +} + +#[tokio::test] +async fn mixed_valid_and_invalid_user_secrets_configuration() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + let _warn_guard = warned_secrets_test_lock().lock().unwrap(); + clear_warned_secrets_for_testing(); + + let mut config = ProxyConfig::default(); + config.access.ignore_time_skew = true; + + for i in 0..9 { + let bad_secret = if i % 2 == 0 { "badhex!" } else { "1122" }; + config + .access + .users + .insert(format!("bad_user_{}", i), bad_secret.to_string()); + } + let valid_secret_hex = "99999999999999999999999999999999"; + config + .access + .users + .insert("good_user".to_string(), valid_secret_hex.to_string()); + config.general.modes.secure = true; + config.general.modes.classic = true; + config.general.modes.tls = true; + + let secret = [0x99u8; 16]; + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "192.0.2.9:12345".parse().unwrap(); + + let valid_handshake = make_valid_tls_handshake(&secret, 0); + + let res = handle_tls_handshake( + &valid_handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!( + matches!(res, HandshakeResult::Success(_)), + "Proxy must gracefully skip invalid secrets and authenticate the valid one" + ); +} + +#[tokio::test] +async fn tls_emulation_fallback_when_cache_missing() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret = [0xAAu8; 16]; + let mut config = test_config_with_secret_hex("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"); + config.censorship.tls_emulation = true; + config.general.modes.tls = true; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "192.0.2.10:12345".parse().unwrap(); + + let valid_handshake = make_valid_tls_handshake(&secret, 0); + + let res = handle_tls_handshake( + &valid_handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!( + matches!(res, HandshakeResult::Success(_)), + "TLS emulation must gracefully fall back to standard ServerHello if cache is missing" + ); +} + +#[tokio::test] +async fn classic_mode_over_tls_transport_protocol_confusion() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret_hex = "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb"; + let mut config = test_config_with_secret_hex(secret_hex); + config.general.modes.classic = true; + config.general.modes.tls = true; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let peer: SocketAddr = "192.0.2.11:12345".parse().unwrap(); + + let handshake = make_valid_mtproto_handshake(secret_hex, ProtoTag::Intermediate, 1); + + let res = handle_mtproto_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + true, + None, + ) + .await; + + assert!( + matches!(res, HandshakeResult::Success(_)), + "Intermediate tag over TLS must succeed if classic mode is enabled, locking in cross-transport behavior" + ); +} + +#[test] +fn generate_tg_nonce_never_emits_reserved_bytes() { + let client_enc_key = [0xCCu8; 32]; + let client_enc_iv = 123456789u128; + let rng = SecureRandom::new(); + + for _ in 0..10_000 { + let (nonce, _, _, _, _) = generate_tg_nonce( + ProtoTag::Secure, + 1, + &client_enc_key, + client_enc_iv, + &rng, + false, + ); + + assert!( + !RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]), + "Nonce must never start with reserved bytes" + ); + let first_four: [u8; 4] = [nonce[0], nonce[1], nonce[2], nonce[3]]; + assert!( + !RESERVED_NONCE_BEGINNINGS.contains(&first_four), + "Nonce must never match reserved 4-byte beginnings" + ); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn dashmap_concurrent_saturation_stress() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let ip_a: IpAddr = "192.0.2.13".parse().unwrap(); + let ip_b: IpAddr = "198.51.100.13".parse().unwrap(); + let mut tasks = Vec::new(); + + for i in 0..100 { + let target_ip = if i % 2 == 0 { ip_a } else { ip_b }; + tasks.push(tokio::spawn(async move { + for _ in 0..50 { + auth_probe_record_failure(target_ip, Instant::now()); + } + })); + } + + for task in tasks { + task.await + .expect("Task panicked during concurrent DashMap stress"); + } + + assert!( + auth_probe_is_throttled_for_testing(ip_a), + "IP A must be throttled after concurrent stress" + ); + assert!( + auth_probe_is_throttled_for_testing(ip_b), + "IP B must be throttled after concurrent stress" + ); +} + +#[test] +fn prototag_invalid_bytes_fail_closed() { + let invalid_tags: [[u8; 4]; 5] = [ + [0, 0, 0, 0], + [0xFF, 0xFF, 0xFF, 0xFF], + [0xDE, 0xAD, 0xBE, 0xEF], + [0xDD, 0xDD, 0xDD, 0xDE], + [0x11, 0x22, 0x33, 0x44], + ]; + + for tag in invalid_tags { + assert_eq!( + ProtoTag::from_bytes(tag), + None, + "Invalid ProtoTag bytes {:?} must fail closed", + tag + ); + } +} + +#[test] +fn auth_probe_eviction_hash_collision_stress() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let state = auth_probe_state_map(); + let now = Instant::now(); + + for i in 0..10_000u32 { + let ip = IpAddr::V4(Ipv4Addr::new(10, 0, (i >> 8) as u8, (i & 0xFF) as u8)); + auth_probe_record_failure_with_state(state, ip, now); + } + + assert!( + state.len() <= AUTH_PROBE_TRACK_MAX_ENTRIES, + "Eviction logic must successfully bound the map size under heavy insertion stress" + ); +} + +#[test] +fn encrypt_tg_nonce_with_ciphers_advances_counter_correctly() { + let client_enc_key = [0xDDu8; 32]; + let client_enc_iv = 987654321u128; + let rng = SecureRandom::new(); + + let (nonce, _, _, _, _) = generate_tg_nonce( + ProtoTag::Secure, + 2, + &client_enc_key, + client_enc_iv, + &rng, + false, + ); + + let (_, mut returned_encryptor, _) = encrypt_tg_nonce_with_ciphers(&nonce); + let zeros = [0u8; 64]; + let returned_keystream = returned_encryptor.encrypt(&zeros); + + let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN]; + let mut expected_enc_key = [0u8; 32]; + expected_enc_key.copy_from_slice(&enc_key_iv[..KEY_LEN]); + let mut expected_enc_iv_arr = [0u8; IV_LEN]; + expected_enc_iv_arr.copy_from_slice(&enc_key_iv[KEY_LEN..]); + let expected_enc_iv = u128::from_be_bytes(expected_enc_iv_arr); + + let mut manual_encryptor = AesCtr::new(&expected_enc_key, expected_enc_iv); + + let mut manual_input = Vec::new(); + manual_input.extend_from_slice(&nonce); + manual_input.extend_from_slice(&zeros); + let manual_output = manual_encryptor.encrypt(&manual_input); + + assert_eq!( + returned_keystream, + &manual_output[64..128], + "encrypt_tg_nonce_with_ciphers must correctly advance the AES-CTR counter by exactly the nonce length" + ); +} diff --git a/src/proxy/tests/handshake_auth_probe_eviction_bias_security_tests.rs b/src/proxy/tests/handshake_auth_probe_eviction_bias_security_tests.rs new file mode 100644 index 0000000..77cea19 --- /dev/null +++ b/src/proxy/tests/handshake_auth_probe_eviction_bias_security_tests.rs @@ -0,0 +1,96 @@ +use super::*; +use std::collections::HashSet; +use std::net::{IpAddr, Ipv4Addr}; +use std::time::{Duration, Instant}; + +fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> { + auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +#[test] +fn adversarial_large_state_offsets_escape_first_scan_window() { + let _guard = auth_probe_test_guard(); + let base = Instant::now(); + let state_len = 65_536usize; + let scan_limit = 1_024usize; + + let mut saw_offset_outside_first_window = false; + for i in 0..8_192u64 { + let ip = IpAddr::V4(Ipv4Addr::new( + ((i >> 16) & 0xff) as u8, + ((i >> 8) & 0xff) as u8, + (i & 0xff) as u8, + ((i.wrapping_mul(131)) & 0xff) as u8, + )); + let now = base + Duration::from_nanos(i); + let start = auth_probe_scan_start_offset(ip, now, state_len, scan_limit); + if start >= scan_limit { + saw_offset_outside_first_window = true; + break; + } + } + + assert!( + saw_offset_outside_first_window, + "scan start offset must cover the full auth-probe state, not only the first scan window" + ); +} + +#[test] +fn stress_large_state_offsets_cover_many_scan_windows() { + let _guard = auth_probe_test_guard(); + let base = Instant::now(); + let state_len = 65_536usize; + let scan_limit = 1_024usize; + + let mut covered_windows = HashSet::new(); + for i in 0..16_384u64 { + let ip = IpAddr::V4(Ipv4Addr::new( + ((i >> 16) & 0xff) as u8, + ((i >> 8) & 0xff) as u8, + (i & 0xff) as u8, + ((i.wrapping_mul(17)) & 0xff) as u8, + )); + let now = base + Duration::from_micros(i); + let start = auth_probe_scan_start_offset(ip, now, state_len, scan_limit); + covered_windows.insert(start / scan_limit); + } + + assert!( + covered_windows.len() >= 16, + "eviction scan must not collapse to a tiny hot zone; covered windows={} out of {}", + covered_windows.len(), + state_len / scan_limit + ); +} + +#[test] +fn light_fuzz_offset_always_stays_inside_state_len() { + let _guard = auth_probe_test_guard(); + let mut seed = 0xC0FF_EE12_3456_789Au64; + let base = Instant::now(); + + for _ in 0..8_192usize { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let ip = IpAddr::V4(Ipv4Addr::new( + (seed >> 24) as u8, + (seed >> 16) as u8, + (seed >> 8) as u8, + seed as u8, + )); + let state_len = ((seed >> 16) as usize % 200_000).saturating_add(1); + let scan_limit = ((seed >> 40) as usize % 2_048).saturating_add(1); + let now = base + Duration::from_nanos(seed & 0x0fff); + let start = auth_probe_scan_start_offset(ip, now, state_len, scan_limit); + + assert!( + start < state_len, + "scan offset must stay inside state length" + ); + } +} diff --git a/src/proxy/tests/handshake_auth_probe_scan_budget_security_tests.rs b/src/proxy/tests/handshake_auth_probe_scan_budget_security_tests.rs new file mode 100644 index 0000000..c91a215 --- /dev/null +++ b/src/proxy/tests/handshake_auth_probe_scan_budget_security_tests.rs @@ -0,0 +1,99 @@ +use super::*; +use std::net::{IpAddr, Ipv4Addr}; +use std::time::{Duration, Instant}; + +fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> { + auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +#[test] +fn edge_zero_state_len_yields_zero_start_offset() { + let _guard = auth_probe_test_guard(); + let ip = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 44)); + let now = Instant::now(); + + assert_eq!( + auth_probe_scan_start_offset(ip, now, 0, 16), + 0, + "empty map must not produce non-zero scan offset" + ); +} + +#[test] +fn adversarial_large_state_must_allow_start_offset_outside_scan_budget_window() { + let _guard = auth_probe_test_guard(); + let base = Instant::now(); + let scan_limit = 16usize; + let state_len = 65_536usize; + + let mut saw_offset_outside_window = false; + for i in 0..2048u32 { + let ip = IpAddr::V4(Ipv4Addr::new( + 203, + ((i >> 16) & 0xff) as u8, + ((i >> 8) & 0xff) as u8, + (i & 0xff) as u8, + )); + let now = base + Duration::from_micros(i as u64); + let start = auth_probe_scan_start_offset(ip, now, state_len, scan_limit); + assert!( + start < state_len, + "start offset must stay within state length; start={start}, len={state_len}" + ); + if start >= scan_limit { + saw_offset_outside_window = true; + break; + } + } + + assert!( + saw_offset_outside_window, + "large-state eviction must sample beyond the first scan window" + ); +} + +#[test] +fn positive_state_smaller_than_scan_limit_caps_to_state_len() { + let _guard = auth_probe_test_guard(); + let ip = IpAddr::V4(Ipv4Addr::new(192, 0, 2, 17)); + let now = Instant::now(); + + for state_len in 1..32usize { + let start = auth_probe_scan_start_offset(ip, now, state_len, 64); + assert!( + start < state_len, + "start offset must never exceed state length when scan limit is larger" + ); + } +} + +#[test] +fn light_fuzz_scan_offset_budget_never_exceeds_effective_window() { + let _guard = auth_probe_test_guard(); + let mut seed = 0x5A41_5356_4C32_3236u64; + let base = Instant::now(); + + for _ in 0..4096 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let ip = IpAddr::V4(Ipv4Addr::new( + (seed >> 24) as u8, + (seed >> 16) as u8, + (seed >> 8) as u8, + seed as u8, + )); + let state_len = ((seed >> 8) as usize % 131_072).saturating_add(1); + let scan_limit = ((seed >> 32) as usize % 512).saturating_add(1); + let now = base + Duration::from_nanos(seed & 0xffff); + let start = auth_probe_scan_start_offset(ip, now, state_len, scan_limit); + + assert!( + start < state_len, + "scan offset must stay inside state length" + ); + } +} diff --git a/src/proxy/tests/handshake_auth_probe_scan_offset_stress_tests.rs b/src/proxy/tests/handshake_auth_probe_scan_offset_stress_tests.rs new file mode 100644 index 0000000..bf97990 --- /dev/null +++ b/src/proxy/tests/handshake_auth_probe_scan_offset_stress_tests.rs @@ -0,0 +1,116 @@ +use super::*; +use std::collections::HashSet; +use std::net::{IpAddr, Ipv4Addr}; +use std::time::{Duration, Instant}; + +fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> { + auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +#[test] +fn positive_same_ip_moving_time_yields_diverse_scan_offsets() { + let _guard = auth_probe_test_guard(); + let ip = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 77)); + let base = Instant::now(); + let mut uniq = HashSet::new(); + + for i in 0..512u64 { + let now = base + Duration::from_nanos(i); + let offset = auth_probe_scan_start_offset(ip, now, 65_536, 16); + uniq.insert(offset); + } + + assert!( + uniq.len() >= 256, + "offset randomization collapsed unexpectedly for same-ip moving-time samples (uniq={})", + uniq.len() + ); +} + +#[test] +fn adversarial_many_ips_same_time_spreads_offsets_without_bias_collapse() { + let _guard = auth_probe_test_guard(); + let now = Instant::now(); + let mut uniq = HashSet::new(); + + for i in 0..1024u32 { + let ip = IpAddr::V4(Ipv4Addr::new( + (i >> 16) as u8, + (i >> 8) as u8, + i as u8, + (255 - (i as u8)), + )); + uniq.insert(auth_probe_scan_start_offset(ip, now, 65_536, 16)); + } + + assert!( + uniq.len() >= 512, + "scan offset distribution collapsed unexpectedly across adversarial peer set (uniq={})", + uniq.len() + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_parallel_failure_churn_under_saturation_remains_capped_and_live() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let start = Instant::now(); + let mut workers = Vec::new(); + for worker in 0..8u8 { + workers.push(tokio::spawn(async move { + for i in 0..8192u32 { + let ip = IpAddr::V4(Ipv4Addr::new( + 10, + worker, + ((i >> 8) & 0xff) as u8, + (i & 0xff) as u8, + )); + auth_probe_record_failure(ip, start + Duration::from_micros((i % 128) as u64)); + } + })); + } + + for worker in workers { + worker.await.expect("saturation worker must not panic"); + } + + assert!( + auth_probe_state_map().len() <= AUTH_PROBE_TRACK_MAX_ENTRIES, + "state must remain hard-capped under parallel saturation churn" + ); + + let probe = IpAddr::V4(Ipv4Addr::new(10, 4, 1, 1)); + let _ = auth_probe_should_apply_preauth_throttle(probe, start + Duration::from_millis(1)); +} + +#[test] +fn light_fuzz_scan_offset_stays_within_window_for_randomized_inputs() { + let _guard = auth_probe_test_guard(); + let mut seed = 0xA55A_1357_2468_9BDFu64; + let base = Instant::now(); + + for _ in 0..8192 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let ip = IpAddr::V4(Ipv4Addr::new( + (seed >> 24) as u8, + (seed >> 16) as u8, + (seed >> 8) as u8, + seed as u8, + )); + let state_len = ((seed >> 8) as usize % 200_000).saturating_add(1); + let scan_limit = ((seed >> 40) as usize % 1024).saturating_add(1); + let now = base + Duration::from_nanos(seed & 0x1fff); + + let offset = auth_probe_scan_start_offset(ip, now, state_len, scan_limit); + assert!( + offset < state_len, + "scan offset must always remain inside state length" + ); + } +} diff --git a/src/proxy/tests/handshake_key_material_zeroization_security_tests.rs b/src/proxy/tests/handshake_key_material_zeroization_security_tests.rs new file mode 100644 index 0000000..7176b1c --- /dev/null +++ b/src/proxy/tests/handshake_key_material_zeroization_security_tests.rs @@ -0,0 +1,42 @@ +use super::*; + +fn handshake_source() -> &'static str { + include_str!("../handshake.rs") +} + +#[test] +fn security_dec_key_derivation_is_zeroized_in_candidate_loop() { + let src = handshake_source(); + assert!( + src.contains("let dec_key = Zeroizing::new(sha256(&dec_key_input));"), + "candidate-loop dec_key derivation must be wrapped in Zeroizing to clear secrets on early-continue paths" + ); +} + +#[test] +fn security_enc_key_derivation_is_zeroized_in_candidate_loop() { + let src = handshake_source(); + assert!( + src.contains("let enc_key = Zeroizing::new(sha256(&enc_key_input));"), + "candidate-loop enc_key derivation must be wrapped in Zeroizing to clear secrets on early-continue paths" + ); +} + +#[test] +fn security_aes_ctr_initialization_uses_zeroizing_references() { + let src = handshake_source(); + assert!( + src.contains("let mut decryptor = AesCtr::new(&dec_key, dec_iv);") + && src.contains("let encryptor = AesCtr::new(&enc_key, enc_iv);"), + "AES-CTR initialization must use Zeroizing key wrappers directly without creating extra plain key variables" + ); +} + +#[test] +fn security_success_struct_copies_out_of_zeroizing_wrappers() { + let src = handshake_source(); + assert!( + src.contains("dec_key: *dec_key,") && src.contains("enc_key: *enc_key,"), + "HandshakeSuccess construction must copy from Zeroizing wrappers so loop-local key material is dropped and zeroized" + ); +} diff --git a/src/proxy/tests/handshake_more_clever_tests.rs b/src/proxy/tests/handshake_more_clever_tests.rs new file mode 100644 index 0000000..9782469 --- /dev/null +++ b/src/proxy/tests/handshake_more_clever_tests.rs @@ -0,0 +1,686 @@ +use super::*; +use crate::crypto::{AesCtr, sha256, sha256_hmac}; +use crate::protocol::constants::{ProtoTag, RESERVED_NONCE_BEGINNINGS, RESERVED_NONCE_FIRST_BYTES}; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; +use std::collections::HashSet; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tokio::sync::Barrier; + +// --- Helpers --- + +fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> { + auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig { + let mut cfg = ProxyConfig::default(); + cfg.access.users.clear(); + cfg.access + .users + .insert("user".to_string(), secret_hex.to_string()); + cfg.access.ignore_time_skew = true; + cfg.general.modes.secure = true; + cfg.general.modes.classic = true; + cfg.general.modes.tls = true; + cfg +} + +fn make_valid_tls_handshake(secret: &[u8], timestamp: u32) -> Vec { + let session_id_len: usize = 32; + let len = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + session_id_len; + let mut handshake = vec![0x42u8; len]; + + handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8; + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + + let computed = sha256_hmac(secret, &handshake); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] + .copy_from_slice(&digest); + handshake +} + +fn make_valid_mtproto_handshake( + secret_hex: &str, + proto_tag: ProtoTag, + dc_idx: i16, +) -> [u8; HANDSHAKE_LEN] { + let secret = hex::decode(secret_hex).expect("secret hex must decode"); + let mut handshake = [0x5Au8; HANDSHAKE_LEN]; + for (idx, b) in handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN] + .iter_mut() + .enumerate() + { + *b = (idx as u8).wrapping_add(1); + } + + let dec_prekey = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN]; + let dec_iv_bytes = &handshake[SKIP_LEN + PREKEY_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]; + + let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len()); + dec_key_input.extend_from_slice(dec_prekey); + dec_key_input.extend_from_slice(&secret); + let dec_key = sha256(&dec_key_input); + + let mut dec_iv_arr = [0u8; IV_LEN]; + dec_iv_arr.copy_from_slice(dec_iv_bytes); + let dec_iv = u128::from_be_bytes(dec_iv_arr); + + let mut stream = AesCtr::new(&dec_key, dec_iv); + let keystream = stream.encrypt(&[0u8; HANDSHAKE_LEN]); + + let mut target_plain = [0u8; HANDSHAKE_LEN]; + target_plain[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes()); + target_plain[DC_IDX_POS..DC_IDX_POS + 2].copy_from_slice(&dc_idx.to_le_bytes()); + + for idx in PROTO_TAG_POS..HANDSHAKE_LEN { + handshake[idx] = target_plain[idx] ^ keystream[idx]; + } + + handshake +} + +fn make_valid_tls_client_hello_with_sni_and_alpn( + secret: &[u8], + timestamp: u32, + sni_host: &str, + alpn_protocols: &[&[u8]], +) -> Vec { + let mut body = Vec::new(); + body.extend_from_slice(&TLS_VERSION); + body.extend_from_slice(&[0u8; 32]); + body.push(32); + body.extend_from_slice(&[0x42u8; 32]); + body.extend_from_slice(&2u16.to_be_bytes()); + body.extend_from_slice(&[0x13, 0x01]); + body.push(1); + body.push(0); + + let mut ext_blob = Vec::new(); + + let host_bytes = sni_host.as_bytes(); + let mut sni_payload = Vec::new(); + sni_payload.extend_from_slice(&((host_bytes.len() + 3) as u16).to_be_bytes()); + sni_payload.push(0); + sni_payload.extend_from_slice(&(host_bytes.len() as u16).to_be_bytes()); + sni_payload.extend_from_slice(host_bytes); + ext_blob.extend_from_slice(&0x0000u16.to_be_bytes()); + ext_blob.extend_from_slice(&(sni_payload.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&sni_payload); + + if !alpn_protocols.is_empty() { + let mut alpn_list = Vec::new(); + for proto in alpn_protocols { + alpn_list.push(proto.len() as u8); + alpn_list.extend_from_slice(proto); + } + let mut alpn_data = Vec::new(); + alpn_data.extend_from_slice(&(alpn_list.len() as u16).to_be_bytes()); + alpn_data.extend_from_slice(&alpn_list); + + ext_blob.extend_from_slice(&0x0010u16.to_be_bytes()); + ext_blob.extend_from_slice(&(alpn_data.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&alpn_data); + } + + body.extend_from_slice(&(ext_blob.len() as u16).to_be_bytes()); + body.extend_from_slice(&ext_blob); + + let mut handshake = Vec::new(); + handshake.push(0x01); + let body_len = (body.len() as u32).to_be_bytes(); + handshake.extend_from_slice(&body_len[1..4]); + handshake.extend_from_slice(&body); + + let mut record = Vec::new(); + record.push(TLS_RECORD_HANDSHAKE); + record.extend_from_slice(&[0x03, 0x01]); + record.extend_from_slice(&(handshake.len() as u16).to_be_bytes()); + record.extend_from_slice(&handshake); + + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + let computed = sha256_hmac(secret, &record); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].copy_from_slice(&digest); + + record +} + +// --- Category 1: Timing & Delay Invariants --- + +#[tokio::test] +async fn server_hello_delay_bypassed_if_max_is_zero_despite_high_min() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret = [0x1Au8; 16]; + let mut config = test_config_with_secret_hex("1a1a1a1a1a1a1a1a1a1a1a1a1a1a1a1a"); + config.censorship.server_hello_delay_min_ms = 5000; + config.censorship.server_hello_delay_max_ms = 0; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "192.0.2.101:12345".parse().unwrap(); + + let mut invalid_handshake = make_valid_tls_handshake(&secret, 0); + invalid_handshake[tls::TLS_DIGEST_POS] ^= 0xFF; + + let fut = handle_tls_handshake( + &invalid_handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ); + + // Deterministic assertion: with max_ms == 0 there must be no sleep path, + // so the handshake should complete promptly under a generous timeout budget. + let res = tokio::time::timeout(Duration::from_millis(250), fut) + .await + .expect("max_ms=0 should bypass artificial delay and complete quickly"); + + assert!(matches!(res, HandshakeResult::BadClient { .. })); +} + +#[test] +fn auth_probe_backoff_extreme_fail_streak_clamps_safely() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let state = auth_probe_state_map(); + let peer_ip = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 99)); + let now = Instant::now(); + + state.insert( + peer_ip, + AuthProbeState { + fail_streak: u32::MAX - 1, + blocked_until: now, + last_seen: now, + }, + ); + + auth_probe_record_failure_with_state(&state, peer_ip, now); + + let updated = state.get(&peer_ip).unwrap(); + assert_eq!(updated.fail_streak, u32::MAX); + + let expected_blocked_until = now + Duration::from_millis(AUTH_PROBE_BACKOFF_MAX_MS); + assert_eq!( + updated.blocked_until, expected_blocked_until, + "Extreme fail streak must clamp cleanly to AUTH_PROBE_BACKOFF_MAX_MS" + ); +} + +#[test] +fn generate_tg_nonce_cryptographic_uniqueness_and_entropy() { + let client_enc_key = [0x2Bu8; 32]; + let client_enc_iv = 1337u128; + let rng = SecureRandom::new(); + + let mut nonces = HashSet::new(); + let mut total_set_bits = 0usize; + let iterations = 5_000; + + for _ in 0..iterations { + let (nonce, _, _, _, _) = generate_tg_nonce( + ProtoTag::Secure, + 2, + &client_enc_key, + client_enc_iv, + &rng, + false, + ); + + for byte in nonce.iter() { + total_set_bits += byte.count_ones() as usize; + } + + assert!( + nonces.insert(nonce), + "generate_tg_nonce emitted a duplicate nonce! RNG is stuck." + ); + } + + let total_bits = iterations * HANDSHAKE_LEN * 8; + let ratio = (total_set_bits as f64) / (total_bits as f64); + assert!( + ratio > 0.48 && ratio < 0.52, + "Nonce entropy is degraded. Set bit ratio: {}", + ratio + ); +} + +#[tokio::test] +async fn mtproto_multi_user_decryption_isolation() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let mut config = ProxyConfig::default(); + config.general.modes.secure = true; + config.access.ignore_time_skew = true; + + config.access.users.insert( + "user_a".to_string(), + "11111111111111111111111111111111".to_string(), + ); + config.access.users.insert( + "user_b".to_string(), + "22222222222222222222222222222222".to_string(), + ); + let good_secret_hex = "33333333333333333333333333333333"; + config + .access + .users + .insert("user_c".to_string(), good_secret_hex.to_string()); + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let peer: SocketAddr = "192.0.2.104:12345".parse().unwrap(); + + let valid_handshake = make_valid_mtproto_handshake(good_secret_hex, ProtoTag::Secure, 1); + + let res = handle_mtproto_handshake( + &valid_handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + + match res { + HandshakeResult::Success((_, _, success)) => { + assert_eq!( + success.user, "user_c", + "Decryption attempts on previous users must not corrupt the handshake buffer for the valid user" + ); + } + _ => panic!( + "Multi-user MTProto handshake failed. Decryption buffer might be mutating in place." + ), + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn invalid_secret_warning_lock_contention_and_bound() { + let _guard = warned_secrets_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_warned_secrets_for_testing(); + + let tasks = 50; + let iterations_per_task = 100; + let barrier = Arc::new(Barrier::new(tasks)); + let mut handles = Vec::new(); + + for t in 0..tasks { + let b = barrier.clone(); + handles.push(tokio::spawn(async move { + b.wait().await; + for i in 0..iterations_per_task { + let user_name = format!("contention_user_{}_{}", t, i); + warn_invalid_secret_once(&user_name, "invalid_hex", ACCESS_SECRET_BYTES, None); + } + })); + } + + for handle in handles { + handle.await.unwrap(); + } + + let warned = INVALID_SECRET_WARNED.get().unwrap(); + let guard = warned + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + + assert_eq!( + guard.len(), + WARNED_SECRET_MAX_ENTRIES, + "Concurrent spam of invalid secrets must strictly bound the HashSet memory to WARNED_SECRET_MAX_ENTRIES" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn mtproto_strict_concurrent_replay_race_condition() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret_hex = "4A4A4A4A4A4A4A4A4A4A4A4A4A4A4A4A"; + let config = Arc::new(test_config_with_secret_hex(secret_hex)); + let replay_checker = Arc::new(ReplayChecker::new(4096, Duration::from_secs(60))); + let valid_handshake = Arc::new(make_valid_mtproto_handshake( + secret_hex, + ProtoTag::Secure, + 1, + )); + + let tasks = 100; + let barrier = Arc::new(Barrier::new(tasks)); + let mut handles = Vec::new(); + + for i in 0..tasks { + let b = barrier.clone(); + let cfg = config.clone(); + let rc = replay_checker.clone(); + let hs = valid_handshake.clone(); + + handles.push(tokio::spawn(async move { + let peer = SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(10, 0, 0, (i % 250) as u8)), + 10000 + i as u16, + ); + b.wait().await; + handle_mtproto_handshake( + &hs, + tokio::io::empty(), + tokio::io::sink(), + peer, + &cfg, + &rc, + false, + None, + ) + .await + })); + } + + let mut successes = 0; + let mut failures = 0; + + for handle in handles { + match handle.await.unwrap() { + HandshakeResult::Success(_) => successes += 1, + HandshakeResult::BadClient { .. } => failures += 1, + _ => panic!("Unexpected error result in concurrent MTProto replay test"), + } + } + + assert_eq!( + successes, 1, + "Replay cache race condition allowed multiple identical MTProto handshakes to succeed" + ); + assert_eq!( + failures, + tasks - 1, + "Replay cache failed to forcefully reject concurrent duplicates" + ); +} + +#[tokio::test] +async fn tls_alpn_zero_length_protocol_handled_safely() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret = [0x5Bu8; 16]; + let mut config = test_config_with_secret_hex("5b5b5b5b5b5b5b5b5b5b5b5b5b5b5b5b"); + config.censorship.alpn_enforce = true; + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "192.0.2.107:12345".parse().unwrap(); + + let handshake = + make_valid_tls_client_hello_with_sni_and_alpn(&secret, 0, "example.com", &[b""]); + + let res = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!( + matches!(res, HandshakeResult::BadClient { .. }), + "0-length ALPN must be safely rejected without panicking" + ); +} + +#[tokio::test] +async fn tls_sni_massive_hostname_does_not_panic() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret = [0x6Cu8; 16]; + let config = test_config_with_secret_hex("6c6c6c6c6c6c6c6c6c6c6c6c6c6c6c6c"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "192.0.2.108:12345".parse().unwrap(); + + let massive_hostname = String::from_utf8(vec![b'a'; 65000]).unwrap(); + let handshake = + make_valid_tls_client_hello_with_sni_and_alpn(&secret, 0, &massive_hostname, &[]); + + let res = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!( + matches!( + res, + HandshakeResult::Success(_) | HandshakeResult::BadClient { .. } + ), + "Massive SNI hostname must be processed or ignored without stack overflow or panic" + ); +} + +#[tokio::test] +async fn tls_progressive_truncation_fuzzing_no_panics() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret = [0x7Du8; 16]; + let config = test_config_with_secret_hex("7d7d7d7d7d7d7d7d7d7d7d7d7d7d7d7d"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "192.0.2.109:12345".parse().unwrap(); + + let valid_handshake = + make_valid_tls_client_hello_with_sni_and_alpn(&secret, 0, "example.com", &[b"h2"]); + let full_len = valid_handshake.len(); + + // Truncated corpus only: full_len is a valid baseline and should not be + // asserted as BadClient in a truncation-specific test. + for i in (0..full_len).rev() { + let truncated = &valid_handshake[..i]; + let res = handle_tls_handshake( + truncated, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!( + matches!(res, HandshakeResult::BadClient { .. }), + "Truncated TLS handshake at len {} must fail safely without panicking", + i + ); + } +} + +#[tokio::test] +async fn mtproto_pure_entropy_fuzzing_no_panics() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let config = test_config_with_secret_hex("8e8e8e8e8e8e8e8e8e8e8e8e8e8e8e8e"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let peer: SocketAddr = "192.0.2.110:12345".parse().unwrap(); + + let mut seeded = StdRng::seed_from_u64(0xDEADBEEFCAFE); + + for _ in 0..10_000 { + let mut noise = [0u8; HANDSHAKE_LEN]; + seeded.fill_bytes(&mut noise); + + let res = handle_mtproto_handshake( + &noise, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + + assert!( + matches!(res, HandshakeResult::BadClient { .. }), + "Pure entropy MTProto payload must fail closed and never panic" + ); + } +} + +#[test] +fn decode_user_secret_odd_length_hex_rejection() { + let _guard = warned_secrets_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_warned_secrets_for_testing(); + + let mut config = ProxyConfig::default(); + config.access.users.clear(); + config.access.users.insert( + "odd_user".to_string(), + "1234567890123456789012345678901".to_string(), + ); + + let decoded = decode_user_secrets(&config, None); + assert!( + decoded.is_empty(), + "Odd-length hex string must be gracefully rejected by hex::decode without unwrapping" + ); +} + +#[test] +fn saturation_grace_pre_existing_high_fail_streak_immediate_throttle() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let state = auth_probe_state_map(); + let peer_ip = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 112)); + let now = Instant::now(); + + let extreme_streak = AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS + 5; + state.insert( + peer_ip, + AuthProbeState { + fail_streak: extreme_streak, + blocked_until: now + Duration::from_secs(5), + last_seen: now, + }, + ); + + { + let mut guard = auth_probe_saturation_state_lock(); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(5), + last_seen: now, + }); + } + + let is_throttled = auth_probe_should_apply_preauth_throttle(peer_ip, now); + assert!( + is_throttled, + "A peer with a pre-existing high fail streak must be immediately throttled when saturation begins, receiving no unearned grace period" + ); +} + +#[test] +fn auth_probe_saturation_note_resets_retention_window() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let base_time = Instant::now(); + + auth_probe_note_saturation(base_time); + let later = base_time + Duration::from_secs(AUTH_PROBE_TRACK_RETENTION_SECS - 1); + auth_probe_note_saturation(later); + + let check_time = base_time + Duration::from_secs(AUTH_PROBE_TRACK_RETENTION_SECS + 5); + + // This call may return false if backoff has elapsed, but it must not clear + // the saturation state because `later` refreshed last_seen. + let _ = auth_probe_saturation_is_throttled_at_for_testing(check_time); + let guard = auth_probe_saturation_state_lock(); + assert!( + guard.is_some(), + "Ongoing saturation notes must refresh last_seen so saturation state remains retained past the original window" + ); +} + +#[test] +fn mtproto_classic_tags_rejected_when_only_secure_mode_enabled() { + let mut config = ProxyConfig::default(); + config.general.modes.classic = false; + config.general.modes.secure = true; + config.general.modes.tls = false; + + assert!(!mode_enabled_for_proto(&config, ProtoTag::Abridged, false)); + assert!(!mode_enabled_for_proto( + &config, + ProtoTag::Intermediate, + false + )); +} + +#[test] +fn mtproto_secure_tag_rejected_when_only_classic_mode_enabled() { + let mut config = ProxyConfig::default(); + config.general.modes.classic = true; + config.general.modes.secure = false; + config.general.modes.tls = false; + + assert!(!mode_enabled_for_proto(&config, ProtoTag::Secure, false)); +} + +#[test] +fn ipv6_localhost_and_unspecified_normalization() { + let localhost = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)); + let unspecified = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)); + + let norm_local = normalize_auth_probe_ip(localhost); + let norm_unspec = normalize_auth_probe_ip(unspecified); + + let expected_bucket = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)); + + assert_eq!(norm_local, expected_bucket); + assert_eq!(norm_unspec, expected_bucket); +} diff --git a/src/proxy/tests/handshake_real_bug_stress_tests.rs b/src/proxy/tests/handshake_real_bug_stress_tests.rs new file mode 100644 index 0000000..1e27ed5 --- /dev/null +++ b/src/proxy/tests/handshake_real_bug_stress_tests.rs @@ -0,0 +1,340 @@ +use super::*; +use crate::crypto::{AesCtr, SecureRandom, sha256, sha256_hmac}; +use crate::protocol::constants::{ProtoTag, TLS_RECORD_HANDSHAKE, TLS_VERSION}; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tokio::sync::Barrier; + +fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> { + auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig { + let mut cfg = ProxyConfig::default(); + cfg.access.users.clear(); + cfg.access + .users + .insert("user".to_string(), secret_hex.to_string()); + cfg.access.ignore_time_skew = true; + cfg.general.modes.secure = true; + cfg.general.modes.classic = true; + cfg.general.modes.tls = true; + cfg +} + +fn make_valid_tls_client_hello_with_alpn( + secret: &[u8], + timestamp: u32, + alpn_protocols: &[&[u8]], +) -> Vec { + let mut body = Vec::new(); + body.extend_from_slice(&TLS_VERSION); + body.extend_from_slice(&[0u8; 32]); + body.push(32); + body.extend_from_slice(&[0x42u8; 32]); + body.extend_from_slice(&2u16.to_be_bytes()); + body.extend_from_slice(&[0x13, 0x01]); + body.push(1); + body.push(0); + + let mut ext_blob = Vec::new(); + if !alpn_protocols.is_empty() { + let mut alpn_list = Vec::new(); + for proto in alpn_protocols { + alpn_list.push(proto.len() as u8); + alpn_list.extend_from_slice(proto); + } + let mut alpn_data = Vec::new(); + alpn_data.extend_from_slice(&(alpn_list.len() as u16).to_be_bytes()); + alpn_data.extend_from_slice(&alpn_list); + + ext_blob.extend_from_slice(&0x0010u16.to_be_bytes()); + ext_blob.extend_from_slice(&(alpn_data.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&alpn_data); + } + + body.extend_from_slice(&(ext_blob.len() as u16).to_be_bytes()); + body.extend_from_slice(&ext_blob); + + let mut handshake = Vec::new(); + handshake.push(0x01); + let body_len = (body.len() as u32).to_be_bytes(); + handshake.extend_from_slice(&body_len[1..4]); + handshake.extend_from_slice(&body); + + let mut record = Vec::new(); + record.push(TLS_RECORD_HANDSHAKE); + record.extend_from_slice(&[0x03, 0x01]); + record.extend_from_slice(&(handshake.len() as u16).to_be_bytes()); + record.extend_from_slice(&handshake); + + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + + let computed = sha256_hmac(secret, &record); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].copy_from_slice(&digest); + + record +} + +fn make_valid_mtproto_handshake( + secret_hex: &str, + proto_tag: ProtoTag, + dc_idx: i16, +) -> [u8; HANDSHAKE_LEN] { + let secret = hex::decode(secret_hex).expect("secret hex must decode"); + let mut handshake = [0x5Au8; HANDSHAKE_LEN]; + for (idx, b) in handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN] + .iter_mut() + .enumerate() + { + *b = (idx as u8).wrapping_add(1); + } + + let dec_prekey = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN]; + let dec_iv_bytes = &handshake[SKIP_LEN + PREKEY_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]; + + let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len()); + dec_key_input.extend_from_slice(dec_prekey); + dec_key_input.extend_from_slice(&secret); + let dec_key = sha256(&dec_key_input); + + let mut dec_iv_arr = [0u8; IV_LEN]; + dec_iv_arr.copy_from_slice(dec_iv_bytes); + let dec_iv = u128::from_be_bytes(dec_iv_arr); + + let mut stream = AesCtr::new(&dec_key, dec_iv); + let keystream = stream.encrypt(&[0u8; HANDSHAKE_LEN]); + + let mut target_plain = [0u8; HANDSHAKE_LEN]; + target_plain[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes()); + target_plain[DC_IDX_POS..DC_IDX_POS + 2].copy_from_slice(&dc_idx.to_le_bytes()); + + for idx in PROTO_TAG_POS..HANDSHAKE_LEN { + handshake[idx] = target_plain[idx] ^ keystream[idx]; + } + + handshake +} + +#[tokio::test] +async fn tls_alpn_reject_does_not_pollute_replay_cache() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret = [0x11u8; 16]; + let mut config = test_config_with_secret_hex("11111111111111111111111111111111"); + config.censorship.alpn_enforce = true; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "192.0.2.201:12345".parse().unwrap(); + + let handshake = make_valid_tls_client_hello_with_alpn(&secret, 0, &[b"h3"]); + let before = replay_checker.stats(); + + let res = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + let after = replay_checker.stats(); + + assert!(matches!(res, HandshakeResult::BadClient { .. })); + assert_eq!( + before.total_additions, after.total_additions, + "ALPN policy reject must not add TLS digest into replay cache" + ); +} + +#[tokio::test] +async fn tls_truncated_session_id_len_fails_closed_without_panic() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let config = test_config_with_secret_hex("33333333333333333333333333333333"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "192.0.2.203:12345".parse().unwrap(); + + let min_len = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1; + let mut malicious = vec![0x42u8; min_len]; + malicious[min_len - 1] = u8::MAX; + + let res = handle_tls_handshake( + &malicious, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(res, HandshakeResult::BadClient { .. })); +} + +#[test] +fn auth_probe_eviction_identical_timestamps_keeps_map_bounded() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let state = auth_probe_state_map(); + let same = Instant::now(); + + for i in 0..AUTH_PROBE_TRACK_MAX_ENTRIES { + let ip = IpAddr::V4(Ipv4Addr::new(10, 1, (i >> 8) as u8, (i & 0xFF) as u8)); + state.insert( + ip, + AuthProbeState { + fail_streak: 7, + blocked_until: same, + last_seen: same, + }, + ); + } + + let new_ip = IpAddr::V4(Ipv4Addr::new(192, 168, 21, 21)); + auth_probe_record_failure_with_state(state, new_ip, same + Duration::from_millis(1)); + + assert_eq!(state.len(), AUTH_PROBE_TRACK_MAX_ENTRIES); + assert!(state.contains_key(&new_ip)); +} + +#[test] +fn clear_auth_probe_state_recovers_from_poisoned_saturation_lock() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let saturation = auth_probe_saturation_state(); + let poison_thread = std::thread::spawn(move || { + let _hold = saturation + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + panic!("intentional poison for regression coverage"); + }); + let _ = poison_thread.join(); + + clear_auth_probe_state_for_testing(); + + let guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + assert!(guard.is_none()); +} + +#[tokio::test] +async fn mtproto_invalid_length_secret_is_ignored_and_valid_user_still_auths() { + let _probe_guard = auth_probe_test_guard(); + let _warn_guard = warned_secrets_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + clear_warned_secrets_for_testing(); + + let mut config = ProxyConfig::default(); + config.general.modes.secure = true; + config.access.ignore_time_skew = true; + + config.access.users.insert( + "short_user".to_string(), + "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaa".to_string(), + ); + + let valid_secret_hex = "77777777777777777777777777777777"; + config + .access + .users + .insert("good_user".to_string(), valid_secret_hex.to_string()); + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let peer: SocketAddr = "192.0.2.207:12345".parse().unwrap(); + let handshake = make_valid_mtproto_handshake(valid_secret_hex, ProtoTag::Secure, 1); + + let res = handle_mtproto_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + + assert!(matches!(res, HandshakeResult::Success(_))); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn saturation_grace_exhaustion_under_concurrency_keeps_peer_throttled() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let peer_ip = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 80)); + let now = Instant::now(); + + { + let mut guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(5), + last_seen: now, + }); + } + + let state = auth_probe_state_map(); + state.insert( + peer_ip, + AuthProbeState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS - 1, + blocked_until: now, + last_seen: now, + }, + ); + + let tasks = 32; + let barrier = Arc::new(Barrier::new(tasks)); + let mut handles = Vec::new(); + + for _ in 0..tasks { + let b = barrier.clone(); + handles.push(tokio::spawn(async move { + b.wait().await; + auth_probe_record_failure(peer_ip, Instant::now()); + })); + } + + for handle in handles { + handle.await.unwrap(); + } + + let final_state = state.get(&peer_ip).expect("state must exist"); + assert!( + final_state.fail_streak + >= AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS + ); + assert!(auth_probe_should_apply_preauth_throttle( + peer_ip, + Instant::now() + )); +} diff --git a/src/proxy/tests/handshake_security_tests.rs b/src/proxy/tests/handshake_security_tests.rs index d06f63e..0e43d35 100644 --- a/src/proxy/tests/handshake_security_tests.rs +++ b/src/proxy/tests/handshake_security_tests.rs @@ -956,6 +956,89 @@ async fn stress_tls_sni_preferred_user_hint_scales_to_large_user_set() { } } +#[tokio::test] +async fn tls_unknown_sni_drop_policy_returns_hard_error() { + let secret = [0x48u8; 16]; + let mut config = test_config_with_secret_hex("48484848484848484848484848484848"); + config.censorship.unknown_sni_action = UnknownSniAction::Drop; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.190:44326".parse().unwrap(); + let handshake = + make_valid_tls_client_hello_with_sni_and_alpn(&secret, 0, "unknown.example", &[b"h2"]); + + let result = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!( + result, + HandshakeResult::Error(ProxyError::UnknownTlsSni) + )); +} + +#[tokio::test] +async fn tls_unknown_sni_mask_policy_falls_back_to_bad_client() { + let secret = [0x49u8; 16]; + let mut config = test_config_with_secret_hex("49494949494949494949494949494949"); + config.censorship.unknown_sni_action = UnknownSniAction::Mask; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.191:44326".parse().unwrap(); + let handshake = + make_valid_tls_client_hello_with_sni_and_alpn(&secret, 0, "unknown.example", &[b"h2"]); + + let result = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(result, HandshakeResult::BadClient { .. })); +} + +#[tokio::test] +async fn tls_missing_sni_keeps_legacy_auth_path() { + let secret = [0x4Au8; 16]; + let mut config = test_config_with_secret_hex("4a4a4a4a4a4a4a4a4a4a4a4a4a4a4a4a"); + config.censorship.unknown_sni_action = UnknownSniAction::Drop; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.192:44326".parse().unwrap(); + let handshake = make_valid_tls_handshake(&secret, 0); + + let result = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(result, HandshakeResult::Success(_))); +} + #[tokio::test] async fn alpn_enforce_rejects_unsupported_client_alpn() { let secret = [0x33u8; 16]; @@ -1560,6 +1643,32 @@ fn auth_probe_capacity_fresh_full_map_still_tracks_newcomer_with_bounded_evictio ); } +#[test] +fn unknown_sni_warn_cooldown_first_event_is_warn_and_repeated_events_are_info_until_window_expires() +{ + let _guard = unknown_sni_warn_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_unknown_sni_warn_state_for_testing(); + + let now = Instant::now(); + + assert!( + should_emit_unknown_sni_warn_for_testing(now), + "first unknown SNI event must be eligible for WARN emission" + ); + assert!( + !should_emit_unknown_sni_warn_for_testing(now + Duration::from_secs(1)), + "events inside cooldown window must be demoted from WARN to INFO" + ); + assert!( + should_emit_unknown_sni_warn_for_testing( + now + Duration::from_secs(UNKNOWN_SNI_WARN_COOLDOWN_SECS) + ), + "once cooldown expires, next unknown SNI event must be WARN-eligible again" + ); +} + #[test] fn stress_auth_probe_full_map_churn_keeps_bound_and_tracks_newcomers() { let _guard = auth_probe_test_lock() diff --git a/src/proxy/tests/handshake_timing_manual_bench_tests.rs b/src/proxy/tests/handshake_timing_manual_bench_tests.rs new file mode 100644 index 0000000..13d112c --- /dev/null +++ b/src/proxy/tests/handshake_timing_manual_bench_tests.rs @@ -0,0 +1,318 @@ +use super::*; +use crate::crypto::{AesCtr, SecureRandom, sha256, sha256_hmac}; +use crate::protocol::constants::{ProtoTag, TLS_RECORD_HANDSHAKE, TLS_VERSION}; +use std::net::SocketAddr; +use std::time::{Duration, Instant}; + +fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> { + auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +fn make_valid_mtproto_handshake( + secret_hex: &str, + proto_tag: ProtoTag, + dc_idx: i16, + salt: u8, +) -> [u8; HANDSHAKE_LEN] { + let secret = hex::decode(secret_hex).expect("secret hex must decode"); + let mut handshake = [0x5Au8; HANDSHAKE_LEN]; + + for (idx, b) in handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN] + .iter_mut() + .enumerate() + { + *b = (idx as u8).wrapping_add(1).wrapping_add(salt); + } + + let dec_prekey = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN]; + let dec_iv_bytes = &handshake[SKIP_LEN + PREKEY_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]; + + let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len()); + dec_key_input.extend_from_slice(dec_prekey); + dec_key_input.extend_from_slice(&secret); + let dec_key = sha256(&dec_key_input); + + let mut dec_iv_arr = [0u8; IV_LEN]; + dec_iv_arr.copy_from_slice(dec_iv_bytes); + let dec_iv = u128::from_be_bytes(dec_iv_arr); + + let mut stream = AesCtr::new(&dec_key, dec_iv); + let keystream = stream.encrypt(&[0u8; HANDSHAKE_LEN]); + + let mut target_plain = [0u8; HANDSHAKE_LEN]; + target_plain[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes()); + target_plain[DC_IDX_POS..DC_IDX_POS + 2].copy_from_slice(&dc_idx.to_le_bytes()); + + for idx in PROTO_TAG_POS..HANDSHAKE_LEN { + handshake[idx] = target_plain[idx] ^ keystream[idx]; + } + + handshake +} + +fn make_valid_tls_handshake(secret: &[u8], timestamp: u32) -> Vec { + let session_id_len: usize = 32; + let len = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + session_id_len; + let mut handshake = vec![0x42u8; len]; + + handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8; + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + + let computed = sha256_hmac(secret, &handshake); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] + .copy_from_slice(&digest); + handshake +} + +fn make_valid_tls_client_hello_with_sni_and_alpn( + secret: &[u8], + timestamp: u32, + sni_host: &str, + alpn_protocols: &[&[u8]], +) -> Vec { + let mut body = Vec::new(); + body.extend_from_slice(&TLS_VERSION); + body.extend_from_slice(&[0u8; 32]); + body.push(32); + body.extend_from_slice(&[0x42u8; 32]); + body.extend_from_slice(&2u16.to_be_bytes()); + body.extend_from_slice(&[0x13, 0x01]); + body.push(1); + body.push(0); + + let mut ext_blob = Vec::new(); + let host_bytes = sni_host.as_bytes(); + let mut sni_payload = Vec::new(); + sni_payload.extend_from_slice(&((host_bytes.len() + 3) as u16).to_be_bytes()); + sni_payload.push(0); + sni_payload.extend_from_slice(&(host_bytes.len() as u16).to_be_bytes()); + sni_payload.extend_from_slice(host_bytes); + ext_blob.extend_from_slice(&0x0000u16.to_be_bytes()); + ext_blob.extend_from_slice(&(sni_payload.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&sni_payload); + + if !alpn_protocols.is_empty() { + let mut alpn_list = Vec::new(); + for proto in alpn_protocols { + alpn_list.push(proto.len() as u8); + alpn_list.extend_from_slice(proto); + } + let mut alpn_data = Vec::new(); + alpn_data.extend_from_slice(&(alpn_list.len() as u16).to_be_bytes()); + alpn_data.extend_from_slice(&alpn_list); + + ext_blob.extend_from_slice(&0x0010u16.to_be_bytes()); + ext_blob.extend_from_slice(&(alpn_data.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&alpn_data); + } + + body.extend_from_slice(&(ext_blob.len() as u16).to_be_bytes()); + body.extend_from_slice(&ext_blob); + + let mut handshake = Vec::new(); + handshake.push(0x01); + let body_len = (body.len() as u32).to_be_bytes(); + handshake.extend_from_slice(&body_len[1..4]); + handshake.extend_from_slice(&body); + + let mut record = Vec::new(); + record.push(TLS_RECORD_HANDSHAKE); + record.extend_from_slice(&[0x03, 0x01]); + record.extend_from_slice(&(handshake.len() as u16).to_be_bytes()); + record.extend_from_slice(&handshake); + + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + let computed = sha256_hmac(secret, &record); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].copy_from_slice(&digest); + + record +} + +fn median_ns(samples: &mut [u128]) -> u128 { + samples.sort_unstable(); + samples[samples.len() / 2] +} + +#[tokio::test] +#[ignore = "manual benchmark: timing-sensitive and host-dependent"] +async fn mtproto_user_scan_timing_manual_benchmark() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + const DECOY_USERS: usize = 8_000; + const ITERATIONS: usize = 250; + + let preferred_user = "target_user"; + let target_secret_hex = "dededededededededededededededede"; + + let mut config = ProxyConfig::default(); + config.general.modes.secure = true; + config.access.ignore_time_skew = true; + + for i in 0..DECOY_USERS { + config.access.users.insert( + format!("decoy_{i}"), + "00000000000000000000000000000000".to_string(), + ); + } + + config + .access + .users + .insert(preferred_user.to_string(), target_secret_hex.to_string()); + + let replay_checker_preferred = ReplayChecker::new(65_536, Duration::from_secs(60)); + let replay_checker_full_scan = ReplayChecker::new(65_536, Duration::from_secs(60)); + let peer_a: SocketAddr = "192.0.2.241:12345".parse().unwrap(); + let peer_b: SocketAddr = "192.0.2.242:12345".parse().unwrap(); + + let mut preferred_samples = Vec::with_capacity(ITERATIONS); + let mut full_scan_samples = Vec::with_capacity(ITERATIONS); + + for i in 0..ITERATIONS { + let handshake = make_valid_mtproto_handshake( + target_secret_hex, + ProtoTag::Secure, + 1 + i as i16, + (i % 251) as u8, + ); + + let started_preferred = Instant::now(); + let preferred = handle_mtproto_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer_a, + &config, + &replay_checker_preferred, + false, + Some(preferred_user), + ) + .await; + preferred_samples.push(started_preferred.elapsed().as_nanos()); + assert!(matches!(preferred, HandshakeResult::Success(_))); + + let started_scan = Instant::now(); + let full_scan = handle_mtproto_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer_b, + &config, + &replay_checker_full_scan, + false, + None, + ) + .await; + full_scan_samples.push(started_scan.elapsed().as_nanos()); + assert!(matches!(full_scan, HandshakeResult::Success(_))); + } + + let preferred_median = median_ns(&mut preferred_samples); + let full_scan_median = median_ns(&mut full_scan_samples); + + let ratio = if preferred_median == 0 { + 0.0 + } else { + full_scan_median as f64 / preferred_median as f64 + }; + + println!( + "manual timing benchmark: decoys={DECOY_USERS}, iters={ITERATIONS}, preferred_median_ns={preferred_median}, full_scan_median_ns={full_scan_median}, ratio={ratio:.3}" + ); + + assert!( + full_scan_median >= preferred_median, + "full user scan should not be faster than preferred-user path in this benchmark" + ); +} + +#[tokio::test] +#[ignore = "manual benchmark: timing-sensitive and host-dependent"] +async fn tls_sni_preferred_vs_no_sni_fallback_manual_benchmark() { + let _guard = auth_probe_test_guard(); + + const DECOY_USERS: usize = 8_000; + const ITERATIONS: usize = 250; + + let preferred_user = "user-b"; + let target_secret_hex = "abababababababababababababababab"; + let target_secret = [0xABu8; 16]; + + let mut config = ProxyConfig::default(); + config.general.modes.tls = true; + config.access.ignore_time_skew = true; + + for i in 0..DECOY_USERS { + config.access.users.insert( + format!("decoy_{i}"), + "00000000000000000000000000000000".to_string(), + ); + } + + config + .access + .users + .insert(preferred_user.to_string(), target_secret_hex.to_string()); + + let mut sni_samples = Vec::with_capacity(ITERATIONS); + let mut no_sni_samples = Vec::with_capacity(ITERATIONS); + + for i in 0..ITERATIONS { + let with_sni = make_valid_tls_client_hello_with_sni_and_alpn( + &target_secret, + i as u32, + preferred_user, + &[b"h2"], + ); + let no_sni = make_valid_tls_handshake(&target_secret, (i as u32).wrapping_add(10_000)); + + let started_sni = Instant::now(); + let sni_secrets = decode_user_secrets(&config, Some(preferred_user)); + let sni_result = tls::validate_tls_handshake_with_replay_window( + &with_sni, + &sni_secrets, + config.access.ignore_time_skew, + config.access.replay_window_secs, + ); + sni_samples.push(started_sni.elapsed().as_nanos()); + assert!(sni_result.is_some()); + + let started_no_sni = Instant::now(); + let no_sni_secrets = decode_user_secrets(&config, None); + let no_sni_result = tls::validate_tls_handshake_with_replay_window( + &no_sni, + &no_sni_secrets, + config.access.ignore_time_skew, + config.access.replay_window_secs, + ); + no_sni_samples.push(started_no_sni.elapsed().as_nanos()); + assert!(no_sni_result.is_some()); + } + + let sni_median = median_ns(&mut sni_samples); + let no_sni_median = median_ns(&mut no_sni_samples); + + let ratio = if sni_median == 0 { + 0.0 + } else { + no_sni_median as f64 / sni_median as f64 + }; + + println!( + "manual tls benchmark: decoys={DECOY_USERS}, iters={ITERATIONS}, sni_median_ns={sni_median}, no_sni_median_ns={no_sni_median}, ratio_no_sni_over_sni={ratio:.3}" + ); +} diff --git a/src/proxy/tests/masking_ab_envelope_blur_integration_security_tests.rs b/src/proxy/tests/masking_ab_envelope_blur_integration_security_tests.rs index 3e860e8..a977409 100644 --- a/src/proxy/tests/masking_ab_envelope_blur_integration_security_tests.rs +++ b/src/proxy/tests/masking_ab_envelope_blur_integration_security_tests.rs @@ -493,9 +493,12 @@ async fn timing_classifier_light_fuzz_pairwise_bucketed_accuracy_stays_bounded_u ]; let mut meaningful_improvement_seen = false; - let mut baseline_sum = 0.0f64; - let mut hardened_sum = 0.0f64; - let mut pair_count = 0usize; + let mut informative_baseline_sum = 0.0f64; + let mut informative_hardened_sum = 0.0f64; + let mut informative_pair_count = 0usize; + let mut low_info_baseline_sum = 0.0f64; + let mut low_info_hardened_sum = 0.0f64; + let mut low_info_pair_count = 0usize; let acc_quant_step = 1.0 / (2 * SAMPLE_COUNT) as f64; let tolerated_pair_regression = acc_quant_step + 0.03; @@ -522,6 +525,16 @@ async fn timing_classifier_light_fuzz_pairwise_bucketed_accuracy_stays_bounded_u hardened_acc <= baseline_acc + tolerated_pair_regression, "normalization should not materially worsen informative pair: baseline={baseline_acc:.3} hardened={hardened_acc:.3} tolerated={tolerated_pair_regression:.3}" ); + informative_baseline_sum += baseline_acc; + informative_hardened_sum += hardened_acc; + informative_pair_count += 1; + } else { + // Low-information pairs (near-random baseline separability) are expected + // to exhibit quantized jitter at low sample counts; do not fold them into + // strict average-regression checks used for informative side-channel signal. + low_info_baseline_sum += baseline_acc; + low_info_hardened_sum += hardened_acc; + low_info_pair_count += 1; } println!( @@ -531,20 +544,30 @@ async fn timing_classifier_light_fuzz_pairwise_bucketed_accuracy_stays_bounded_u if hardened_acc + 0.05 <= baseline_acc { meaningful_improvement_seen = true; } - - baseline_sum += baseline_acc; - hardened_sum += hardened_acc; - pair_count += 1; } - let baseline_avg = baseline_sum / pair_count as f64; - let hardened_avg = hardened_sum / pair_count as f64; + assert!( + informative_pair_count > 0, + "expected at least one informative pair for timing-separability guard" + ); + + let informative_baseline_avg = informative_baseline_sum / informative_pair_count as f64; + let informative_hardened_avg = informative_hardened_sum / informative_pair_count as f64; assert!( - hardened_avg <= baseline_avg + 0.10, - "normalization should not materially increase average pairwise separability: baseline_avg={baseline_avg:.3} hardened_avg={hardened_avg:.3}" + informative_hardened_avg <= informative_baseline_avg + 0.10, + "normalization should not materially increase informative average separability: baseline_avg={informative_baseline_avg:.3} hardened_avg={informative_hardened_avg:.3}" ); + if low_info_pair_count > 0 { + let low_info_baseline_avg = low_info_baseline_sum / low_info_pair_count as f64; + let low_info_hardened_avg = low_info_hardened_sum / low_info_pair_count as f64; + assert!( + low_info_hardened_avg <= low_info_baseline_avg + 0.40, + "normalization low-info average drift exceeded jitter budget: baseline_avg={low_info_baseline_avg:.3} hardened_avg={low_info_hardened_avg:.3}" + ); + } + // Optional signal only: do not require improvement on every run because // noisy CI schedulers can flatten pairwise differences at low sample counts. let _ = meaningful_improvement_seen; diff --git a/src/proxy/tests/masking_additional_hardening_security_tests.rs b/src/proxy/tests/masking_additional_hardening_security_tests.rs new file mode 100644 index 0000000..a6f6386 --- /dev/null +++ b/src/proxy/tests/masking_additional_hardening_security_tests.rs @@ -0,0 +1,126 @@ +use super::*; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::task::{Context, Poll}; +use tokio::io::AsyncRead; +use tokio::time::{Duration, timeout}; + +struct EndlessReader { + produced: Arc, +} + +impl AsyncRead for EndlessReader { + fn poll_read( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + let len = buf.remaining().max(1); + let fill = vec![0xAA; len]; + buf.put_slice(&fill); + self.produced.fetch_add(len, Ordering::Relaxed); + Poll::Ready(Ok(())) + } +} + +#[test] +fn loop_guard_unspecified_bind_uses_interface_inventory() { + let local: SocketAddr = "0.0.0.0:443".parse().unwrap(); + let resolved: SocketAddr = "192.168.44.10:443".parse().unwrap(); + let interfaces = vec!["192.168.44.10".parse().unwrap()]; + + assert!(is_mask_target_local_listener_with_interfaces( + "mask.example", + 443, + local, + Some(resolved), + &interfaces, + )); +} + +#[tokio::test] +async fn consume_client_data_stops_after_byte_cap_without_eof() { + let produced = Arc::new(AtomicUsize::new(0)); + let reader = EndlessReader { + produced: Arc::clone(&produced), + }; + let cap = 10_000usize; + + consume_client_data(reader, cap).await; + + let total = produced.load(Ordering::Relaxed); + assert!( + total >= cap, + "consume path must read at least up to cap before stopping" + ); + assert!( + total <= cap + 8192, + "consume path must stop within one read chunk above cap" + ); +} + +#[test] +fn masking_beobachten_minutes_zero_fail_closes_to_minimum_ttl() { + let mut config = ProxyConfig::default(); + config.general.beobachten = true; + config.general.beobachten_minutes = 0; + + let ttl = masking_beobachten_ttl(&config); + assert_eq!(ttl, std::time::Duration::from_secs(60)); +} + +#[test] +fn timing_normalization_zero_floor_safety_net_defaults_to_mask_timeout() { + let mut config = ProxyConfig::default(); + config.censorship.mask_timing_normalization_enabled = true; + config.censorship.mask_timing_normalization_floor_ms = 0; + config.censorship.mask_timing_normalization_ceiling_ms = 0; + + let budget = mask_outcome_target_budget(&config); + assert_eq!( + budget, + Duration::from_millis(0), + "zero floor/ceiling must produce zero extra normalization budget" + ); +} + +#[tokio::test] +async fn loop_guard_blocks_self_target_before_proxy_protocol_header_growth() { + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let accept_task = tokio::spawn(async move { + timeout(Duration::from_millis(120), listener.accept()) + .await + .is_ok() + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = None; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_proxy_protocol = 2; + + let peer: SocketAddr = "203.0.113.251:55991".parse().unwrap(); + let local_addr: SocketAddr = format!("0.0.0.0:{}", backend_addr.port()).parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + handle_bad_client( + tokio::io::empty(), + tokio::io::sink(), + b"GET / HTTP/1.1\r\n\r\n", + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + let accepted = accept_task.await.unwrap(); + assert!( + !accepted, + "loop guard must fail closed before any recursive PROXY protocol amplification" + ); +} diff --git a/src/proxy/tests/masking_aggressive_mode_security_tests.rs b/src/proxy/tests/masking_aggressive_mode_security_tests.rs index a77fc14..7356dc0 100644 --- a/src/proxy/tests/masking_aggressive_mode_security_tests.rs +++ b/src/proxy/tests/masking_aggressive_mode_security_tests.rs @@ -85,7 +85,10 @@ async fn aggressive_mode_shapes_backend_silent_non_eof_path() { let legacy = capture_forwarded_len_with_mode(body_sent, false, false, false, 0).await; let aggressive = capture_forwarded_len_with_mode(body_sent, false, true, false, 0).await; - assert!(legacy < floor, "legacy mode should keep timeout path unshaped"); + assert!( + legacy < floor, + "legacy mode should keep timeout path unshaped" + ); assert!( aggressive >= floor, "aggressive mode must shape backend-silent non-EOF paths (aggressive={aggressive}, floor={floor})" diff --git a/src/proxy/tests/masking_classification_completeness_security_tests.rs b/src/proxy/tests/masking_classification_completeness_security_tests.rs new file mode 100644 index 0000000..35bf87b --- /dev/null +++ b/src/proxy/tests/masking_classification_completeness_security_tests.rs @@ -0,0 +1,16 @@ +use super::*; + +#[test] +fn detect_client_type_recognizes_extended_http_probe_verbs() { + assert_eq!(detect_client_type(b"CONNECT / HTTP/1.1\r\n"), "HTTP"); + assert_eq!(detect_client_type(b"TRACE / HTTP/1.1\r\n"), "HTTP"); + assert_eq!(detect_client_type(b"PATCH / HTTP/1.1\r\n"), "HTTP"); +} + +#[test] +fn detect_client_type_recognizes_fragmented_http_method_prefixes() { + assert_eq!(detect_client_type(b"CO"), "HTTP"); + assert_eq!(detect_client_type(b"CON"), "HTTP"); + assert_eq!(detect_client_type(b"TR"), "HTTP"); + assert_eq!(detect_client_type(b"PAT"), "HTTP"); +} diff --git a/src/proxy/tests/masking_connect_failure_close_matrix_security_tests.rs b/src/proxy/tests/masking_connect_failure_close_matrix_security_tests.rs new file mode 100644 index 0000000..718189c --- /dev/null +++ b/src/proxy/tests/masking_connect_failure_close_matrix_security_tests.rs @@ -0,0 +1,126 @@ +use super::*; +use crate::network::dns_overrides::install_entries; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::time::{Duration, Instant, timeout}; + +async fn run_connect_failure_case( + host: &str, + port: u16, + timing_normalization_enabled: bool, + peer: SocketAddr, +) -> Duration { + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = None; + config.censorship.mask_host = Some(host.to_string()); + config.censorship.mask_port = port; + config.censorship.mask_timing_normalization_enabled = timing_normalization_enabled; + config.censorship.mask_timing_normalization_floor_ms = 120; + config.censorship.mask_timing_normalization_ceiling_ms = 120; + + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + let probe = b"CONNECT example.org:443 HTTP/1.1\r\nHost: example.org\r\n\r\n"; + + let (mut client_writer, client_reader) = duplex(1024); + let (mut client_visible_reader, client_visible_writer) = duplex(1024); + + let started = Instant::now(); + let task = tokio::spawn(async move { + handle_bad_client( + client_reader, + client_visible_writer, + probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + }); + + client_writer.shutdown().await.unwrap(); + + timeout(Duration::from_secs(4), task) + .await + .unwrap() + .unwrap(); + + let mut buf = [0u8; 1]; + let n = timeout(Duration::from_secs(1), client_visible_reader.read(&mut buf)) + .await + .unwrap() + .unwrap(); + assert_eq!( + n, 0, + "connect-failure path must close client-visible writer" + ); + + started.elapsed() +} + +#[tokio::test] +async fn connect_failure_refusal_close_behavior_matrix() { + let temp_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let unused_port = temp_listener.local_addr().unwrap().port(); + drop(temp_listener); + + for (idx, timing_normalization_enabled) in [false, true].into_iter().enumerate() { + let peer: SocketAddr = format!("203.0.113.210:{}", 54100 + idx as u16) + .parse() + .unwrap(); + let elapsed = + run_connect_failure_case("127.0.0.1", unused_port, timing_normalization_enabled, peer) + .await; + + if timing_normalization_enabled { + assert!( + elapsed >= Duration::from_millis(110) && elapsed < Duration::from_millis(250), + "normalized refusal path must honor configured timing envelope without stalling" + ); + } else { + assert!( + elapsed >= Duration::from_millis(40) && elapsed < Duration::from_millis(150), + "non-normalized refusal path must honor baseline connect budget without stalling" + ); + } + } +} + +#[tokio::test] +async fn connect_failure_overridden_hostname_close_behavior_matrix() { + let temp_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let unused_port = temp_listener.local_addr().unwrap().port(); + drop(temp_listener); + + // Make hostname resolution deterministic in tests so timing ceilings are meaningful. + install_entries(&[format!("mask.invalid:{}:127.0.0.1", unused_port)]).unwrap(); + + for (idx, timing_normalization_enabled) in [false, true].into_iter().enumerate() { + let peer: SocketAddr = format!("203.0.113.220:{}", 54200 + idx as u16) + .parse() + .unwrap(); + let elapsed = run_connect_failure_case( + "mask.invalid", + unused_port, + timing_normalization_enabled, + peer, + ) + .await; + + if timing_normalization_enabled { + assert!( + elapsed >= Duration::from_millis(110) && elapsed < Duration::from_millis(250), + "normalized overridden-host path must honor configured timing envelope without stalling" + ); + } else { + assert!( + elapsed >= Duration::from_millis(40) && elapsed < Duration::from_millis(150), + "non-normalized overridden-host path must honor baseline connect budget without stalling" + ); + } + } + + install_entries(&[]).unwrap(); +} diff --git a/src/proxy/tests/masking_consume_idle_timeout_security_tests.rs b/src/proxy/tests/masking_consume_idle_timeout_security_tests.rs new file mode 100644 index 0000000..f2c39a2 --- /dev/null +++ b/src/proxy/tests/masking_consume_idle_timeout_security_tests.rs @@ -0,0 +1,88 @@ +use super::*; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::time::Instant; +use tokio::io::{AsyncRead, ReadBuf}; + +struct OneByteThenStall { + sent: bool, +} + +impl AsyncRead for OneByteThenStall { + fn poll_read( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + if !self.sent { + self.sent = true; + buf.put_slice(&[0x42]); + Poll::Ready(Ok(())) + } else { + Poll::Pending + } + } +} + +#[tokio::test] +async fn stalling_client_terminates_at_idle_not_relay_timeout() { + let reader = OneByteThenStall { sent: false }; + let started = Instant::now(); + + let result = tokio::time::timeout( + MASK_RELAY_TIMEOUT, + consume_client_data(reader, MASK_BUFFER_SIZE * 4), + ) + .await; + + assert!( + result.is_ok(), + "consume_client_data should complete by per-read idle timeout, not hit relay timeout" + ); + + let elapsed = started.elapsed(); + assert!( + elapsed >= (MASK_RELAY_IDLE_TIMEOUT / 2), + "consume_client_data returned too quickly for idle-timeout path: {elapsed:?}" + ); + assert!( + elapsed < MASK_RELAY_TIMEOUT, + "consume_client_data waited full relay timeout ({elapsed:?}); \ + per-read idle timeout is missing" + ); +} + +#[tokio::test] +async fn fast_reader_drains_to_eof() { + let data = vec![0xAAu8; 32 * 1024]; + let reader = std::io::Cursor::new(data); + + tokio::time::timeout(MASK_RELAY_TIMEOUT, consume_client_data(reader, usize::MAX)) + .await + .expect("consume_client_data did not complete for fast EOF reader"); +} + +#[tokio::test] +async fn io_error_terminates_cleanly() { + struct ErrReader; + + impl AsyncRead for ErrReader { + fn poll_read( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _buf: &mut ReadBuf<'_>, + ) -> Poll> { + Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::ConnectionReset, + "simulated reset", + ))) + } + } + + tokio::time::timeout( + MASK_RELAY_TIMEOUT, + consume_client_data(ErrReader, usize::MAX), + ) + .await + .expect("consume_client_data did not return on I/O error"); +} diff --git a/src/proxy/tests/masking_consume_stress_adversarial_tests.rs b/src/proxy/tests/masking_consume_stress_adversarial_tests.rs new file mode 100644 index 0000000..12287b5 --- /dev/null +++ b/src/proxy/tests/masking_consume_stress_adversarial_tests.rs @@ -0,0 +1,64 @@ +use super::*; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::time::Instant; +use tokio::io::{AsyncRead, ReadBuf}; +use tokio::task::JoinSet; + +struct OneByteThenStall { + sent: bool, +} + +impl AsyncRead for OneByteThenStall { + fn poll_read( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + if !self.sent { + self.sent = true; + buf.put_slice(&[0xAA]); + Poll::Ready(Ok(())) + } else { + Poll::Pending + } + } +} + +#[tokio::test] +async fn consume_stall_stress_finishes_within_idle_budget() { + let mut set = JoinSet::new(); + let started = Instant::now(); + + for _ in 0..64 { + set.spawn(async { + tokio::time::timeout( + MASK_RELAY_TIMEOUT, + consume_client_data(OneByteThenStall { sent: false }, usize::MAX), + ) + .await + .expect("consume_client_data exceeded relay timeout under stall load"); + }); + } + + while let Some(res) = set.join_next().await { + res.unwrap(); + } + + // Under test constants idle=100ms, relay=200ms. 64 concurrent tasks stalling + // for 100ms should complete well under a strict 600ms boundary. + assert!( + started.elapsed() < MASK_RELAY_TIMEOUT * 3, + "stall stress batch completed too slowly; possible async executor starvation or head-of-line blocking" + ); +} + +#[tokio::test] +async fn consume_zero_cap_returns_immediately() { + let started = Instant::now(); + consume_client_data(tokio::io::empty(), 0).await; + assert!( + started.elapsed() < MASK_RELAY_IDLE_TIMEOUT, + "zero byte cap must return immediately" + ); +} diff --git a/src/proxy/tests/masking_extended_attack_surface_security_tests.rs b/src/proxy/tests/masking_extended_attack_surface_security_tests.rs new file mode 100644 index 0000000..650731c --- /dev/null +++ b/src/proxy/tests/masking_extended_attack_surface_security_tests.rs @@ -0,0 +1,225 @@ +use super::*; +use tokio::io::{AsyncWriteExt, duplex}; +use tokio::time::{Duration, Instant, timeout}; + +fn make_self_target_config( + timing_normalization_enabled: bool, + floor_ms: u64, + ceiling_ms: u64, + beobachten_enabled: bool, +) -> ProxyConfig { + let mut config = ProxyConfig::default(); + config.general.beobachten = beobachten_enabled; + config.general.beobachten_minutes = 5; + config.censorship.mask = true; + config.censorship.mask_unix_sock = None; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = 443; + config.censorship.mask_timing_normalization_enabled = timing_normalization_enabled; + config.censorship.mask_timing_normalization_floor_ms = floor_ms; + config.censorship.mask_timing_normalization_ceiling_ms = ceiling_ms; + config +} + +async fn run_self_target_refusal( + config: ProxyConfig, + peer: SocketAddr, + initial: &'static [u8], +) -> Duration { + let beobachten = BeobachtenStore::new(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().expect("valid local addr"); + + let (mut client, server) = duplex(1024); + let started = Instant::now(); + let task = tokio::spawn(async move { + handle_bad_client( + server, + tokio::io::sink(), + initial, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + }); + + client + .shutdown() + .await + .expect("client shutdown must succeed"); + + timeout(Duration::from_secs(3), task) + .await + .expect("self-target refusal must complete in bounded time") + .expect("self-target refusal task must not panic"); + + started.elapsed() +} + +#[tokio::test] +async fn positive_self_target_refusal_honors_normalization_floor() { + let config = make_self_target_config(true, 120, 120, false); + let peer: SocketAddr = "203.0.113.41:54041".parse().expect("valid peer"); + + let elapsed = run_self_target_refusal(config, peer, b"GET / HTTP/1.1\r\n\r\n").await; + + assert!( + elapsed >= Duration::from_millis(110) && elapsed < Duration::from_millis(260), + "normalized self-target refusal must stay within expected envelope" + ); +} + +#[tokio::test] +async fn negative_non_normalized_refusal_does_not_sleep_to_large_floor() { + let config = make_self_target_config(false, 240, 240, false); + let peer: SocketAddr = "203.0.113.42:54042".parse().expect("valid peer"); + + let elapsed = run_self_target_refusal(config, peer, b"GET / HTTP/1.1\r\n\r\n").await; + + assert!( + elapsed < Duration::from_millis(180), + "non-normalized path must not inherit normalization floor delays" + ); +} + +#[tokio::test] +async fn edge_ceiling_below_floor_uses_floor_fail_closed() { + let config = make_self_target_config(true, 140, 80, false); + let peer: SocketAddr = "203.0.113.43:54043".parse().expect("valid peer"); + + let elapsed = run_self_target_refusal(config, peer, b"GET / HTTP/1.1\r\n\r\n").await; + + assert!( + elapsed >= Duration::from_millis(130) && elapsed < Duration::from_millis(280), + "ceiling max { + max = elapsed; + } + assert!( + elapsed >= Duration::from_millis(100) && elapsed < Duration::from_millis(320), + "parallel probe latency must stay bounded under normalization" + ); + } + + assert!( + max.saturating_sub(min) <= Duration::from_millis(130), + "normalization should limit path variance across adversarial parallel probes" + ); +} + +#[tokio::test] +async fn integration_beobachten_records_probe_classification_on_refusal() { + let config = make_self_target_config(false, 0, 0, true); + let peer: SocketAddr = "198.51.100.71:55071".parse().expect("valid peer"); + let local_addr: SocketAddr = "127.0.0.1:443".parse().expect("valid local addr"); + let beobachten = BeobachtenStore::new(); + + let (mut client, server) = duplex(1024); + let task = tokio::spawn(async move { + handle_bad_client( + server, + tokio::io::sink(), + b"GET /classified HTTP/1.1\r\nHost: demo\r\n\r\n", + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + beobachten.snapshot_text(Duration::from_secs(60)) + }); + + client + .shutdown() + .await + .expect("client shutdown must succeed"); + + let snapshot = timeout(Duration::from_secs(3), task) + .await + .expect("integration task must complete") + .expect("integration task must not panic"); + + assert!(snapshot.contains("[HTTP]")); + assert!(snapshot.contains("198.51.100.71-1")); +} + +#[tokio::test] +async fn light_fuzz_timing_configuration_matrix_is_bounded() { + let mut seed = 0xA17E_55AA_2026_0323u64; + + for case in 0..48u64 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let enabled = (seed & 1) == 0; + let floor = (seed >> 8) % 180; + let ceiling = (seed >> 24) % 180; + let config = make_self_target_config(enabled, floor, ceiling, false); + let peer: SocketAddr = format!("203.0.113.90:{}", 56000 + (case as u16)) + .parse() + .expect("valid peer"); + + let elapsed = run_self_target_refusal(config, peer, b"HEAD /h HTTP/1.1\r\n\r\n").await; + + assert!( + elapsed < Duration::from_millis(420), + "fuzz case must stay bounded and never hang" + ); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_high_fanout_self_target_refusal_no_deadlock_or_timeout() { + let workers = 64usize; + let mut tasks = Vec::with_capacity(workers); + + for idx in 0..workers { + tasks.push(tokio::spawn(async move { + let config = make_self_target_config(false, 0, 0, false); + let peer: SocketAddr = format!("198.51.100.200:{}", 57000 + idx as u16) + .parse() + .expect("valid peer"); + run_self_target_refusal(config, peer, b"GET /stress HTTP/1.1\r\n\r\n").await + })); + } + + timeout(Duration::from_secs(5), async { + for task in tasks { + let elapsed = task.await.expect("stress task must not panic"); + assert!( + elapsed < Duration::from_millis(260), + "stress refusal must remain bounded without normalization" + ); + } + }) + .await + .expect("high-fanout refusal workload must complete without deadlock"); +} diff --git a/src/proxy/tests/masking_http2_preface_integration_security_tests.rs b/src/proxy/tests/masking_http2_preface_integration_security_tests.rs new file mode 100644 index 0000000..7f1c03f --- /dev/null +++ b/src/proxy/tests/masking_http2_preface_integration_security_tests.rs @@ -0,0 +1,55 @@ +use super::*; +use tokio::net::TcpListener; +use tokio::time::Duration; + +#[tokio::test] +async fn http2_preface_is_forwarded_and_recorded_as_http() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let preface = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn({ + let preface = preface.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut received = vec![0u8; preface.len()]; + stream.read_exact(&mut received).await.unwrap(); + assert_eq!(received, preface); + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = true; + config.general.beobachten_minutes = 1; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let peer: SocketAddr = "198.51.100.130:54130".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let (client_reader, _client_writer) = tokio::io::duplex(512); + let (_client_visible_reader, client_visible_writer) = tokio::io::duplex(512); + let beobachten = BeobachtenStore::new(); + + handle_bad_client( + client_reader, + client_visible_writer, + &preface, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + tokio::time::timeout(Duration::from_secs(2), accept_task) + .await + .unwrap() + .unwrap(); + + let snapshot = beobachten.snapshot_text(Duration::from_secs(60)); + assert!(snapshot.contains("[HTTP]")); + assert!(snapshot.contains("198.51.100.130-1")); +} diff --git a/src/proxy/tests/masking_http2_probe_classification_security_tests.rs b/src/proxy/tests/masking_http2_probe_classification_security_tests.rs new file mode 100644 index 0000000..34e04a9 --- /dev/null +++ b/src/proxy/tests/masking_http2_probe_classification_security_tests.rs @@ -0,0 +1,92 @@ +use super::*; + +#[test] +fn full_http2_preface_classified_as_http_probe() { + let preface = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n"; + assert!( + is_http_probe(preface), + "HTTP/2 connection preface must be classified as HTTP probe" + ); +} + +#[test] +fn partial_http2_preface_3_bytes_classified() { + assert!( + is_http_probe(b"PRI"), + "3-byte HTTP/2 preface prefix must be classified" + ); +} + +#[test] +fn partial_http2_preface_2_bytes_classified() { + assert!( + is_http_probe(b"PR"), + "2-byte HTTP/2 preface prefix must be classified" + ); +} + +#[test] +fn existing_http1_methods_unaffected() { + for prefix in [ + b"GET / HTTP/1.1\r\n".as_ref(), + b"POST /api HTTP/1.1\r\n".as_ref(), + b"CONNECT example.com:443 HTTP/1.1\r\n".as_ref(), + b"TRACE / HTTP/1.1\r\n".as_ref(), + b"PATCH / HTTP/1.1\r\n".as_ref(), + ] { + assert!(is_http_probe(prefix)); + } +} + +#[test] +fn non_http_data_not_classified() { + for data in [ + b"\x16\x03\x01\x00\xf1".as_ref(), + b"SSH-2.0-OpenSSH_8.9\r\n".as_ref(), + b"\x00\x01\x02\x03".as_ref(), + b"".as_ref(), + b"P".as_ref(), + ] { + assert!(!is_http_probe(data)); + } +} + +#[test] +fn light_fuzz_non_http_prefixes_not_misclassified() { + // Deterministic pseudo-fuzz to exercise classifier edges while avoiding + // known HTTP method and partial windows. + let mut x = 0x1234_5678u32; + for _ in 0..1024 { + x = x.wrapping_mul(1664525).wrapping_add(1013904223); + let len = 4 + ((x >> 8) as usize % 12); + let mut data = vec![0u8; len]; + for byte in &mut data { + x = x.wrapping_mul(1664525).wrapping_add(1013904223); + *byte = (x & 0xFF) as u8; + } + + if [ + b"GET ".as_ref(), + b"POST".as_ref(), + b"HEAD".as_ref(), + b"PUT ".as_ref(), + b"DELETE".as_ref(), + b"OPTIONS".as_ref(), + b"CONNECT".as_ref(), + b"TRACE".as_ref(), + b"PATCH".as_ref(), + b"PRI ".as_ref(), + ] + .iter() + .any(|m| data.starts_with(m)) + { + continue; + } + + assert!( + !is_http_probe(&data), + "non-http pseudo-fuzz input misclassified: {:?}", + &data[..data.len().min(8)] + ); + } +} diff --git a/src/proxy/tests/masking_http_probe_boundary_security_tests.rs b/src/proxy/tests/masking_http_probe_boundary_security_tests.rs new file mode 100644 index 0000000..c8f3ec0 --- /dev/null +++ b/src/proxy/tests/masking_http_probe_boundary_security_tests.rs @@ -0,0 +1,85 @@ +use super::*; + +#[test] +fn exact_four_byte_http_tokens_are_classified() { + for token in [ + b"GET ".as_ref(), + b"POST".as_ref(), + b"HEAD".as_ref(), + b"PUT ".as_ref(), + b"PRI ".as_ref(), + ] { + assert!( + is_http_probe(token), + "exact 4-byte token must be classified as HTTP probe: {:?}", + token + ); + } +} + +#[test] +fn exact_four_byte_non_http_tokens_are_not_classified() { + for token in [ + b"GEX ".as_ref(), + b"POXT".as_ref(), + b"HEA/".as_ref(), + b"PU\0 ".as_ref(), + b"PRI/".as_ref(), + ] { + assert!( + !is_http_probe(token), + "non-HTTP 4-byte token must not be classified: {:?}", + token + ); + } +} + +#[test] +fn detect_client_type_keeps_http_label_for_minimal_four_byte_http_prefixes() { + assert_eq!(detect_client_type(b"GET "), "HTTP"); + assert_eq!(detect_client_type(b"PRI "), "HTTP"); +} + +#[test] +fn exact_long_http_tokens_are_classified() { + for token in [b"CONNECT".as_ref(), b"TRACE".as_ref(), b"PATCH".as_ref()] { + assert!( + is_http_probe(token), + "exact long HTTP token must be classified as HTTP probe: {:?}", + token + ); + } +} + +#[test] +fn detect_client_type_keeps_http_label_for_exact_long_http_tokens() { + assert_eq!(detect_client_type(b"CONNECT"), "HTTP"); + assert_eq!(detect_client_type(b"TRACE"), "HTTP"); + assert_eq!(detect_client_type(b"PATCH"), "HTTP"); +} + +#[test] +fn light_fuzz_four_byte_ascii_noise_not_misclassified() { + // Deterministic pseudo-fuzz over 4-byte printable ASCII inputs. + let mut x = 0xA17C_93E5u32; + for _ in 0..2048 { + let mut token = [0u8; 4]; + for byte in &mut token { + x = x.wrapping_mul(1664525).wrapping_add(1013904223); + *byte = 32 + ((x & 0x3F) as u8); // printable ASCII subset + } + + if [b"GET ", b"POST", b"HEAD", b"PUT ", b"PRI "] + .iter() + .any(|m| token.as_slice() == *m) + { + continue; + } + + assert!( + !is_http_probe(&token), + "pseudo-fuzz noise misclassified as HTTP probe: {:?}", + token + ); + } +} diff --git a/src/proxy/tests/masking_interface_cache_concurrency_security_tests.rs b/src/proxy/tests/masking_interface_cache_concurrency_security_tests.rs new file mode 100644 index 0000000..ed6d1ab --- /dev/null +++ b/src/proxy/tests/masking_interface_cache_concurrency_security_tests.rs @@ -0,0 +1,41 @@ +#![cfg(unix)] + +use super::*; +use std::sync::{Mutex, OnceLock}; +use tokio::sync::Barrier; + +fn interface_cache_test_lock() -> &'static Mutex<()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::new(())) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn adversarial_parallel_cold_miss_performs_single_interface_refresh() { + let _guard = interface_cache_test_lock() + .lock() + .unwrap_or_else(|poison| poison.into_inner()); + reset_local_interface_enumerations_for_tests(); + + let local_addr: SocketAddr = "0.0.0.0:443".parse().expect("valid local addr"); + let workers = 32usize; + let barrier = std::sync::Arc::new(Barrier::new(workers)); + let mut tasks = Vec::with_capacity(workers); + + for _ in 0..workers { + let barrier = std::sync::Arc::clone(&barrier); + tasks.push(tokio::spawn(async move { + barrier.wait().await; + is_mask_target_local_listener_async("127.0.0.1", 443, local_addr, None).await + })); + } + + for task in tasks { + let _ = task.await.expect("parallel cache task must not panic"); + } + + assert_eq!( + local_interface_enumerations_for_tests(), + 1, + "parallel cold misses must coalesce into a single interface enumeration" + ); +} diff --git a/src/proxy/tests/masking_interface_cache_defense_in_depth_security_tests.rs b/src/proxy/tests/masking_interface_cache_defense_in_depth_security_tests.rs new file mode 100644 index 0000000..d82cf82 --- /dev/null +++ b/src/proxy/tests/masking_interface_cache_defense_in_depth_security_tests.rs @@ -0,0 +1,51 @@ +#![cfg(unix)] + +use super::*; + +#[test] +fn defense_in_depth_empty_refresh_preserves_previous_non_empty_interfaces() { + let previous = vec![ + "192.168.100.7" + .parse::() + .expect("must parse interface ip"), + ]; + let refreshed = Vec::new(); + + let next = choose_interface_snapshot(&previous, refreshed); + + assert_eq!( + next, previous, + "empty refresh should preserve previous non-empty snapshot to avoid fail-open loop-guard regressions" + ); +} + +#[test] +fn defense_in_depth_non_empty_refresh_replaces_previous_snapshot() { + let previous = vec![ + "192.168.100.7" + .parse::() + .expect("must parse interface ip"), + ]; + let refreshed = vec![ + "10.55.0.3" + .parse::() + .expect("must parse refreshed interface ip"), + ]; + + let next = choose_interface_snapshot(&previous, refreshed.clone()); + + assert_eq!(next, refreshed); +} + +#[test] +fn defense_in_depth_empty_refresh_keeps_empty_when_no_previous_snapshot_exists() { + let previous = Vec::new(); + let refreshed = Vec::new(); + + let next = choose_interface_snapshot(&previous, refreshed); + + assert!( + next.is_empty(), + "empty refresh with no previous snapshot should remain empty" + ); +} diff --git a/src/proxy/tests/masking_interface_cache_security_tests.rs b/src/proxy/tests/masking_interface_cache_security_tests.rs new file mode 100644 index 0000000..17debb0 --- /dev/null +++ b/src/proxy/tests/masking_interface_cache_security_tests.rs @@ -0,0 +1,49 @@ +#![cfg(unix)] + +use super::*; +use std::sync::{Mutex, OnceLock}; + +fn interface_cache_test_lock() -> &'static Mutex<()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::new(())) +} + +#[tokio::test] +async fn tdd_repeated_local_listener_checks_do_not_repeat_interface_enumeration_within_window() { + let _guard = interface_cache_test_lock() + .lock() + .unwrap_or_else(|poison| poison.into_inner()); + reset_local_interface_enumerations_for_tests(); + + let local_addr: SocketAddr = "0.0.0.0:443".parse().expect("valid local addr"); + + let _ = is_mask_target_local_listener_async("127.0.0.1", 443, local_addr, None).await; + let _ = is_mask_target_local_listener_async("127.0.0.1", 443, local_addr, None).await; + + assert_eq!( + local_interface_enumerations_for_tests(), + 1, + "interface enumeration must be cached across repeated bad-client checks" + ); +} + +#[tokio::test] +async fn tdd_non_local_port_short_circuit_does_not_enumerate_interfaces() { + let _guard = interface_cache_test_lock() + .lock() + .unwrap_or_else(|poison| poison.into_inner()); + reset_local_interface_enumerations_for_tests(); + + let local_addr: SocketAddr = "0.0.0.0:443".parse().expect("valid local addr"); + let is_local = is_mask_target_local_listener_async("127.0.0.1", 8443, local_addr, None).await; + + assert!( + !is_local, + "different port must not be treated as local listener" + ); + assert_eq!( + local_interface_enumerations_for_tests(), + 0, + "port mismatch should bypass interface enumeration entirely" + ); +} diff --git a/src/proxy/tests/masking_offline_target_redteam_expected_fail_tests.rs b/src/proxy/tests/masking_offline_target_redteam_expected_fail_tests.rs new file mode 100644 index 0000000..efa4529 --- /dev/null +++ b/src/proxy/tests/masking_offline_target_redteam_expected_fail_tests.rs @@ -0,0 +1,178 @@ +use super::*; +use std::net::{SocketAddr, TcpListener as StdTcpListener}; +use tokio::io::{AsyncWriteExt, duplex}; +use tokio::time::{Duration, Instant}; + +fn closed_local_port() -> u16 { + let listener = StdTcpListener::bind("127.0.0.1:0").unwrap(); + let port = listener.local_addr().unwrap().port(); + drop(listener); + port +} + +#[tokio::test] +#[ignore = "red-team expected-fail: offline mask target keeps bad-client socket alive before consume timeout boundary"] +async fn redteam_offline_target_should_drop_idle_client_early() { + let (client_read, mut client_write) = duplex(1024); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = closed_local_port(); + cfg.censorship.mask_timing_normalization_enabled = false; + + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let peer_addr: SocketAddr = "192.0.2.50:5000".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + let handler = tokio::spawn(async move { + handle_bad_client( + client_read, + tokio::io::sink(), + b"GET / HTTP/1.1\r\n\r\n", + peer_addr, + local_addr, + &cfg, + &beobachten, + ) + .await; + }); + + tokio::time::sleep(Duration::from_millis(150)).await; + let write_res = client_write.write_all(b"probe-should-be-closed").await; + assert!( + write_res.is_err(), + "offline target path still keeps client writable before consume timeout" + ); + + handler.abort(); +} + +#[tokio::test] +#[ignore = "red-team expected-fail: proxy should mimic immediate RST-like close when target is offline"] +async fn redteam_offline_target_should_not_sleep_to_mask_refusal() { + let (client_read, mut client_write) = duplex(1024); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = closed_local_port(); + cfg.censorship.mask_timing_normalization_enabled = false; + + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let peer_addr: SocketAddr = "192.0.2.51:5000".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + let handler = tokio::spawn(async move { + handle_bad_client( + client_read, + tokio::io::sink(), + b"\x16\x03\x01\x00\x05hello", + peer_addr, + local_addr, + &cfg, + &beobachten, + ) + .await; + }); + + client_write.shutdown().await.unwrap(); + let _ = handler.await; + let elapsed = started.elapsed(); + + assert!( + elapsed < Duration::from_millis(10), + "offline target path still applies coarse masking sleep and is fingerprintable" + ); +} + +#[tokio::test] +#[ignore = "red-team expected-fail: refusal path should remain below strict latency envelope under burst"] +async fn redteam_offline_refusal_burst_timing_spread_should_be_tight() { + let mut samples = Vec::new(); + + for i in 0..12u16 { + let (client_read, mut client_write) = duplex(1024); + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = closed_local_port(); + cfg.censorship.mask_timing_normalization_enabled = false; + + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let peer_addr: SocketAddr = format!("192.0.2.52:{}", 5100 + i).parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + let handler = tokio::spawn(async move { + handle_bad_client( + client_read, + tokio::io::sink(), + b"GET / HTTP/1.1\r\n\r\n", + peer_addr, + local_addr, + &cfg, + &beobachten, + ) + .await; + }); + + client_write.shutdown().await.unwrap(); + let _ = handler.await; + samples.push(started.elapsed()); + } + + let min = samples.iter().copied().min().unwrap_or_default(); + let max = samples.iter().copied().max().unwrap_or_default(); + let spread = max.saturating_sub(min); + + assert!( + spread <= Duration::from_millis(5), + "offline refusal timing spread too wide for strict red-team envelope: {:?}", + spread + ); +} + +#[tokio::test] +#[ignore = "manual red-team: host resolver failure should complete without panic"] +async fn redteam_dns_resolution_failure_must_not_panic() { + let (client_read, mut client_write) = duplex(1024); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("this.domain.definitely.does.not.exist.invalid".to_string()); + cfg.censorship.mask_port = 443; + + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let peer_addr: SocketAddr = "192.0.2.99:5999".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + let handler = tokio::spawn(async move { + handle_bad_client( + client_read, + tokio::io::sink(), + b"GET / HTTP/1.1\r\n\r\n", + peer_addr, + local_addr, + &cfg, + &beobachten, + ) + .await; + }); + + client_write.shutdown().await.unwrap(); + let result = tokio::time::timeout(Duration::from_secs(2), handler).await; + assert!( + result.is_ok(), + "dns failure path stalled or panicked instead of terminating" + ); +} diff --git a/src/proxy/tests/masking_padding_timeout_adversarial_tests.rs b/src/proxy/tests/masking_padding_timeout_adversarial_tests.rs new file mode 100644 index 0000000..b99b4bc --- /dev/null +++ b/src/proxy/tests/masking_padding_timeout_adversarial_tests.rs @@ -0,0 +1,51 @@ +use super::*; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::time::Instant; +use tokio::io::AsyncWrite; + +struct NeverWritable; + +impl AsyncWrite for NeverWritable { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _buf: &[u8], + ) -> Poll> { + Poll::Pending + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Pending + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} + +#[tokio::test] +async fn shape_padding_returns_before_global_mask_timeout_on_blocked_writer() { + let mut writer = NeverWritable; + let started = Instant::now(); + + maybe_write_shape_padding(&mut writer, 1, true, 256, 4096, false, 0, false).await; + + assert!( + started.elapsed() <= MASK_TIMEOUT + std::time::Duration::from_millis(30), + "shape padding blocked past timeout budget" + ); +} + +#[tokio::test] +async fn shape_padding_with_non_http_blur_disabled_at_cap_writes_nothing() { + let mut output = Vec::new(); + { + let mut writer = tokio::io::BufWriter::new(&mut output); + maybe_write_shape_padding(&mut writer, 4096, true, 64, 4096, false, 128, false).await; + use tokio::io::AsyncWriteExt; + writer.flush().await.unwrap(); + } + + assert!(output.is_empty()); +} diff --git a/src/proxy/tests/masking_production_cap_regression_security_tests.rs b/src/proxy/tests/masking_production_cap_regression_security_tests.rs new file mode 100644 index 0000000..9ff51ba --- /dev/null +++ b/src/proxy/tests/masking_production_cap_regression_security_tests.rs @@ -0,0 +1,283 @@ +use super::*; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::time::{Duration, Instant, timeout}; + +const PROD_CAP_BYTES: usize = 5 * 1024 * 1024; + +struct FinitePatternReader { + remaining: usize, + chunk: usize, + read_calls: Arc, +} + +impl FinitePatternReader { + fn new(total: usize, chunk: usize, read_calls: Arc) -> Self { + Self { + remaining: total, + chunk, + read_calls, + } + } +} + +impl AsyncRead for FinitePatternReader { + fn poll_read( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + self.read_calls.fetch_add(1, Ordering::Relaxed); + + if self.remaining == 0 { + return Poll::Ready(Ok(())); + } + + let take = self.remaining.min(self.chunk).min(buf.remaining()); + if take == 0 { + return Poll::Ready(Ok(())); + } + + let fill = vec![0x5Au8; take]; + buf.put_slice(&fill); + self.remaining -= take; + Poll::Ready(Ok(())) + } +} + +#[derive(Default)] +struct CountingWriter { + written: usize, +} + +impl AsyncWrite for CountingWriter { + fn poll_write( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.written = self.written.saturating_add(buf.len()); + Poll::Ready(Ok(buf.len())) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} + +struct NeverReadyReader; + +impl AsyncRead for NeverReadyReader { + fn poll_read( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + Poll::Pending + } +} + +struct BudgetProbeReader { + remaining: usize, + total_read: Arc, +} + +impl BudgetProbeReader { + fn new(total: usize, total_read: Arc) -> Self { + Self { + remaining: total, + total_read, + } + } +} + +impl AsyncRead for BudgetProbeReader { + fn poll_read( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + if self.remaining == 0 { + return Poll::Ready(Ok(())); + } + + let take = self.remaining.min(buf.remaining()); + if take == 0 { + return Poll::Ready(Ok(())); + } + + let fill = vec![0xA5u8; take]; + buf.put_slice(&fill); + self.remaining -= take; + self.total_read.fetch_add(take, Ordering::Relaxed); + Poll::Ready(Ok(())) + } +} + +#[tokio::test] +async fn positive_copy_with_production_cap_stops_exactly_at_budget() { + let read_calls = Arc::new(AtomicUsize::new(0)); + let mut reader = FinitePatternReader::new(PROD_CAP_BYTES + (256 * 1024), 4096, read_calls); + let mut writer = CountingWriter::default(); + + let outcome = copy_with_idle_timeout(&mut reader, &mut writer, PROD_CAP_BYTES, true).await; + + assert_eq!( + outcome.total, PROD_CAP_BYTES, + "copy path must stop at explicit production cap" + ); + assert_eq!(writer.written, PROD_CAP_BYTES); + assert!( + !outcome.ended_by_eof, + "byte-cap stop must not be misclassified as EOF" + ); +} + +#[tokio::test] +async fn negative_consume_with_zero_cap_performs_no_reads() { + let read_calls = Arc::new(AtomicUsize::new(0)); + let reader = FinitePatternReader::new(1024, 64, Arc::clone(&read_calls)); + + consume_client_data_with_timeout_and_cap(reader, 0).await; + + assert_eq!( + read_calls.load(Ordering::Relaxed), + 0, + "zero cap must return before reading attacker-controlled bytes" + ); +} + +#[tokio::test] +async fn edge_copy_below_cap_reports_eof_without_overread() { + let read_calls = Arc::new(AtomicUsize::new(0)); + let payload = 73 * 1024; + let mut reader = FinitePatternReader::new(payload, 3072, read_calls); + let mut writer = CountingWriter::default(); + + let outcome = copy_with_idle_timeout(&mut reader, &mut writer, PROD_CAP_BYTES, true).await; + + assert_eq!(outcome.total, payload); + assert_eq!(writer.written, payload); + assert!( + outcome.ended_by_eof, + "finite upstream below cap must terminate via EOF path" + ); +} + +#[tokio::test] +async fn adversarial_blackhat_never_ready_reader_is_bounded_by_timeout_guards() { + let started = Instant::now(); + + consume_client_data_with_timeout_and_cap(NeverReadyReader, PROD_CAP_BYTES).await; + + assert!( + started.elapsed() < Duration::from_millis(350), + "never-ready reader must be bounded by idle/relay timeout protections" + ); +} + +#[tokio::test] +async fn integration_consume_path_honors_production_cap_for_large_payload() { + let read_calls = Arc::new(AtomicUsize::new(0)); + let reader = FinitePatternReader::new(PROD_CAP_BYTES + (1024 * 1024), 8192, read_calls); + + let bounded = timeout( + Duration::from_millis(350), + consume_client_data_with_timeout_and_cap(reader, PROD_CAP_BYTES), + ) + .await; + + assert!( + bounded.is_ok(), + "consume path with production cap must finish within bounded time" + ); +} + +#[tokio::test] +async fn adversarial_consume_path_never_reads_beyond_declared_byte_cap() { + let byte_cap = 5usize; + let total_read = Arc::new(AtomicUsize::new(0)); + let reader = BudgetProbeReader::new(256 * 1024, Arc::clone(&total_read)); + + consume_client_data_with_timeout_and_cap(reader, byte_cap).await; + + assert!( + total_read.load(Ordering::Relaxed) <= byte_cap, + "consume path must not read more than configured byte cap" + ); +} + +#[tokio::test] +async fn light_fuzz_cap_and_payload_matrix_preserves_min_budget_invariant() { + let mut seed = 0x1234_5678_9ABC_DEF0u64; + + for _case in 0..96u32 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let cap = ((seed & 0x3ffff) as usize).saturating_add(1); + let payload = ((seed.rotate_left(11) & 0x7ffff) as usize).saturating_add(1); + let chunk = (((seed >> 5) & 0x1fff) as usize).saturating_add(1); + + let read_calls = Arc::new(AtomicUsize::new(0)); + let mut reader = FinitePatternReader::new(payload, chunk, read_calls); + let mut writer = CountingWriter::default(); + + let outcome = copy_with_idle_timeout(&mut reader, &mut writer, cap, true).await; + let expected = payload.min(cap); + + assert_eq!( + outcome.total, expected, + "copy total must match min(payload, cap) under fuzzed inputs" + ); + assert_eq!(writer.written, expected); + if payload <= cap { + assert!(outcome.ended_by_eof); + } else { + assert!(!outcome.ended_by_eof); + } + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_parallel_copy_tasks_with_production_cap_complete_without_leaks() { + let workers = 8usize; + let mut tasks = Vec::with_capacity(workers); + + for idx in 0..workers { + tasks.push(tokio::spawn(async move { + let read_calls = Arc::new(AtomicUsize::new(0)); + let mut reader = FinitePatternReader::new( + PROD_CAP_BYTES + (idx + 1) * 4096, + 4096 + (idx * 257), + read_calls, + ); + let mut writer = CountingWriter::default(); + copy_with_idle_timeout(&mut reader, &mut writer, PROD_CAP_BYTES, true).await + })); + } + + timeout(Duration::from_secs(3), async { + for task in tasks { + let outcome = task.await.expect("stress task must not panic"); + assert_eq!( + outcome.total, PROD_CAP_BYTES, + "stress copy task must stay within production cap" + ); + assert!( + !outcome.ended_by_eof, + "stress task should end due to cap, not EOF" + ); + } + }) + .await + .expect("stress suite must complete in bounded time"); +} diff --git a/src/proxy/tests/masking_relay_guardrails_security_tests.rs b/src/proxy/tests/masking_relay_guardrails_security_tests.rs new file mode 100644 index 0000000..257c0f8 --- /dev/null +++ b/src/proxy/tests/masking_relay_guardrails_security_tests.rs @@ -0,0 +1,105 @@ +use super::*; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex, sink}; +use tokio::time::{Duration, timeout}; + +#[tokio::test] +async fn relay_to_mask_enforces_masking_session_byte_cap() { + let initial = vec![0x16, 0x03, 0x01, 0x00, 0x01]; + let extra = vec![0xAB; 96 * 1024]; + + let (client_reader, mut client_writer) = duplex(128 * 1024); + let (mask_read, _mask_read_peer) = duplex(1024); + let (mut mask_observer, mask_write) = duplex(256 * 1024); + let initial_for_task = initial.clone(); + + let relay = tokio::spawn(async move { + relay_to_mask( + client_reader, + sink(), + mask_read, + mask_write, + &initial_for_task, + false, + 512, + 4096, + false, + 0, + false, + 32 * 1024, + ) + .await; + }); + + client_writer.write_all(&extra).await.unwrap(); + client_writer.shutdown().await.unwrap(); + + timeout(Duration::from_secs(2), relay) + .await + .unwrap() + .unwrap(); + + let mut observed = Vec::new(); + timeout( + Duration::from_secs(2), + mask_observer.read_to_end(&mut observed), + ) + .await + .unwrap() + .unwrap(); + + // In this deterministic test, relay must stop exactly at the configured cap. + assert_eq!( + observed.len(), + initial.len() + (32 * 1024), + "masked relay must forward exactly up to the cap (observed={} initial={} cap={})", + observed.len(), + initial.len(), + 32 * 1024 + ); +} + +#[tokio::test] +async fn relay_to_mask_propagates_client_half_close_without_waiting_for_other_direction_timeout() { + let initial = b"GET /half-close HTTP/1.1\r\n".to_vec(); + + let (client_reader, mut client_writer) = duplex(8 * 1024); + let (mask_read, _mask_read_peer) = duplex(8 * 1024); + let (mut mask_observer, mask_write) = duplex(8 * 1024); + let initial_for_task = initial.clone(); + + let relay = tokio::spawn(async move { + relay_to_mask( + client_reader, + sink(), + mask_read, + mask_write, + &initial_for_task, + false, + 512, + 4096, + false, + 0, + false, + 32 * 1024, + ) + .await; + }); + + client_writer.shutdown().await.unwrap(); + + let mut observed = Vec::new(); + timeout( + Duration::from_millis(80), + mask_observer.read_to_end(&mut observed), + ) + .await + .expect("mask backend write side should be half-closed promptly") + .unwrap(); + + assert_eq!(&observed[..initial.len()], initial.as_slice()); + + timeout(Duration::from_secs(2), relay) + .await + .unwrap() + .unwrap(); +} diff --git a/src/proxy/tests/masking_rng_hoist_perf_regression_tests.rs b/src/proxy/tests/masking_rng_hoist_perf_regression_tests.rs new file mode 100644 index 0000000..627c48b --- /dev/null +++ b/src/proxy/tests/masking_rng_hoist_perf_regression_tests.rs @@ -0,0 +1,100 @@ +use super::*; +use tokio::io::AsyncReadExt; +use tokio::time::{Duration, timeout}; + +async fn collect_padding( + total_sent: usize, + enabled: bool, + floor: usize, + cap: usize, + above_cap_blur: bool, + blur_max: usize, + aggressive: bool, +) -> Vec { + let (mut tx, mut rx) = tokio::io::duplex(256 * 1024); + + maybe_write_shape_padding( + &mut tx, + total_sent, + enabled, + floor, + cap, + above_cap_blur, + blur_max, + aggressive, + ) + .await; + + drop(tx); + + let mut output = Vec::new(); + timeout(Duration::from_secs(1), rx.read_to_end(&mut output)) + .await + .expect("reading padded output timed out") + .expect("failed reading padded output"); + output +} + +#[tokio::test] +async fn padding_output_is_not_all_zero() { + let output = collect_padding(1, true, 256, 4096, false, 0, false).await; + + assert!( + output.len() >= 255, + "expected at least 255 padding bytes, got {}", + output.len() + ); + + let nonzero = output.iter().filter(|&&b| b != 0).count(); + // In 255 bytes of uniform randomness, the expected number of zero bytes is ~1. + // A weak nonzero check can miss severe entropy collapse. + assert!( + nonzero >= 240, + "RNG output entropy collapsed, too many zero bytes: {} nonzero out of {}", + nonzero, + output.len(), + ); +} + +#[tokio::test] +async fn padding_reaches_first_bucket_boundary() { + let output = collect_padding(1, true, 64, 4096, false, 0, false).await; + assert_eq!(output.len(), 63); +} + +#[tokio::test] +async fn disabled_padding_produces_no_output() { + let output = collect_padding(0, false, 256, 4096, false, 0, false).await; + assert!(output.is_empty()); +} + +#[tokio::test] +async fn at_cap_without_blur_produces_no_output() { + let output = collect_padding(4096, true, 64, 4096, false, 0, false).await; + assert!(output.is_empty()); +} + +#[tokio::test] +async fn above_cap_blur_is_positive_and_bounded_in_aggressive_mode() { + let output = collect_padding(4096, true, 64, 4096, true, 128, true).await; + assert!(!output.is_empty()); + assert!(output.len() <= 128, "blur exceeded max: {}", output.len()); +} + +#[tokio::test] +async fn stress_padding_runs_are_not_constant_pattern() { + // Stress and sanity-check: repeated runs should not collapse to identical + // first 16 bytes across all samples. + let mut first_chunks = Vec::new(); + for _ in 0..64 { + let out = collect_padding(1, true, 64, 4096, false, 0, false).await; + first_chunks.push(out[..16].to_vec()); + } + + let first = &first_chunks[0]; + let all_same = first_chunks.iter().all(|chunk| chunk == first); + assert!( + !all_same, + "all stress samples had identical prefix, rng output appears degenerate" + ); +} diff --git a/src/proxy/tests/masking_security_tests.rs b/src/proxy/tests/masking_security_tests.rs index 4519d85..c698b55 100644 --- a/src/proxy/tests/masking_security_tests.rs +++ b/src/proxy/tests/masking_security_tests.rs @@ -1376,6 +1376,7 @@ async fn relay_to_mask_keeps_backend_to_client_flow_when_client_to_backend_stall false, 0, false, + 5 * 1024 * 1024, ) .await; }); @@ -1506,6 +1507,7 @@ async fn relay_to_mask_timeout_cancels_and_drops_all_io_endpoints() { false, 0, false, + 5 * 1024 * 1024, ), ) .await; diff --git a/src/proxy/tests/masking_self_target_loop_security_tests.rs b/src/proxy/tests/masking_self_target_loop_security_tests.rs new file mode 100644 index 0000000..7f6cb29 --- /dev/null +++ b/src/proxy/tests/masking_self_target_loop_security_tests.rs @@ -0,0 +1,330 @@ +use super::*; +use std::net::SocketAddr; +use std::net::TcpListener as StdTcpListener; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::net::TcpListener; +use tokio::time::{Duration, Instant, timeout}; + +fn closed_local_port() -> u16 { + let listener = StdTcpListener::bind("127.0.0.1:0").unwrap(); + let port = listener.local_addr().unwrap().port(); + drop(listener); + port +} + +#[tokio::test] +async fn self_target_detection_matches_literal_ipv4_listener() { + let local: SocketAddr = "198.51.100.40:443".parse().unwrap(); + assert!(is_mask_target_local_listener_async("198.51.100.40", 443, local, None,).await); +} + +#[tokio::test] +async fn self_target_detection_matches_bracketed_ipv6_listener() { + let local: SocketAddr = "[2001:db8::44]:8443".parse().unwrap(); + assert!(is_mask_target_local_listener_async("[2001:db8::44]", 8443, local, None,).await); +} + +#[tokio::test] +async fn self_target_detection_keeps_same_ip_different_port_forwardable() { + let local: SocketAddr = "203.0.113.44:443".parse().unwrap(); + assert!(!is_mask_target_local_listener_async("203.0.113.44", 8443, local, None,).await); +} + +#[tokio::test] +async fn self_target_detection_normalizes_ipv4_mapped_ipv6_literal() { + let local: SocketAddr = "127.0.0.1:443".parse().unwrap(); + assert!(is_mask_target_local_listener_async("::ffff:127.0.0.1", 443, local, None,).await); +} + +#[tokio::test] +async fn self_target_detection_unspecified_bind_blocks_loopback_target() { + let local: SocketAddr = "0.0.0.0:443".parse().unwrap(); + assert!(is_mask_target_local_listener_async("127.0.0.1", 443, local, None,).await); +} + +#[tokio::test] +async fn self_target_detection_unspecified_bind_keeps_remote_target_forwardable() { + let local: SocketAddr = "0.0.0.0:443".parse().unwrap(); + let remote: SocketAddr = "198.51.100.44:443".parse().unwrap(); + assert!(!is_mask_target_local_listener_async("mask.example", 443, local, Some(remote),).await); +} + +#[tokio::test] +async fn self_target_fallback_refuses_recursive_loopback_connect() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let local_addr = listener.local_addr().unwrap(); + let accept_task = tokio::spawn(async move { + timeout(Duration::from_millis(120), listener.accept()) + .await + .is_ok() + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = None; + config.censorship.mask_host = Some(local_addr.ip().to_string()); + config.censorship.mask_port = local_addr.port(); + config.censorship.mask_proxy_protocol = 0; + + let peer: SocketAddr = "203.0.113.90:55090".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + handle_bad_client( + tokio::io::empty(), + tokio::io::sink(), + b"GET /", + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + let accepted = accept_task.await.unwrap(); + assert!( + !accepted, + "self-target masking must fail closed without connecting to local listener" + ); +} + +#[tokio::test] +async fn same_ip_different_port_still_forwards_to_mask_backend() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = b"GET /".to_vec(); + let accept_task = tokio::spawn({ + let expected = probe.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = vec![0u8; expected.len()]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, expected); + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = None; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_proxy_protocol = 0; + + let peer: SocketAddr = "203.0.113.91:55091".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + handle_bad_client( + tokio::io::empty(), + tokio::io::sink(), + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + timeout(Duration::from_secs(2), accept_task) + .await + .unwrap() + .unwrap(); +} + +#[test] +fn detect_client_type_http_boundary_get_and_post() { + assert_eq!(detect_client_type(b"GET "), "HTTP"); + assert_eq!(detect_client_type(b"GET /"), "HTTP"); + + assert_eq!(detect_client_type(b"POST"), "HTTP"); + assert_eq!(detect_client_type(b"POST "), "HTTP"); + assert_eq!(detect_client_type(b"POSTX"), "HTTP"); +} + +#[test] +fn detect_client_type_tls_and_length_boundaries() { + assert_eq!(detect_client_type(b"\x16\x03\x01"), "port-scanner"); + assert_eq!(detect_client_type(b"\x16\x03\x01\x00"), "TLS-scanner"); + + assert_eq!(detect_client_type(b"123456789"), "port-scanner"); + assert_eq!(detect_client_type(b"1234567890"), "unknown"); +} + +#[test] +fn build_mask_proxy_header_v1_cross_family_falls_back_to_unknown() { + let peer: SocketAddr = "192.168.1.5:12345".parse().unwrap(); + let local: SocketAddr = "[2001:db8::1]:443".parse().unwrap(); + let header = build_mask_proxy_header(1, peer, local).unwrap(); + assert_eq!(header, b"PROXY UNKNOWN\r\n"); +} + +#[test] +fn next_mask_shape_bucket_checked_mul_overflow_fails_closed() { + let floor = usize::MAX / 2 + 1; + let cap = usize::MAX; + let total = floor + 1; + assert_eq!(next_mask_shape_bucket(total, floor, cap), total); +} + +#[tokio::test] +async fn self_target_reject_path_keeps_timing_budget() { + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = None; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = 443; + + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let peer: SocketAddr = "203.0.113.92:55092".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + let (client, server) = duplex(1024); + drop(client); + + let started = Instant::now(); + handle_bad_client( + server, + tokio::io::sink(), + b"GET / HTTP/1.1\r\n", + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + let elapsed = started.elapsed(); + assert!( + elapsed >= Duration::from_millis(40) && elapsed < Duration::from_millis(250), + "self-target reject path must keep coarse timing budget without stalling" + ); +} + +#[tokio::test] +async fn relay_path_idle_timeout_eviction_remains_effective() { + let (client_read, mut client_write) = duplex(1024); + let (mask_read, mask_write) = duplex(1024); + + tokio::spawn(async move { + tokio::time::sleep(Duration::from_millis(10)).await; + client_write.write_all(b"a").await.unwrap(); + tokio::time::sleep(Duration::from_millis(180)).await; + let _ = client_write.write_all(b"b").await; + }); + + let started = Instant::now(); + relay_to_mask( + client_read, + tokio::io::sink(), + mask_read, + mask_write, + b"init", + false, + 0, + 0, + false, + 0, + false, + 5 * 1024 * 1024, + ) + .await; + + let elapsed = started.elapsed(); + assert!( + elapsed >= Duration::from_millis(90) && elapsed < Duration::from_millis(180), + "idle-timeout eviction must occur before late trickle write" + ); +} + +#[tokio::test] +async fn offline_mask_target_refusal_respects_timing_normalization_budget() { + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = None; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = closed_local_port(); + config.censorship.mask_timing_normalization_enabled = true; + config.censorship.mask_timing_normalization_floor_ms = 120; + config.censorship.mask_timing_normalization_ceiling_ms = 120; + + let peer: SocketAddr = "203.0.113.93:55093".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + let (mut client, server) = duplex(1024); + let started = Instant::now(); + let task = tokio::spawn(async move { + handle_bad_client( + server, + tokio::io::sink(), + b"GET / HTTP/1.1\r\n\r\n", + peer, + local_addr, + &config, + &beobachten, + ) + .await; + }); + + client.shutdown().await.unwrap(); + timeout(Duration::from_secs(2), task) + .await + .unwrap() + .unwrap(); + let elapsed = started.elapsed(); + + assert!( + elapsed >= Duration::from_millis(110) && elapsed < Duration::from_millis(220), + "offline-refusal path must honor normalization budget without unbounded drift" + ); +} + +#[tokio::test] +async fn offline_mask_target_refusal_with_idle_client_is_bounded_by_consume_timeout() { + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = None; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = closed_local_port(); + config.censorship.mask_timing_normalization_enabled = false; + + let peer: SocketAddr = "203.0.113.94:55094".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + let (mut client, server) = duplex(1024); + let started = Instant::now(); + let task = tokio::spawn(async move { + handle_bad_client( + server, + tokio::io::sink(), + b"GET / HTTP/1.1\r\n\r\n", + peer, + local_addr, + &config, + &beobachten, + ) + .await; + }); + + tokio::time::sleep(Duration::from_millis(120)).await; + client + .write_all(b"still-open-before-timeout") + .await + .expect("connection should still be open before consume timeout expires"); + + timeout(Duration::from_secs(2), task) + .await + .unwrap() + .unwrap(); + let elapsed = started.elapsed(); + + assert!( + elapsed >= Duration::from_millis(190) && elapsed < Duration::from_millis(350), + "offline-refusal path must not retain idle client indefinitely" + ); +} diff --git a/src/proxy/tests/masking_shape_guard_adversarial_tests.rs b/src/proxy/tests/masking_shape_guard_adversarial_tests.rs index 982fd26..4fa8da7 100644 --- a/src/proxy/tests/masking_shape_guard_adversarial_tests.rs +++ b/src/proxy/tests/masking_shape_guard_adversarial_tests.rs @@ -43,6 +43,7 @@ async fn run_relay_case( above_cap_blur, above_cap_blur_max_bytes, false, + 5 * 1024 * 1024, ) .await; }); diff --git a/src/proxy/tests/masking_shape_hardening_adversarial_tests.rs b/src/proxy/tests/masking_shape_hardening_adversarial_tests.rs index 3c886ba..9abf3c0 100644 --- a/src/proxy/tests/masking_shape_hardening_adversarial_tests.rs +++ b/src/proxy/tests/masking_shape_hardening_adversarial_tests.rs @@ -88,6 +88,7 @@ async fn relay_to_mask_applies_cap_clamped_padding_for_non_power_of_two_cap() { false, 0, false, + 5 * 1024 * 1024, ) .await; }); diff --git a/src/proxy/tests/masking_timing_budget_coupling_security_tests.rs b/src/proxy/tests/masking_timing_budget_coupling_security_tests.rs new file mode 100644 index 0000000..fda6de7 --- /dev/null +++ b/src/proxy/tests/masking_timing_budget_coupling_security_tests.rs @@ -0,0 +1,58 @@ +#![cfg(unix)] + +use super::*; +use tokio::io::{AsyncWriteExt, duplex}; +use tokio::time::{Duration, Instant, timeout}; + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn adversarial_delayed_interface_lookup_does_not_consume_outcome_floor_budget() { + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = None; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = 443; + config.censorship.mask_timing_normalization_enabled = true; + config.censorship.mask_timing_normalization_floor_ms = 120; + config.censorship.mask_timing_normalization_ceiling_ms = 120; + + let peer: SocketAddr = "203.0.113.151:55151".parse().expect("valid peer"); + let local_addr: SocketAddr = "0.0.0.0:443".parse().expect("valid local addr"); + let beobachten = BeobachtenStore::new(); + + let refresh_lock = LOCAL_INTERFACE_REFRESH_LOCK.get_or_init(|| AsyncMutex::new(())); + let held_refresh_guard = refresh_lock.lock().await; + + let (mut client, server) = duplex(1024); + let started = Instant::now(); + let task = tokio::spawn(async move { + handle_bad_client( + server, + tokio::io::sink(), + b"GET / HTTP/1.1\r\n\r\n", + peer, + local_addr, + &config, + &beobachten, + ) + .await; + }); + + tokio::time::sleep(Duration::from_millis(80)).await; + drop(held_refresh_guard); + client + .shutdown() + .await + .expect("client shutdown must succeed"); + + timeout(Duration::from_secs(2), task) + .await + .expect("task must finish in bounded time") + .expect("task must not panic"); + let elapsed = started.elapsed(); + + assert!( + elapsed >= Duration::from_millis(180) && elapsed < Duration::from_millis(350), + "timing normalization floor must start after pre-outcome self-target checks" + ); +} diff --git a/src/proxy/tests/middle_relay_atomic_quota_invariant_tests.rs b/src/proxy/tests/middle_relay_atomic_quota_invariant_tests.rs new file mode 100644 index 0000000..7c176bc --- /dev/null +++ b/src/proxy/tests/middle_relay_atomic_quota_invariant_tests.rs @@ -0,0 +1,189 @@ +use super::*; +use crate::crypto::AesCtr; +use bytes::Bytes; +use std::io; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering}; +use std::task::{Context, Poll}; +use tokio::io::AsyncWrite; + +struct CountedWriter { + write_calls: Arc, + fail_writes: bool, +} + +impl CountedWriter { + fn new(write_calls: Arc, fail_writes: bool) -> Self { + Self { + write_calls, + fail_writes, + } + } +} + +impl AsyncWrite for CountedWriter { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let this = self.get_mut(); + this.write_calls.fetch_add(1, Ordering::Relaxed); + if this.fail_writes { + Poll::Ready(Err(io::Error::new( + io::ErrorKind::BrokenPipe, + "forced write failure", + ))) + } else { + Poll::Ready(Ok(buf.len())) + } + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} + +fn make_crypto_writer(inner: CountedWriter) -> CryptoWriter { + let key = [0u8; 32]; + let iv = 0u128; + CryptoWriter::new(inner, AesCtr::new(&key, iv), 8 * 1024) +} + +#[tokio::test] +async fn me_writer_write_fail_keeps_reserved_quota_and_tracks_fail_metrics() { + let stats = Stats::new(); + let user = "middle-me-writer-no-rollback-user"; + let user_stats = stats.get_or_create_user_stats_handle(user); + let write_calls = Arc::new(AtomicUsize::new(0)); + let mut writer = make_crypto_writer(CountedWriter::new(write_calls.clone(), true)); + let mut frame_buf = Vec::new(); + let bytes_me2c = AtomicU64::new(0); + let payload = Bytes::from_static(&[0x11, 0x22, 0x33, 0x44, 0x55]); + + let result = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: payload.clone(), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + user, + Some(user_stats.as_ref()), + Some(64), + 0, + &bytes_me2c, + 11, + true, + false, + ) + .await; + + assert!( + matches!(result, Err(ProxyError::Io(_))), + "write failure must propagate as I/O error" + ); + assert!( + write_calls.load(Ordering::Relaxed) > 0, + "writer must be attempted after successful quota reservation" + ); + assert_eq!( + stats.get_user_quota_used(user), + payload.len() as u64, + "reserved quota must not roll back on write failure" + ); + assert_eq!( + stats.get_quota_write_fail_bytes_total(), + payload.len() as u64, + "write-fail byte metric must include failed payload size" + ); + assert_eq!( + stats.get_quota_write_fail_events_total(), + 1, + "write-fail events metric must increment once" + ); + assert_eq!( + stats.get_user_total_octets(user), + 0, + "telemetry octets_to should not advance when write fails" + ); + assert_eq!( + bytes_me2c.load(Ordering::Relaxed), + 0, + "ME->C committed byte counter must not advance on write failure" + ); +} + +#[tokio::test] +async fn me_writer_pre_write_quota_reject_happens_before_writer_poll() { + let stats = Stats::new(); + let user = "middle-me-writer-precheck-user"; + let limit = 8u64; + let user_stats = stats.get_or_create_user_stats_handle(user); + stats.quota_charge_post_write(user_stats.as_ref(), limit); + + let write_calls = Arc::new(AtomicUsize::new(0)); + let mut writer = make_crypto_writer(CountedWriter::new(write_calls.clone(), false)); + let mut frame_buf = Vec::new(); + let bytes_me2c = AtomicU64::new(0); + + let result = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0xAA, 0xBB, 0xCC]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + user, + Some(user_stats.as_ref()), + Some(limit), + 0, + &bytes_me2c, + 12, + true, + false, + ) + .await; + + assert!( + matches!(result, Err(ProxyError::DataQuotaExceeded { .. })), + "pre-write quota rejection must return typed quota error" + ); + assert_eq!( + write_calls.load(Ordering::Relaxed), + 0, + "writer must not be polled when pre-write quota reservation fails" + ); + assert_eq!( + stats.get_me_d2c_quota_reject_pre_write_total(), + 1, + "pre-write quota reject metric must increment" + ); + assert_eq!( + stats.get_user_quota_used(user), + limit, + "failed pre-write reservation must keep previous quota usage unchanged" + ); + assert_eq!( + stats.get_quota_write_fail_bytes_total(), + 0, + "write-fail bytes metric must stay unchanged on pre-write reject" + ); + assert_eq!( + stats.get_quota_write_fail_events_total(), + 0, + "write-fail events metric must stay unchanged on pre-write reject" + ); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0); +} diff --git a/src/proxy/tests/middle_relay_blackhat_campaign_integration_tests.rs b/src/proxy/tests/middle_relay_blackhat_campaign_integration_tests.rs deleted file mode 100644 index 2c9f3f6..0000000 --- a/src/proxy/tests/middle_relay_blackhat_campaign_integration_tests.rs +++ /dev/null @@ -1,112 +0,0 @@ -use super::*; -use crate::stats::Stats; -use dashmap::DashMap; -use std::sync::Arc; -use std::sync::atomic::{AtomicUsize, Ordering}; -use tokio::sync::Barrier; -use tokio::time::{Duration, timeout}; - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn blackhat_campaign_saturation_quota_race_with_queue_pressure_stays_fail_closed() { - let _guard = super::quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - for idx in 0..QUOTA_USER_LOCKS_MAX { - retained.push(quota_user_lock(&format!( - "middle-blackhat-held-{}-{idx}", - std::process::id() - ))); - } - - assert_eq!( - map.len(), - QUOTA_USER_LOCKS_MAX, - "precondition: bounded lock cache must be saturated" - ); - - let (tx, _rx) = mpsc::channel::(1); - tx.send(C2MeCommand::Close) - .await - .expect("queue prefill should succeed"); - - let pressure_seq_before = relay_pressure_event_seq(); - let pressure_errors = Arc::new(AtomicUsize::new(0)); - let mut pressure_workers = Vec::new(); - for _ in 0..16 { - let tx = tx.clone(); - let pressure_errors = Arc::clone(&pressure_errors); - pressure_workers.push(tokio::spawn(async move { - if enqueue_c2me_command(&tx, C2MeCommand::Close).await.is_err() { - pressure_errors.fetch_add(1, Ordering::Relaxed); - } - })); - } - - let stats = Arc::new(Stats::new()); - let user = format!("middle-blackhat-quota-race-{}", std::process::id()); - let gate = Arc::new(Barrier::new(16)); - - let mut quota_workers = Vec::new(); - for _ in 0..16u8 { - let stats = Arc::clone(&stats); - let user = user.clone(); - let gate = Arc::clone(&gate); - quota_workers.push(tokio::spawn(async move { - gate.wait().await; - let user_lock = quota_user_lock(&user); - let _quota_guard = user_lock.lock().await; - - if quota_would_be_exceeded_for_user(&stats, &user, Some(1), 1) { - return false; - } - stats.add_user_octets_to(&user, 1); - true - })); - } - - let mut ok_count = 0usize; - let mut denied_count = 0usize; - for worker in quota_workers { - let result = timeout(Duration::from_secs(2), worker) - .await - .expect("quota worker must finish") - .expect("quota worker must not panic"); - if result { - ok_count += 1; - } else { - denied_count += 1; - } - } - - for worker in pressure_workers { - timeout(Duration::from_secs(2), worker) - .await - .expect("pressure worker must finish") - .expect("pressure worker must not panic"); - } - - assert_eq!( - stats.get_user_total_octets(&user), - 1, - "black-hat campaign must not overshoot same-user quota under saturation" - ); - assert!(ok_count <= 1, "at most one quota contender may succeed"); - assert!( - denied_count >= 15, - "all remaining contenders must be quota-denied" - ); - - let pressure_seq_after = relay_pressure_event_seq(); - assert!( - pressure_seq_after > pressure_seq_before, - "queue pressure leg must trigger pressure accounting" - ); - assert!( - pressure_errors.load(Ordering::Relaxed) >= 1, - "at least one pressure worker should fail from persistent backpressure" - ); - - drop(retained); -} diff --git a/src/proxy/tests/middle_relay_coverage_high_risk_security_tests.rs b/src/proxy/tests/middle_relay_coverage_high_risk_security_tests.rs deleted file mode 100644 index fff26b4..0000000 --- a/src/proxy/tests/middle_relay_coverage_high_risk_security_tests.rs +++ /dev/null @@ -1,708 +0,0 @@ -use super::*; -use crate::crypto::AesCtr; -use crate::crypto::SecureRandom; -use crate::stats::Stats; -use crate::stream::{BufferPool, PooledBuffer}; -use std::sync::Arc; -use tokio::io::AsyncReadExt; -use tokio::io::duplex; -use tokio::sync::mpsc; -use tokio::time::{Duration as TokioDuration, timeout}; - -fn make_pooled_payload(data: &[u8]) -> PooledBuffer { - let pool = Arc::new(BufferPool::with_config(data.len().max(1), 4)); - let mut payload = pool.get(); - payload.resize(data.len(), 0); - payload[..data.len()].copy_from_slice(data); - payload -} - -#[tokio::test] -async fn write_client_payload_abridged_short_quickack_sets_flag_and_preserves_payload() { - let (mut read_side, write_side) = duplex(4096); - let key = [0u8; 32]; - let iv = 0u128; - - let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); - let mut decryptor = AesCtr::new(&key, iv); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - let payload = vec![0xA1, 0xB2, 0xC3, 0xD4, 0x10, 0x20, 0x30, 0x40]; - - write_client_payload( - &mut writer, - ProtoTag::Abridged, - RPC_FLAG_QUICKACK, - &payload, - &rng, - &mut frame_buf, - ) - .await - .expect("abridged quickack payload should serialize"); - writer.flush().await.expect("flush must succeed"); - - let mut encrypted = vec![0u8; 1 + payload.len()]; - read_side - .read_exact(&mut encrypted) - .await - .expect("must read serialized abridged frame"); - let plaintext = decryptor.decrypt(&encrypted); - - assert_eq!(plaintext[0], 0x80 | ((payload.len() / 4) as u8)); - assert_eq!(&plaintext[1..], payload.as_slice()); -} - -#[tokio::test] -async fn write_client_payload_abridged_extended_header_is_encoded_correctly() { - let (mut read_side, write_side) = duplex(16 * 1024); - let key = [0u8; 32]; - let iv = 0u128; - - let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); - let mut decryptor = AesCtr::new(&key, iv); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - - // Boundary where abridged switches to extended length encoding. - let payload = vec![0x5Au8; 0x7f * 4]; - - write_client_payload( - &mut writer, - ProtoTag::Abridged, - RPC_FLAG_QUICKACK, - &payload, - &rng, - &mut frame_buf, - ) - .await - .expect("extended abridged payload should serialize"); - writer.flush().await.expect("flush must succeed"); - - let mut encrypted = vec![0u8; 4 + payload.len()]; - read_side - .read_exact(&mut encrypted) - .await - .expect("must read serialized extended abridged frame"); - let plaintext = decryptor.decrypt(&encrypted); - - assert_eq!(plaintext[0], 0xff, "0x7f with quickack bit must be set"); - assert_eq!(&plaintext[1..4], &[0x7f, 0x00, 0x00]); - assert_eq!(&plaintext[4..], payload.as_slice()); -} - -#[tokio::test] -async fn write_client_payload_abridged_misaligned_is_rejected_fail_closed() { - let (_read_side, write_side) = duplex(1024); - let key = [0u8; 32]; - let iv = 0u128; - - let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - - let err = write_client_payload( - &mut writer, - ProtoTag::Abridged, - 0, - &[1, 2, 3], - &rng, - &mut frame_buf, - ) - .await - .expect_err("misaligned abridged payload must be rejected"); - - let msg = format!("{err}"); - assert!( - msg.contains("4-byte aligned"), - "error should explain alignment contract, got: {msg}" - ); -} - -#[tokio::test] -async fn write_client_payload_secure_misaligned_is_rejected_fail_closed() { - let (_read_side, write_side) = duplex(1024); - let key = [0u8; 32]; - let iv = 0u128; - - let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - - let err = write_client_payload( - &mut writer, - ProtoTag::Secure, - 0, - &[9, 8, 7, 6, 5], - &rng, - &mut frame_buf, - ) - .await - .expect_err("misaligned secure payload must be rejected"); - - let msg = format!("{err}"); - assert!( - msg.contains("Secure payload must be 4-byte aligned"), - "error should be explicit for fail-closed triage, got: {msg}" - ); -} - -#[tokio::test] -async fn write_client_payload_intermediate_quickack_sets_length_msb() { - let (mut read_side, write_side) = duplex(4096); - let key = [0u8; 32]; - let iv = 0u128; - - let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); - let mut decryptor = AesCtr::new(&key, iv); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - let payload = b"hello-middle-relay"; - - write_client_payload( - &mut writer, - ProtoTag::Intermediate, - RPC_FLAG_QUICKACK, - payload, - &rng, - &mut frame_buf, - ) - .await - .expect("intermediate quickack payload should serialize"); - writer.flush().await.expect("flush must succeed"); - - let mut encrypted = vec![0u8; 4 + payload.len()]; - read_side - .read_exact(&mut encrypted) - .await - .expect("must read intermediate frame"); - let plaintext = decryptor.decrypt(&encrypted); - - let mut len_bytes = [0u8; 4]; - len_bytes.copy_from_slice(&plaintext[..4]); - let len_with_flags = u32::from_le_bytes(len_bytes); - assert_ne!(len_with_flags & 0x8000_0000, 0, "quickack bit must be set"); - assert_eq!((len_with_flags & 0x7fff_ffff) as usize, payload.len()); - assert_eq!(&plaintext[4..], payload); -} - -#[tokio::test] -async fn write_client_payload_secure_quickack_prefix_and_padding_bounds_hold() { - let (mut read_side, write_side) = duplex(4096); - let key = [0u8; 32]; - let iv = 0u128; - - let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); - let mut decryptor = AesCtr::new(&key, iv); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - let payload = vec![0x33u8; 100]; // 4-byte aligned as required by secure mode. - - write_client_payload( - &mut writer, - ProtoTag::Secure, - RPC_FLAG_QUICKACK, - &payload, - &rng, - &mut frame_buf, - ) - .await - .expect("secure quickack payload should serialize"); - writer.flush().await.expect("flush must succeed"); - - // Secure mode adds 1..=3 bytes of randomized tail padding. - let mut encrypted_header = [0u8; 4]; - read_side - .read_exact(&mut encrypted_header) - .await - .expect("must read secure header"); - let decrypted_header = decryptor.decrypt(&encrypted_header); - let header: [u8; 4] = decrypted_header - .try_into() - .expect("decrypted secure header must be 4 bytes"); - let wire_len_raw = u32::from_le_bytes(header); - - assert_ne!( - wire_len_raw & 0x8000_0000, - 0, - "secure quickack bit must be set" - ); - - let wire_len = (wire_len_raw & 0x7fff_ffff) as usize; - assert!(wire_len >= payload.len()); - let padding_len = wire_len - payload.len(); - assert!( - (1..=3).contains(&padding_len), - "secure writer must add bounded random tail padding, got {padding_len}" - ); - - let mut encrypted_body = vec![0u8; wire_len]; - read_side - .read_exact(&mut encrypted_body) - .await - .expect("must read secure body"); - let decrypted_body = decryptor.decrypt(&encrypted_body); - assert_eq!(&decrypted_body[..payload.len()], payload.as_slice()); -} - -#[tokio::test] -#[ignore = "heavy: allocates >64MiB to validate abridged too-large fail-closed branch"] -async fn write_client_payload_abridged_too_large_is_rejected_fail_closed() { - let (_read_side, write_side) = duplex(1024); - let key = [0u8; 32]; - let iv = 0u128; - - let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - - // Exactly one 4-byte word above the encodable 24-bit abridged length range. - let payload = vec![0x00u8; (1 << 24) * 4]; - let err = write_client_payload( - &mut writer, - ProtoTag::Abridged, - 0, - &payload, - &rng, - &mut frame_buf, - ) - .await - .expect_err("oversized abridged payload must be rejected"); - - let msg = format!("{err}"); - assert!( - msg.contains("Abridged frame too large"), - "error must clearly indicate oversize fail-close path, got: {msg}" - ); -} - -#[tokio::test] -async fn write_client_ack_intermediate_is_little_endian() { - let (mut read_side, write_side) = duplex(1024); - let key = [0u8; 32]; - let iv = 0u128; - let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); - let mut decryptor = AesCtr::new(&key, iv); - - write_client_ack(&mut writer, ProtoTag::Intermediate, 0x11_22_33_44) - .await - .expect("ack serialization should succeed"); - writer.flush().await.expect("flush must succeed"); - - let mut encrypted = [0u8; 4]; - read_side - .read_exact(&mut encrypted) - .await - .expect("must read ack bytes"); - let plain = decryptor.decrypt(&encrypted); - assert_eq!(plain.as_slice(), &0x11_22_33_44u32.to_le_bytes()); -} - -#[tokio::test] -async fn write_client_ack_abridged_is_big_endian() { - let (mut read_side, write_side) = duplex(1024); - let key = [0u8; 32]; - let iv = 0u128; - let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); - let mut decryptor = AesCtr::new(&key, iv); - - write_client_ack(&mut writer, ProtoTag::Abridged, 0xDE_AD_BE_EF) - .await - .expect("ack serialization should succeed"); - writer.flush().await.expect("flush must succeed"); - - let mut encrypted = [0u8; 4]; - read_side - .read_exact(&mut encrypted) - .await - .expect("must read ack bytes"); - let plain = decryptor.decrypt(&encrypted); - assert_eq!(plain.as_slice(), &0xDE_AD_BE_EFu32.to_be_bytes()); -} - -#[tokio::test] -async fn write_client_payload_abridged_short_boundary_0x7e_is_single_byte_header() { - let (mut read_side, write_side) = duplex(1024 * 1024); - let key = [0u8; 32]; - let iv = 0u128; - let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); - let mut decryptor = AesCtr::new(&key, iv); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - let payload = vec![0xABu8; 0x7e * 4]; - - write_client_payload( - &mut writer, - ProtoTag::Abridged, - 0, - &payload, - &rng, - &mut frame_buf, - ) - .await - .expect("boundary payload should serialize"); - writer.flush().await.expect("flush must succeed"); - - let mut encrypted = vec![0u8; 1 + payload.len()]; - read_side.read_exact(&mut encrypted).await.unwrap(); - let plain = decryptor.decrypt(&encrypted); - assert_eq!(plain[0], 0x7e); - assert_eq!(&plain[1..], payload.as_slice()); -} - -#[tokio::test] -async fn write_client_payload_abridged_extended_without_quickack_has_clean_prefix() { - let (mut read_side, write_side) = duplex(16 * 1024); - let key = [0u8; 32]; - let iv = 0u128; - let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); - let mut decryptor = AesCtr::new(&key, iv); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - let payload = vec![0x42u8; 0x80 * 4]; - - write_client_payload( - &mut writer, - ProtoTag::Abridged, - 0, - &payload, - &rng, - &mut frame_buf, - ) - .await - .expect("extended payload should serialize"); - writer.flush().await.expect("flush must succeed"); - - let mut encrypted = vec![0u8; 4 + payload.len()]; - read_side.read_exact(&mut encrypted).await.unwrap(); - let plain = decryptor.decrypt(&encrypted); - assert_eq!(plain[0], 0x7f); - assert_eq!(&plain[1..4], &[0x80, 0x00, 0x00]); - assert_eq!(&plain[4..], payload.as_slice()); -} - -#[tokio::test] -async fn write_client_payload_intermediate_zero_length_emits_header_only() { - let (mut read_side, write_side) = duplex(1024); - let key = [0u8; 32]; - let iv = 0u128; - let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); - let mut decryptor = AesCtr::new(&key, iv); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - - write_client_payload( - &mut writer, - ProtoTag::Intermediate, - 0, - &[], - &rng, - &mut frame_buf, - ) - .await - .expect("zero-length intermediate payload should serialize"); - writer.flush().await.expect("flush must succeed"); - - let mut encrypted = [0u8; 4]; - read_side.read_exact(&mut encrypted).await.unwrap(); - let plain = decryptor.decrypt(&encrypted); - assert_eq!(plain.as_slice(), &[0, 0, 0, 0]); -} - -#[tokio::test] -async fn write_client_payload_intermediate_ignores_unrelated_flags() { - let (mut read_side, write_side) = duplex(1024); - let key = [0u8; 32]; - let iv = 0u128; - let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); - let mut decryptor = AesCtr::new(&key, iv); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - let payload = [7u8; 12]; - - write_client_payload( - &mut writer, - ProtoTag::Intermediate, - 0x4000_0000, - &payload, - &rng, - &mut frame_buf, - ) - .await - .expect("payload should serialize"); - writer.flush().await.expect("flush must succeed"); - - let mut encrypted = [0u8; 16]; - read_side.read_exact(&mut encrypted).await.unwrap(); - let plain = decryptor.decrypt(&encrypted); - let len = u32::from_le_bytes(plain[0..4].try_into().unwrap()); - assert_eq!(len, payload.len() as u32, "only quickack bit may affect header"); - assert_eq!(&plain[4..], payload.as_slice()); -} - -#[tokio::test] -async fn write_client_payload_secure_without_quickack_keeps_msb_clear() { - let (mut read_side, write_side) = duplex(4096); - let key = [0u8; 32]; - let iv = 0u128; - let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); - let mut decryptor = AesCtr::new(&key, iv); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - let payload = [0x1Du8; 64]; - - write_client_payload( - &mut writer, - ProtoTag::Secure, - 0, - &payload, - &rng, - &mut frame_buf, - ) - .await - .expect("payload should serialize"); - writer.flush().await.expect("flush must succeed"); - - let mut encrypted_header = [0u8; 4]; - read_side.read_exact(&mut encrypted_header).await.unwrap(); - let plain_header = decryptor.decrypt(&encrypted_header); - let h: [u8; 4] = plain_header.as_slice().try_into().unwrap(); - let wire_len_raw = u32::from_le_bytes(h); - assert_eq!(wire_len_raw & 0x8000_0000, 0, "quickack bit must stay clear"); -} - -#[tokio::test] -async fn secure_padding_light_fuzz_distribution_has_multiple_outcomes() { - let (mut read_side, write_side) = duplex(256 * 1024); - let key = [0u8; 32]; - let iv = 0u128; - let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); - let mut decryptor = AesCtr::new(&key, iv); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - let payload = [0x55u8; 100]; - let mut seen = [false; 4]; - - for _ in 0..96 { - write_client_payload( - &mut writer, - ProtoTag::Secure, - 0, - &payload, - &rng, - &mut frame_buf, - ) - .await - .expect("secure payload should serialize"); - writer.flush().await.expect("flush must succeed"); - - let mut encrypted_header = [0u8; 4]; - read_side.read_exact(&mut encrypted_header).await.unwrap(); - let plain_header = decryptor.decrypt(&encrypted_header); - let h: [u8; 4] = plain_header.as_slice().try_into().unwrap(); - let wire_len = (u32::from_le_bytes(h) & 0x7fff_ffff) as usize; - let padding_len = wire_len - payload.len(); - assert!((1..=3).contains(&padding_len)); - seen[padding_len] = true; - - let mut encrypted_body = vec![0u8; wire_len]; - read_side.read_exact(&mut encrypted_body).await.unwrap(); - let _ = decryptor.decrypt(&encrypted_body); - } - - let distinct = (1..=3).filter(|idx| seen[*idx]).count(); - assert!( - distinct >= 2, - "padding generator should not collapse to a single outcome under campaign" - ); -} - -#[tokio::test] -async fn write_client_payload_mixed_proto_sequence_preserves_stream_sync() { - let (mut read_side, write_side) = duplex(128 * 1024); - let key = [0u8; 32]; - let iv = 0u128; - let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); - let mut decryptor = AesCtr::new(&key, iv); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - - let p1 = vec![1u8; 8]; - let p2 = vec![2u8; 16]; - let p3 = vec![3u8; 20]; - - write_client_payload(&mut writer, ProtoTag::Abridged, 0, &p1, &rng, &mut frame_buf) - .await - .unwrap(); - write_client_payload( - &mut writer, - ProtoTag::Intermediate, - RPC_FLAG_QUICKACK, - &p2, - &rng, - &mut frame_buf, - ) - .await - .unwrap(); - write_client_payload(&mut writer, ProtoTag::Secure, 0, &p3, &rng, &mut frame_buf) - .await - .unwrap(); - writer.flush().await.unwrap(); - - // Frame 1: abridged short. - let mut e1 = vec![0u8; 1 + p1.len()]; - read_side.read_exact(&mut e1).await.unwrap(); - let d1 = decryptor.decrypt(&e1); - assert_eq!(d1[0], (p1.len() / 4) as u8); - assert_eq!(&d1[1..], p1.as_slice()); - - // Frame 2: intermediate with quickack. - let mut e2 = vec![0u8; 4 + p2.len()]; - read_side.read_exact(&mut e2).await.unwrap(); - let d2 = decryptor.decrypt(&e2); - let l2 = u32::from_le_bytes(d2[0..4].try_into().unwrap()); - assert_ne!(l2 & 0x8000_0000, 0); - assert_eq!((l2 & 0x7fff_ffff) as usize, p2.len()); - assert_eq!(&d2[4..], p2.as_slice()); - - // Frame 3: secure with bounded tail. - let mut e3h = [0u8; 4]; - read_side.read_exact(&mut e3h).await.unwrap(); - let d3h = decryptor.decrypt(&e3h); - let l3 = (u32::from_le_bytes(d3h.as_slice().try_into().unwrap()) & 0x7fff_ffff) as usize; - assert!(l3 >= p3.len()); - assert!((1..=3).contains(&(l3 - p3.len()))); - let mut e3b = vec![0u8; l3]; - read_side.read_exact(&mut e3b).await.unwrap(); - let d3b = decryptor.decrypt(&e3b); - assert_eq!(&d3b[..p3.len()], p3.as_slice()); -} - -#[test] -fn should_yield_sender_boundary_matrix_blackhat() { - assert!(!should_yield_c2me_sender(0, false)); - assert!(!should_yield_c2me_sender(0, true)); - assert!(!should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET - 1, true)); - assert!(!should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET, false)); - assert!(should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET, true)); - assert!(should_yield_c2me_sender( - C2ME_SENDER_FAIRNESS_BUDGET.saturating_add(1024), - true - )); -} - -#[test] -fn should_yield_sender_light_fuzz_matches_oracle() { - let mut s: u64 = 0xD00D_BAAD_F00D_CAFE; - for _ in 0..5000 { - s ^= s << 7; - s ^= s >> 9; - s ^= s << 8; - let sent = (s as usize) & 0x1fff; - let backlog = (s & 1) != 0; - - let expected = backlog && sent >= C2ME_SENDER_FAIRNESS_BUDGET; - assert_eq!(should_yield_c2me_sender(sent, backlog), expected); - } -} - -#[test] -fn quota_would_be_exceeded_exact_remaining_one_byte() { - let stats = Stats::new(); - let user = "quota-edge"; - let quota = 100u64; - stats.add_user_octets_to(user, 99); - - assert!( - !quota_would_be_exceeded_for_user(&stats, user, Some(quota), 1), - "exactly remaining budget should be allowed" - ); - assert!( - quota_would_be_exceeded_for_user(&stats, user, Some(quota), 2), - "one byte beyond remaining budget must be rejected" - ); -} - -#[test] -fn quota_would_be_exceeded_saturating_edge_remains_fail_closed() { - let stats = Stats::new(); - let user = "quota-saturating-edge"; - let quota = u64::MAX - 3; - stats.add_user_octets_to(user, u64::MAX - 4); - - assert!( - quota_would_be_exceeded_for_user(&stats, user, Some(quota), 2), - "saturating arithmetic edge must stay fail-closed" - ); -} - -#[test] -fn quota_exceeded_boundary_is_inclusive() { - let stats = Stats::new(); - let user = "quota-inclusive-boundary"; - stats.add_user_octets_to(user, 50); - - assert!(quota_exceeded_for_user(&stats, user, Some(50))); - assert!(!quota_exceeded_for_user(&stats, user, Some(51))); -} - -#[tokio::test] -async fn enqueue_c2me_close_fast_path_succeeds_without_backpressure() { - let (tx, mut rx) = mpsc::channel::(4); - enqueue_c2me_command(&tx, C2MeCommand::Close) - .await - .expect("close should enqueue on fast path"); - - let recv = timeout(TokioDuration::from_millis(50), rx.recv()) - .await - .expect("must receive close command") - .expect("close command should be present"); - assert!(matches!(recv, C2MeCommand::Close)); -} - -#[tokio::test] -async fn enqueue_c2me_data_full_then_drain_preserves_order() { - let (tx, mut rx) = mpsc::channel::(1); - tx.send(C2MeCommand::Data { - payload: make_pooled_payload(&[1]), - flags: 10, - }) - .await - .unwrap(); - - let tx2 = tx.clone(); - let producer = tokio::spawn(async move { - enqueue_c2me_command( - &tx2, - C2MeCommand::Data { - payload: make_pooled_payload(&[2, 2]), - flags: 20, - }, - ) - .await - }); - - tokio::time::sleep(TokioDuration::from_millis(10)).await; - - let first = rx.recv().await.expect("first item should exist"); - match first { - C2MeCommand::Data { payload, flags } => { - assert_eq!(payload.as_ref(), &[1]); - assert_eq!(flags, 10); - } - C2MeCommand::Close => panic!("unexpected close as first item"), - } - - producer.await.unwrap().expect("producer should complete"); - - let second = timeout(TokioDuration::from_millis(100), rx.recv()) - .await - .unwrap() - .expect("second item should exist"); - match second { - C2MeCommand::Data { payload, flags } => { - assert_eq!(payload.as_ref(), &[2, 2]); - assert_eq!(flags, 20); - } - C2MeCommand::Close => panic!("unexpected close as second item"), - } -} diff --git a/src/proxy/tests/middle_relay_idle_policy_security_tests.rs b/src/proxy/tests/middle_relay_idle_policy_security_tests.rs index 3e0b30f..fd3243d 100644 --- a/src/proxy/tests/middle_relay_idle_policy_security_tests.rs +++ b/src/proxy/tests/middle_relay_idle_policy_security_tests.rs @@ -2,8 +2,8 @@ use super::*; use crate::crypto::AesCtr; use crate::stats::Stats; use crate::stream::{BufferPool, CryptoReader}; +use std::sync::Arc; use std::sync::atomic::AtomicU64; -use std::sync::{Arc, Mutex, OnceLock}; use tokio::io::AsyncWriteExt; use tokio::io::duplex; use tokio::time::{Duration as TokioDuration, Instant as TokioInstant, timeout}; @@ -48,18 +48,6 @@ fn make_idle_policy(soft_ms: u64, hard_ms: u64, grace_ms: u64) -> RelayClientIdl } } -fn idle_pressure_test_lock() -> &'static Mutex<()> { - static TEST_LOCK: OnceLock> = OnceLock::new(); - TEST_LOCK.get_or_init(|| Mutex::new(())) -} - -fn acquire_idle_pressure_test_lock() -> std::sync::MutexGuard<'static, ()> { - match idle_pressure_test_lock().lock() { - Ok(guard) => guard, - Err(poisoned) => poisoned.into_inner(), - } -} - #[tokio::test] async fn idle_policy_soft_mark_then_hard_close_increments_reason_counters() { let (reader, _writer) = duplex(1024); @@ -372,7 +360,7 @@ async fn stress_many_idle_sessions_fail_closed_without_hang() { #[test] fn pressure_evicts_oldest_idle_candidate_with_deterministic_ordering() { - let _guard = acquire_idle_pressure_test_lock(); + let _guard = relay_idle_pressure_test_scope(); clear_relay_idle_pressure_state_for_testing(); let stats = Stats::new(); @@ -402,7 +390,7 @@ fn pressure_evicts_oldest_idle_candidate_with_deterministic_ordering() { #[test] fn pressure_does_not_evict_without_new_pressure_signal() { - let _guard = acquire_idle_pressure_test_lock(); + let _guard = relay_idle_pressure_test_scope(); clear_relay_idle_pressure_state_for_testing(); let stats = Stats::new(); @@ -421,7 +409,7 @@ fn pressure_does_not_evict_without_new_pressure_signal() { #[test] fn stress_pressure_eviction_preserves_fifo_across_many_candidates() { - let _guard = acquire_idle_pressure_test_lock(); + let _guard = relay_idle_pressure_test_scope(); clear_relay_idle_pressure_state_for_testing(); let stats = Stats::new(); @@ -457,7 +445,7 @@ fn stress_pressure_eviction_preserves_fifo_across_many_candidates() { #[test] fn blackhat_single_pressure_event_must_not_evict_more_than_one_candidate() { - let _guard = acquire_idle_pressure_test_lock(); + let _guard = relay_idle_pressure_test_scope(); clear_relay_idle_pressure_state_for_testing(); let stats = Stats::new(); @@ -491,7 +479,7 @@ fn blackhat_single_pressure_event_must_not_evict_more_than_one_candidate() { #[test] fn blackhat_pressure_counter_must_track_global_budget_not_per_session_cursor() { - let _guard = acquire_idle_pressure_test_lock(); + let _guard = relay_idle_pressure_test_scope(); clear_relay_idle_pressure_state_for_testing(); let stats = Stats::new(); @@ -524,7 +512,7 @@ fn blackhat_pressure_counter_must_track_global_budget_not_per_session_cursor() { #[test] fn blackhat_stale_pressure_before_idle_mark_must_not_trigger_eviction() { - let _guard = acquire_idle_pressure_test_lock(); + let _guard = relay_idle_pressure_test_scope(); clear_relay_idle_pressure_state_for_testing(); let stats = Stats::new(); @@ -543,7 +531,7 @@ fn blackhat_stale_pressure_before_idle_mark_must_not_trigger_eviction() { #[test] fn blackhat_stale_pressure_must_not_evict_any_of_newly_marked_batch() { - let _guard = acquire_idle_pressure_test_lock(); + let _guard = relay_idle_pressure_test_scope(); clear_relay_idle_pressure_state_for_testing(); let stats = Stats::new(); @@ -575,7 +563,7 @@ fn blackhat_stale_pressure_must_not_evict_any_of_newly_marked_batch() { #[test] fn blackhat_stale_pressure_seen_without_candidates_must_be_globally_invalidated() { - let _guard = acquire_idle_pressure_test_lock(); + let _guard = relay_idle_pressure_test_scope(); clear_relay_idle_pressure_state_for_testing(); let stats = Stats::new(); @@ -601,7 +589,7 @@ fn blackhat_stale_pressure_seen_without_candidates_must_be_globally_invalidated( #[test] fn blackhat_stale_pressure_must_not_survive_candidate_churn() { - let _guard = acquire_idle_pressure_test_lock(); + let _guard = relay_idle_pressure_test_scope(); clear_relay_idle_pressure_state_for_testing(); let stats = Stats::new(); @@ -621,7 +609,7 @@ fn blackhat_stale_pressure_must_not_survive_candidate_churn() { #[test] fn blackhat_pressure_seq_saturation_must_not_disable_future_pressure_accounting() { - let _guard = acquire_idle_pressure_test_lock(); + let _guard = relay_idle_pressure_test_scope(); clear_relay_idle_pressure_state_for_testing(); { @@ -646,7 +634,7 @@ fn blackhat_pressure_seq_saturation_must_not_disable_future_pressure_accounting( #[test] fn blackhat_pressure_seq_saturation_must_not_break_multiple_distinct_events() { - let _guard = acquire_idle_pressure_test_lock(); + let _guard = relay_idle_pressure_test_scope(); clear_relay_idle_pressure_state_for_testing(); { @@ -673,7 +661,7 @@ fn blackhat_pressure_seq_saturation_must_not_break_multiple_distinct_events() { #[tokio::test(flavor = "multi_thread", worker_threads = 4)] async fn integration_race_single_pressure_event_allows_at_most_one_eviction_under_parallel_claims() { - let _guard = acquire_idle_pressure_test_lock(); + let _guard = relay_idle_pressure_test_scope(); clear_relay_idle_pressure_state_for_testing(); let stats = Arc::new(Stats::new()); @@ -738,7 +726,7 @@ async fn integration_race_single_pressure_event_allows_at_most_one_eviction_unde #[tokio::test(flavor = "multi_thread", worker_threads = 4)] async fn integration_race_burst_pressure_with_churn_preserves_empty_set_invalidation_and_budget() { - let _guard = acquire_idle_pressure_test_lock(); + let _guard = relay_idle_pressure_test_scope(); clear_relay_idle_pressure_state_for_testing(); let stats = Arc::new(Stats::new()); diff --git a/src/proxy/tests/middle_relay_idle_registry_poison_security_tests.rs b/src/proxy/tests/middle_relay_idle_registry_poison_security_tests.rs new file mode 100644 index 0000000..b43825c --- /dev/null +++ b/src/proxy/tests/middle_relay_idle_registry_poison_security_tests.rs @@ -0,0 +1,62 @@ +use super::*; +use std::panic::{AssertUnwindSafe, catch_unwind}; + +#[test] +fn blackhat_registry_poison_recovers_with_fail_closed_reset_and_pressure_accounting() { + let _guard = relay_idle_pressure_test_scope(); + clear_relay_idle_pressure_state_for_testing(); + + let _ = catch_unwind(AssertUnwindSafe(|| { + let registry = relay_idle_candidate_registry(); + let mut guard = registry + .lock() + .expect("registry lock must be acquired before poison"); + guard.by_conn_id.insert( + 999, + RelayIdleCandidateMeta { + mark_order_seq: 1, + mark_pressure_seq: 0, + }, + ); + guard.ordered.insert((1, 999)); + panic!("intentional poison for idle-registry recovery"); + })); + + // Helper lock must recover from poison, reset stale state, and continue. + assert!(mark_relay_idle_candidate(42)); + assert_eq!(oldest_relay_idle_candidate(), Some(42)); + + let before = relay_pressure_event_seq(); + note_relay_pressure_event(); + let after = relay_pressure_event_seq(); + assert!( + after > before, + "pressure accounting must still advance after poison" + ); + + clear_relay_idle_pressure_state_for_testing(); +} + +#[test] +fn clear_state_helper_must_reset_poisoned_registry_for_deterministic_fifo_tests() { + let _guard = relay_idle_pressure_test_scope(); + clear_relay_idle_pressure_state_for_testing(); + + let _ = catch_unwind(AssertUnwindSafe(|| { + let registry = relay_idle_candidate_registry(); + let _guard = registry + .lock() + .expect("registry lock must be acquired before poison"); + panic!("intentional poison while lock held"); + })); + + clear_relay_idle_pressure_state_for_testing(); + + assert_eq!(oldest_relay_idle_candidate(), None); + assert_eq!(relay_pressure_event_seq(), 0); + + assert!(mark_relay_idle_candidate(7)); + assert_eq!(oldest_relay_idle_candidate(), Some(7)); + + clear_relay_idle_pressure_state_for_testing(); +} diff --git a/src/proxy/tests/middle_relay_quota_overflow_lock_security_tests.rs b/src/proxy/tests/middle_relay_quota_overflow_lock_security_tests.rs deleted file mode 100644 index d06e103..0000000 --- a/src/proxy/tests/middle_relay_quota_overflow_lock_security_tests.rs +++ /dev/null @@ -1,131 +0,0 @@ -use super::*; -use dashmap::DashMap; -use std::sync::Arc; - -#[test] -fn saturation_uses_stable_overflow_lock_without_cache_growth() { - let _guard = super::quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let prefix = format!("middle-quota-held-{}", std::process::id()); - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - for idx in 0..QUOTA_USER_LOCKS_MAX { - retained.push(quota_user_lock(&format!("{prefix}-{idx}"))); - } - - assert_eq!(map.len(), QUOTA_USER_LOCKS_MAX); - - let user = format!("middle-quota-overflow-{}", std::process::id()); - let first = quota_user_lock(&user); - let second = quota_user_lock(&user); - - assert!( - Arc::ptr_eq(&first, &second), - "overflow user must get deterministic same lock while cache is saturated" - ); - assert_eq!( - map.len(), - QUOTA_USER_LOCKS_MAX, - "overflow path must not grow bounded lock map" - ); - assert!( - map.get(&user).is_none(), - "overflow user should stay outside bounded lock map under saturation" - ); - - drop(retained); -} - -#[test] -fn overflow_striping_keeps_different_users_distributed() { - let _guard = super::quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let prefix = format!("middle-quota-dist-held-{}", std::process::id()); - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - for idx in 0..QUOTA_USER_LOCKS_MAX { - retained.push(quota_user_lock(&format!("{prefix}-{idx}"))); - } - - let a = quota_user_lock("middle-overflow-user-a"); - let b = quota_user_lock("middle-overflow-user-b"); - let c = quota_user_lock("middle-overflow-user-c"); - - let distinct = [ - Arc::as_ptr(&a) as usize, - Arc::as_ptr(&b) as usize, - Arc::as_ptr(&c) as usize, - ] - .iter() - .copied() - .collect::>() - .len(); - - assert!( - distinct >= 2, - "striped overflow lock set should avoid collapsing all users to one lock" - ); - - drop(retained); -} - -#[test] -fn reclaim_path_caches_new_user_after_stale_entries_drop() { - let _guard = super::quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let prefix = format!("middle-quota-reclaim-held-{}", std::process::id()); - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - for idx in 0..QUOTA_USER_LOCKS_MAX { - retained.push(quota_user_lock(&format!("{prefix}-{idx}"))); - } - - drop(retained); - - let user = format!("middle-quota-reclaim-user-{}", std::process::id()); - let got = quota_user_lock(&user); - assert!(map.get(&user).is_some()); - assert!( - Arc::strong_count(&got) >= 2, - "after reclaim, lock should be held both by caller and map" - ); -} - -#[test] -fn overflow_path_same_user_is_stable_across_parallel_threads() { - let _guard = super::quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - for idx in 0..QUOTA_USER_LOCKS_MAX { - retained.push(quota_user_lock(&format!( - "middle-quota-thread-held-{}-{idx}", - std::process::id() - ))); - } - - let user = format!("middle-quota-overflow-thread-user-{}", std::process::id()); - let mut workers = Vec::new(); - for _ in 0..32 { - let user = user.clone(); - workers.push(std::thread::spawn(move || quota_user_lock(&user))); - } - - let first = workers - .remove(0) - .join() - .expect("thread must return lock handle"); - for worker in workers { - let got = worker.join().expect("thread must return lock handle"); - assert!( - Arc::ptr_eq(&first, &got), - "same overflow user should resolve to one striped lock even under contention" - ); - } - - drop(retained); -} diff --git a/src/proxy/tests/middle_relay_security_tests.rs b/src/proxy/tests/middle_relay_security_tests.rs deleted file mode 100644 index 3be9524..0000000 --- a/src/proxy/tests/middle_relay_security_tests.rs +++ /dev/null @@ -1,2509 +0,0 @@ -use super::*; -use crate::config::{GeneralConfig, MeRouteNoWriterMode, MeSocksKdfPolicy, MeWriterPickMode}; -use crate::crypto::AesCtr; -use crate::crypto::SecureRandom; -use crate::network::probe::NetworkDecision; -use crate::proxy::handshake::HandshakeSuccess; -use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController}; -use crate::stats::Stats; -use crate::stream::{BufferPool, CryptoReader, CryptoWriter, PooledBuffer}; -use crate::transport::middle_proxy::MePool; -use bytes::Bytes; -use rand::rngs::StdRng; -use rand::{RngExt, SeedableRng}; -use std::collections::{HashMap, HashSet}; -use std::net::SocketAddr; -use std::sync::Arc; -use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering}; -use std::sync::Mutex; -use std::thread; -use tokio::io::AsyncReadExt; -use tokio::io::AsyncWriteExt; -use tokio::io::duplex; -use tokio::sync::Barrier; -use tokio::time::{Duration as TokioDuration, timeout}; - -fn make_pooled_payload(data: &[u8]) -> PooledBuffer { - let pool = Arc::new(BufferPool::with_config(data.len().max(1), 4)); - let mut payload = pool.get(); - payload.resize(data.len(), 0); - payload[..data.len()].copy_from_slice(data); - payload -} - -fn make_pooled_payload_from(pool: &Arc, data: &[u8]) -> PooledBuffer { - let mut payload = pool.get(); - payload.resize(data.len(), 0); - payload[..data.len()].copy_from_slice(data); - payload -} - -#[test] -fn should_yield_sender_only_on_budget_with_backlog() { - assert!(!should_yield_c2me_sender(0, true)); - assert!(!should_yield_c2me_sender( - C2ME_SENDER_FAIRNESS_BUDGET - 1, - true - )); - assert!(!should_yield_c2me_sender( - C2ME_SENDER_FAIRNESS_BUDGET, - false - )); - assert!(should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET, true)); -} - -#[tokio::test] -async fn enqueue_c2me_command_uses_try_send_fast_path() { - let (tx, mut rx) = mpsc::channel::(2); - enqueue_c2me_command( - &tx, - C2MeCommand::Data { - payload: make_pooled_payload(&[1, 2, 3]), - flags: 0, - }, - ) - .await - .unwrap(); - - let recv = timeout(TokioDuration::from_millis(50), rx.recv()) - .await - .unwrap() - .unwrap(); - match recv { - C2MeCommand::Data { payload, flags } => { - assert_eq!(payload.as_ref(), &[1, 2, 3]); - assert_eq!(flags, 0); - } - C2MeCommand::Close => panic!("unexpected close command"), - } -} - -#[tokio::test] -async fn enqueue_c2me_command_falls_back_to_send_when_queue_is_full() { - let (tx, mut rx) = mpsc::channel::(1); - tx.send(C2MeCommand::Data { - payload: make_pooled_payload(&[9]), - flags: 9, - }) - .await - .unwrap(); - - let tx2 = tx.clone(); - let producer = tokio::spawn(async move { - enqueue_c2me_command( - &tx2, - C2MeCommand::Data { - payload: make_pooled_payload(&[7, 7]), - flags: 7, - }, - ) - .await - .unwrap(); - }); - - let _ = timeout(TokioDuration::from_millis(100), rx.recv()) - .await - .unwrap(); - producer.await.unwrap(); - - let recv = timeout(TokioDuration::from_millis(100), rx.recv()) - .await - .unwrap() - .unwrap(); - match recv { - C2MeCommand::Data { payload, flags } => { - assert_eq!(payload.as_ref(), &[7, 7]); - assert_eq!(flags, 7); - } - C2MeCommand::Close => panic!("unexpected close command"), - } -} - -#[tokio::test] -async fn enqueue_c2me_command_closed_channel_recycles_payload() { - let pool = Arc::new(BufferPool::with_config(64, 4)); - let payload = make_pooled_payload_from(&pool, &[1, 2, 3, 4]); - let (tx, rx) = mpsc::channel::(1); - drop(rx); - - let result = enqueue_c2me_command(&tx, C2MeCommand::Data { payload, flags: 0 }).await; - - assert!(result.is_err(), "closed queue must fail enqueue"); - drop(result); - assert!( - pool.stats().pooled >= 1, - "payload must return to pool when enqueue fails on closed channel" - ); -} - -#[tokio::test] -async fn enqueue_c2me_command_full_then_closed_recycles_waiting_payload() { - let pool = Arc::new(BufferPool::with_config(64, 4)); - let (tx, rx) = mpsc::channel::(1); - - tx.send(C2MeCommand::Data { - payload: make_pooled_payload_from(&pool, &[9]), - flags: 1, - }) - .await - .unwrap(); - - let tx2 = tx.clone(); - let pool2 = pool.clone(); - let blocked_send = tokio::spawn(async move { - enqueue_c2me_command( - &tx2, - C2MeCommand::Data { - payload: make_pooled_payload_from(&pool2, &[7, 7, 7]), - flags: 2, - }, - ) - .await - }); - - tokio::time::sleep(TokioDuration::from_millis(10)).await; - drop(rx); - - let result = timeout(TokioDuration::from_secs(1), blocked_send) - .await - .expect("blocked send task must finish") - .expect("blocked send task must not panic"); - - assert!( - result.is_err(), - "closing receiver while sender is blocked must fail enqueue" - ); - drop(result); - assert!( - pool.stats().pooled >= 2, - "both queued and blocked payloads must return to pool after channel close" - ); -} - -#[tokio::test] -async fn enqueue_c2me_command_full_queue_times_out_without_receiver_progress() { - let (tx, _rx) = mpsc::channel::(1); - tx.send(C2MeCommand::Data { - payload: make_pooled_payload(&[1]), - flags: 0, - }) - .await - .unwrap(); - - let started = Instant::now(); - let result = enqueue_c2me_command( - &tx, - C2MeCommand::Data { - payload: make_pooled_payload(&[2, 2]), - flags: 1, - }, - ) - .await; - - assert!( - result.is_err(), - "enqueue must fail when queue stays full beyond bounded timeout" - ); - assert!( - started.elapsed() < TokioDuration::from_millis(400), - "full-queue timeout must resolve promptly" - ); -} - -#[test] -fn desync_dedup_cache_is_bounded() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("desync dedup test lock must be available"); - clear_desync_dedup_for_testing(); - - let now = Instant::now(); - for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 { - assert!( - should_emit_full_desync(key, false, now), - "unique keys up to cap must be tracked" - ); - } - - assert!( - should_emit_full_desync(u64::MAX, false, now), - "new key above cap must emit once after bounded eviction for forensic visibility" - ); - - assert!( - !should_emit_full_desync(u64::MAX, false, now), - "already tracked key inside dedup window must stay suppressed" - ); -} - -#[test] -fn quota_user_lock_cache_reuses_entry_for_same_user() { - let _guard = super::quota_user_lock_test_scope(); - - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let a = quota_user_lock("quota-user-a"); - let b = quota_user_lock("quota-user-a"); - assert!(Arc::ptr_eq(&a, &b), "same user must reuse same quota lock"); -} - -#[test] -fn quota_user_lock_cache_is_bounded_under_unique_churn() { - let _guard = super::quota_user_lock_test_scope(); - - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - for idx in 0..(QUOTA_USER_LOCKS_MAX + 128) { - let user = format!("quota-user-{idx}"); - let lock = quota_user_lock(&user); - drop(lock); - } - - assert!( - map.len() <= QUOTA_USER_LOCKS_MAX, - "quota lock cache must stay within configured bound" - ); -} - -#[test] -fn quota_user_lock_cache_saturation_returns_stable_overflow_lock_without_growth() { - let _guard = super::quota_user_lock_test_scope(); - - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - for attempt in 0..8u32 { - map.clear(); - - let prefix = format!("quota-held-user-{}-{attempt}", std::process::id()); - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - for idx in 0..QUOTA_USER_LOCKS_MAX { - let user = format!("{prefix}-{idx}"); - retained.push(quota_user_lock(&user)); - } - - if map.len() != QUOTA_USER_LOCKS_MAX { - drop(retained); - continue; - } - - let overflow_user = format!("quota-overflow-user-{}-{attempt}", std::process::id()); - let overflow_a = quota_user_lock(&overflow_user); - let overflow_b = quota_user_lock(&overflow_user); - - assert_eq!( - map.len(), - QUOTA_USER_LOCKS_MAX, - "overflow acquisition must not grow cache past hard limit" - ); - assert!( - map.get(&overflow_user).is_none(), - "overflow path should not cache new user lock when map is saturated and all entries are retained" - ); - assert!( - Arc::ptr_eq(&overflow_a, &overflow_b), - "overflow user lock should use deterministic striping under saturation" - ); - - drop(retained); - return; - } - - panic!("unable to observe stable saturated lock-cache precondition after bounded retries"); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn adversarial_quota_race_under_lock_cache_saturation_still_allows_only_one_winner() { - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - for idx in 0..QUOTA_USER_LOCKS_MAX { - let user = format!("quota-saturated-user-{idx}"); - retained.push(quota_user_lock(&user)); - } - - assert_eq!( - map.len(), - QUOTA_USER_LOCKS_MAX, - "precondition: cache must be saturated for overflow-user race test" - ); - - let stats = Stats::new(); - let bytes_me2c = AtomicU64::new(0); - let user = "gap-t04-saturated-lock-race-user"; - let barrier = Arc::new(Barrier::new(2)); - - let one = run_quota_race_attempt(&stats, &bytes_me2c, user, 0x55, 9101, barrier.clone()); - let two = run_quota_race_attempt(&stats, &bytes_me2c, user, 0x66, 9102, barrier); - let (r1, r2) = tokio::join!(one, two); - - assert!( - matches!(r1, Ok(_) | Err(ProxyError::DataQuotaExceeded { .. })) - && matches!(r2, Ok(_) | Err(ProxyError::DataQuotaExceeded { .. })), - "both racers must resolve cleanly without unexpected errors" - ); - assert!( - matches!(r1, Err(ProxyError::DataQuotaExceeded { .. })) - || matches!(r2, Err(ProxyError::DataQuotaExceeded { .. })), - "at least one racer must be quota-rejected even when lock cache is saturated" - ); - assert_eq!( - stats.get_user_total_octets(user), - 1, - "saturated lock cache must not permit double-success quota overshoot" - ); - - drop(retained); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn stress_quota_race_under_lock_cache_saturation_never_allows_double_success() { - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - for idx in 0..QUOTA_USER_LOCKS_MAX { - let user = format!("quota-saturated-stress-holder-{idx}"); - retained.push(quota_user_lock(&user)); - } - - let stats = Stats::new(); - let bytes_me2c = AtomicU64::new(0); - - for round in 0..128u64 { - let user = format!("gap-t04-saturated-race-round-{round}"); - let barrier = Arc::new(Barrier::new(2)); - - let one = run_quota_race_attempt( - &stats, - &bytes_me2c, - &user, - 0x71, - 12_000 + round, - barrier.clone(), - ); - let two = run_quota_race_attempt(&stats, &bytes_me2c, &user, 0x72, 13_000 + round, barrier); - - let (r1, r2) = tokio::join!(one, two); - assert!( - matches!(r1, Ok(_) | Err(ProxyError::DataQuotaExceeded { .. })) - && matches!(r2, Ok(_) | Err(ProxyError::DataQuotaExceeded { .. })), - "round {round}: racers must resolve cleanly" - ); - assert!( - matches!(r1, Err(ProxyError::DataQuotaExceeded { .. })) - || matches!(r2, Err(ProxyError::DataQuotaExceeded { .. })), - "round {round}: at least one racer must be quota-rejected" - ); - assert_eq!( - stats.get_user_total_octets(&user), - 1, - "round {round}: saturated cache must still enforce exactly one forwarded byte" - ); - } - - drop(retained); -} - -#[test] -fn adversarial_forensics_trace_id_should_not_alias_conn_id() { - let now = Instant::now(); - let trace_id = 0x1122_3344_5566_7788; - let conn_id = 0x8877_6655_4433_2211; - let state = RelayForensicsState { - trace_id, - conn_id, - user: "trace-user".to_string(), - peer: "198.51.100.17:443".parse().unwrap(), - peer_hash: 0x8877_6655_4433_2211, - started_at: now, - bytes_c2me: 0, - bytes_me2c: Arc::new(AtomicU64::new(0)), - desync_all_full: false, - }; - - assert_ne!( - state.trace_id, state.conn_id, - "security expectation: trace correlation should be independent of connection identity" - ); - assert_eq!(state.trace_id, trace_id); - assert_eq!(state.conn_id, conn_id); -} - -#[tokio::test] -async fn abridged_ack_uses_big_endian_confirm_bytes_after_decryption() { - let (mut writer_side, reader_side) = duplex(8); - let key = [0u8; 32]; - let iv = 0u128; - let mut writer = CryptoWriter::new(reader_side, AesCtr::new(&key, iv), 8 * 1024); - - write_client_ack(&mut writer, ProtoTag::Abridged, 0x11_22_33_44) - .await - .expect("ack write must succeed"); - - let mut observed = [0u8; 4]; - writer_side - .read_exact(&mut observed) - .await - .expect("ack bytes must be readable"); - let mut decryptor = AesCtr::new(&key, iv); - let decrypted = decryptor.decrypt(&observed); - - assert_eq!( - decrypted, - 0x11_22_33_44u32.to_be_bytes(), - "abridged ACK should encode confirm bytes in big-endian order" - ); -} - -#[test] -fn desync_dedup_full_cache_churn_stays_suppressed() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("desync dedup test lock must be available"); - clear_desync_dedup_for_testing(); - - let now = Instant::now(); - for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 { - assert!(should_emit_full_desync(key, false, now)); - } - - for offset in 0..2048u64 { - let emitted = should_emit_full_desync(u64::MAX - offset, false, now); - if offset == 0 { - assert!( - emitted, - "first full-cache newcomer should emit for forensic visibility" - ); - } else { - assert!( - !emitted, - "full-cache newcomer churn inside emit interval must stay suppressed" - ); - } - } -} - -#[test] -fn dedup_hash_is_stable_for_same_input_within_process() { - let sample = ( - "scope_user", - hash_ip("198.51.100.7".parse().unwrap()), - ProtoTag::Secure, - ); - let first = hash_value(&sample); - let second = hash_value(&sample); - assert_eq!( - first, second, - "dedup hash must be stable within a process for cache lookups" - ); -} - -#[test] -fn dedup_hash_resists_simple_collision_bursts_for_peer_ip_space() { - let mut seen = HashSet::new(); - - for octet in 1u16..=2048 { - let third = ((octet / 256) & 0xff) as u8; - let fourth = (octet & 0xff) as u8; - let ip = IpAddr::V4(std::net::Ipv4Addr::new(198, 51, third, fourth)); - let key = hash_value(&( - "scope_user", - hash_ip(ip), - ProtoTag::Secure, - DESYNC_ERROR_CLASS, - )); - seen.insert(key); - } - - assert_eq!( - seen.len(), - 2048, - "adversarial peer-IP burst should not collapse dedup keys via trivial collisions" - ); -} - -#[test] -fn light_fuzz_dedup_hash_collision_rate_stays_negligible() { - let mut rng = StdRng::seed_from_u64(0x9E37_79B9_A1B2_C3D4); - let mut seen = HashSet::new(); - let samples = 8192usize; - - for _ in 0..samples { - let user_seed: u64 = rng.random(); - let peer_seed: u64 = rng.random(); - let proto = if (peer_seed & 1) == 0 { - ProtoTag::Secure - } else { - ProtoTag::Intermediate - }; - let key = hash_value(&(user_seed, peer_seed, proto, DESYNC_ERROR_CLASS)); - seen.insert(key); - } - - let collisions = samples - seen.len(); - assert!( - collisions <= 1, - "light fuzz collision count should remain negligible for 64-bit dedup keys" - ); -} - -#[test] -fn stress_desync_dedup_churn_keeps_cache_hard_bounded() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("desync dedup test lock must be available"); - clear_desync_dedup_for_testing(); - - let now = Instant::now(); - let total = DESYNC_DEDUP_MAX_ENTRIES + 8192; - - let mut emitted_count = 0usize; - for key in 0..total as u64 { - let emitted = should_emit_full_desync(key, false, now); - if emitted { - emitted_count += 1; - } - } - - assert_eq!( - emitted_count, - DESYNC_DEDUP_MAX_ENTRIES + 1, - "after capacity is reached, same-tick newcomer churn must be rate-limited" - ); - - let len = DESYNC_DEDUP - .get() - .expect("dedup cache must be initialized by stress run") - .len(); - assert!( - len <= DESYNC_DEDUP_MAX_ENTRIES, - "dedup cache must stay bounded under stress churn" - ); -} - -#[test] -fn full_cache_newcomer_emission_is_rate_limited_but_periodic() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("desync dedup test lock must be available"); - clear_desync_dedup_for_testing(); - - let dedup = DESYNC_DEDUP.get_or_init(DashMap::new); - let base_now = Instant::now(); - - for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 { - dedup.insert(key, base_now - TokioDuration::from_millis(10)); - } - - // Same-tick newcomer storm: only the first should emit full forensic record. - let mut burst_emits = 0usize; - for i in 0..1024u64 { - if should_emit_full_desync(10_000_000 + i, false, base_now) { - burst_emits += 1; - } - } - assert_eq!( - burst_emits, 1, - "full-cache newcomer burst must be bounded to a single full emit per interval" - ); - - // After each interval elapses, one newcomer may emit again. - for step in 1..=6u64 { - let t = base_now + DESYNC_FULL_CACHE_EMIT_MIN_INTERVAL * step as u32; - assert!( - should_emit_full_desync(20_000_000 + step, false, t), - "full-cache newcomer should re-emit once interval has elapsed" - ); - assert!( - !should_emit_full_desync(30_000_000 + step, false, t), - "additional newcomers in the same interval tick must remain suppressed" - ); - } -} - -#[test] -fn full_cache_mode_override_emits_every_event() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("desync dedup test lock must be available"); - clear_desync_dedup_for_testing(); - - let now = Instant::now(); - for i in 0..10_000u64 { - assert!( - should_emit_full_desync(100_000_000 + i, true, now), - "desync_all_full override must bypass dedup and rate-limit suppression" - ); - } -} - -#[test] -fn report_desync_stats_follow_rate_limited_full_cache_policy() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("desync dedup test lock must be available"); - clear_desync_dedup_for_testing(); - - let dedup = DESYNC_DEDUP.get_or_init(DashMap::new); - let base_now = Instant::now(); - for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 { - dedup.insert(key, base_now - TokioDuration::from_millis(10)); - } - - let stats = Stats::new(); - let mut state = make_forensics_state(); - state.started_at = base_now; - - for i in 0..128u64 { - state.peer_hash = 0xABC0_0000_0000_0000u64 ^ i; - let _ = report_desync_frame_too_large( - &state, - ProtoTag::Secure, - 3, - 1024, - 4096, - Some([0x16, 0x03, 0x03, 0x00]), - &stats, - ); - } - - assert_eq!( - stats.get_desync_total(), - 128, - "every detected desync must increment total counter" - ); - assert_eq!( - stats.get_desync_full_logged(), - 1, - "same-interval full-cache newcomer storm must allow only one full forensic emit" - ); - assert_eq!( - stats.get_desync_suppressed(), - 127, - "remaining same-interval full-cache newcomer events must be suppressed" - ); - - // After one full interval in real wall clock, a newcomer should emit again. - thread::sleep(DESYNC_FULL_CACHE_EMIT_MIN_INTERVAL + TokioDuration::from_millis(20)); - state.peer_hash = 0xDEAD_BEEF_DEAD_BEEFu64; - let _ = report_desync_frame_too_large( - &state, - ProtoTag::Secure, - 4, - 1024, - 4097, - Some([0x16, 0x03, 0x03, 0x01]), - &stats, - ); - - assert_eq!( - stats.get_desync_full_logged(), - 2, - "full forensic emission must recover after rate-limit interval" - ); -} - -#[test] -fn concurrent_full_cache_newcomer_storm_is_single_emit_per_interval() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("desync dedup test lock must be available"); - clear_desync_dedup_for_testing(); - - let dedup = DESYNC_DEDUP.get_or_init(DashMap::new); - let base_now = Instant::now(); - for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 { - dedup.insert(key, base_now - TokioDuration::from_millis(10)); - } - - let emits = Arc::new(AtomicUsize::new(0)); - let mut workers = Vec::new(); - for worker_id in 0..32u64 { - let emits = Arc::clone(&emits); - workers.push(thread::spawn(move || { - for i in 0..512u64 { - let key = 0x7000_0000_0000_0000u64 ^ (worker_id << 20) ^ i; - if should_emit_full_desync(key, false, base_now) { - emits.fetch_add(1, Ordering::Relaxed); - } - } - })); - } - - for worker in workers { - worker.join().expect("worker thread must not panic"); - } - - assert_eq!( - emits.load(Ordering::Relaxed), - 1, - "concurrent same-interval full-cache storm must allow only one full forensic emit" - ); -} - -#[test] -fn light_fuzz_full_cache_rate_limit_oracle_matches_model() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("desync dedup test lock must be available"); - clear_desync_dedup_for_testing(); - - let dedup = DESYNC_DEDUP.get_or_init(DashMap::new); - let base_now = Instant::now(); - for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 { - dedup.insert(key, base_now - TokioDuration::from_millis(10)); - } - - let mut rng = StdRng::seed_from_u64(0xD15EA5E5_F00DBAAD); - let mut model_last_emit: Option = None; - - for i in 0..4096u64 { - let jitter_ms: u64 = rng.random_range(0..=3000); - let t = base_now + TokioDuration::from_millis(jitter_ms); - let key = 0x55AA_0000_0000_0000u64 ^ i ^ rng.random::(); - let actual = should_emit_full_desync(key, false, t); - - let expected = match model_last_emit { - None => { - model_last_emit = Some(t); - true - } - Some(last) => { - match t.checked_duration_since(last) { - Some(elapsed) if elapsed >= DESYNC_FULL_CACHE_EMIT_MIN_INTERVAL => { - model_last_emit = Some(t); - true - } - Some(_) => false, - None => { - // Match production fail-open behavior for non-monotonic synthetic input. - model_last_emit = Some(t); - true - } - } - } - }; - - assert_eq!( - actual, expected, - "full-cache rate-limit gate diverged from reference model under light fuzz" - ); - } -} - -#[test] -fn full_cache_gate_lock_poison_is_fail_closed_without_panic() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("desync dedup test lock must be available"); - clear_desync_dedup_for_testing(); - - let dedup = DESYNC_DEDUP.get_or_init(DashMap::new); - let base_now = Instant::now(); - for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 { - dedup.insert(key, base_now - TokioDuration::from_millis(10)); - } - - // Poison the full-cache gate lock intentionally. - let gate = DESYNC_FULL_CACHE_LAST_EMIT_AT.get_or_init(|| Mutex::new(None)); - let _ = std::panic::catch_unwind(|| { - let _lock = gate - .lock() - .expect("gate lock must be lockable before poison"); - panic!("intentional gate poison for fail-closed regression"); - }); - - let emitted = should_emit_full_desync(0xFACE_0000_0000_0001, false, base_now); - assert!( - !emitted, - "poisoned full-cache gate must fail-closed (suppress) instead of panic or fail-open" - ); - assert!( - dedup.len() <= DESYNC_DEDUP_MAX_ENTRIES, - "dedup cache must remain bounded even when gate lock is poisoned" - ); -} - -#[test] -fn full_cache_non_monotonic_time_emits_and_resets_gate_safely() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("desync dedup test lock must be available"); - clear_desync_dedup_for_testing(); - - let dedup = DESYNC_DEDUP.get_or_init(DashMap::new); - let base_now = Instant::now(); - for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 { - dedup.insert(key, base_now - TokioDuration::from_millis(10)); - } - - // First event seeds the gate. - assert!(should_emit_full_desync( - 0xABCD_0000_0000_0001, - false, - base_now + TokioDuration::from_millis(900) - )); - - // Synthetic earlier timestamp must not panic; it should fail-open and reset gate. - assert!(should_emit_full_desync( - 0xABCD_0000_0000_0002, - false, - base_now + TokioDuration::from_millis(100) - )); - - // Same instant again remains suppressed after reset. - assert!(!should_emit_full_desync( - 0xABCD_0000_0000_0003, - false, - base_now + TokioDuration::from_millis(100) - )); -} - -#[test] -fn desync_dedup_full_cache_inserts_new_key_with_bounded_single_key_churn() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("desync dedup test lock must be available"); - clear_desync_dedup_for_testing(); - - let dedup = DESYNC_DEDUP.get_or_init(DashMap::new); - let base_now = Instant::now(); - - // Fill with fresh entries so stale-pruning does not apply. - for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 { - dedup.insert(key, base_now - TokioDuration::from_millis(10)); - } - - let before_keys: std::collections::HashSet = dedup.iter().map(|e| *e.key()).collect(); - - let newcomer_key = u64::MAX; - let emitted = should_emit_full_desync(newcomer_key, false, base_now); - assert!( - emitted, - "new entry under full fresh cache must emit after bounded eviction" - ); - assert!( - dedup.get(&newcomer_key).is_some(), - "new key must be inserted after bounded eviction" - ); - - let after_keys: std::collections::HashSet = dedup.iter().map(|e| *e.key()).collect(); - let removed_count = before_keys.difference(&after_keys).count(); - let added_count = after_keys.difference(&before_keys).count(); - - assert_eq!( - removed_count, 1, - "full-cache insertion must evict exactly one prior key" - ); - assert_eq!( - added_count, 1, - "full-cache insertion must add exactly one newcomer key" - ); - assert!( - dedup.len() <= DESYNC_DEDUP_MAX_ENTRIES, - "dedup cache must remain hard-bounded after full-cache churn" - ); -} - -#[test] -fn light_fuzz_desync_dedup_temporal_gate_behavior_is_stable() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("desync dedup test lock must be available"); - clear_desync_dedup_for_testing(); - - let key = 0xC0DE_CAFE_u64; - let start = Instant::now(); - - assert!( - should_emit_full_desync(key, false, start), - "first event for key must emit full forensic record" - ); - - // Deterministic pseudo-random time deltas around dedup window edge. - let mut s: u64 = 0x1234_5678_9ABC_DEF0; - for _ in 0..2048 { - s ^= s << 7; - s ^= s >> 9; - s ^= s << 8; - - let delta_ms = s % (DESYNC_DEDUP_WINDOW.as_millis() as u64 * 2 + 1); - let now = start + TokioDuration::from_millis(delta_ms); - let emitted = should_emit_full_desync(key, false, now); - - if delta_ms < DESYNC_DEDUP_WINDOW.as_millis() as u64 { - assert!( - !emitted, - "events inside dedup window must remain suppressed" - ); - } else { - // Once window elapsed for this key, at least one sample should re-emit and refresh. - if emitted { - return; - } - } - } - - panic!("expected at least one post-window sample to re-emit forensic record"); -} - -fn make_forensics_state() -> RelayForensicsState { - RelayForensicsState { - trace_id: 1, - conn_id: 2, - user: "test-user".to_string(), - peer: "127.0.0.1:50000".parse::().unwrap(), - peer_hash: 3, - started_at: Instant::now(), - bytes_c2me: 0, - bytes_me2c: Arc::new(AtomicU64::new(0)), - desync_all_full: false, - } -} - -fn make_crypto_reader(reader: R) -> CryptoReader -where - R: tokio::io::AsyncRead + Unpin, -{ - let key = [0u8; 32]; - let iv = 0u128; - CryptoReader::new(reader, AesCtr::new(&key, iv)) -} - -fn make_crypto_writer(writer: W) -> CryptoWriter -where - W: tokio::io::AsyncWrite + Unpin, -{ - let key = [0u8; 32]; - let iv = 0u128; - CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) -} - -async fn make_me_pool_for_abort_test(stats: Arc) -> Arc { - let general = GeneralConfig::default(); - - MePool::new( - None, - vec![1u8; 32], - None, - false, - None, - Vec::new(), - 1, - None, - 12, - 1200, - HashMap::new(), - HashMap::new(), - None, - NetworkDecision::default(), - None, - Arc::new(SecureRandom::new()), - stats, - general.me_keepalive_enabled, - general.me_keepalive_interval_secs, - general.me_keepalive_jitter_secs, - general.me_keepalive_payload_random, - general.rpc_proxy_req_every, - general.me_warmup_stagger_enabled, - general.me_warmup_step_delay_ms, - general.me_warmup_step_jitter_ms, - general.me_reconnect_max_concurrent_per_dc, - general.me_reconnect_backoff_base_ms, - general.me_reconnect_backoff_cap_ms, - general.me_reconnect_fast_retry_count, - general.me_single_endpoint_shadow_writers, - general.me_single_endpoint_outage_mode_enabled, - general.me_single_endpoint_outage_disable_quarantine, - general.me_single_endpoint_outage_backoff_min_ms, - general.me_single_endpoint_outage_backoff_max_ms, - general.me_single_endpoint_shadow_rotate_every_secs, - general.me_floor_mode, - general.me_adaptive_floor_idle_secs, - general.me_adaptive_floor_min_writers_single_endpoint, - general.me_adaptive_floor_min_writers_multi_endpoint, - general.me_adaptive_floor_recover_grace_secs, - general.me_adaptive_floor_writers_per_core_total, - general.me_adaptive_floor_cpu_cores_override, - general.me_adaptive_floor_max_extra_writers_single_per_core, - general.me_adaptive_floor_max_extra_writers_multi_per_core, - general.me_adaptive_floor_max_active_writers_per_core, - general.me_adaptive_floor_max_warm_writers_per_core, - general.me_adaptive_floor_max_active_writers_global, - general.me_adaptive_floor_max_warm_writers_global, - general.hardswap, - general.me_pool_drain_ttl_secs, - general.me_instadrain, - general.me_pool_drain_threshold, - general.me_pool_drain_soft_evict_enabled, - general.me_pool_drain_soft_evict_grace_secs, - general.me_pool_drain_soft_evict_per_writer, - general.me_pool_drain_soft_evict_budget_per_core, - general.me_pool_drain_soft_evict_cooldown_ms, - general.effective_me_pool_force_close_secs(), - general.me_pool_min_fresh_ratio, - general.me_hardswap_warmup_delay_min_ms, - general.me_hardswap_warmup_delay_max_ms, - general.me_hardswap_warmup_extra_passes, - general.me_hardswap_warmup_pass_backoff_base_ms, - general.me_bind_stale_mode, - general.me_bind_stale_ttl_secs, - general.me_secret_atomic_snapshot, - general.me_deterministic_writer_sort, - MeWriterPickMode::default(), - general.me_writer_pick_sample_size, - MeSocksKdfPolicy::default(), - general.me_writer_cmd_channel_capacity, - general.me_route_channel_capacity, - general.me_route_backpressure_base_timeout_ms, - general.me_route_backpressure_high_timeout_ms, - general.me_route_backpressure_high_watermark_pct, - general.me_reader_route_data_wait_ms, - general.me_health_interval_ms_unhealthy, - general.me_health_interval_ms_healthy, - general.me_warn_rate_limit_ms, - MeRouteNoWriterMode::default(), - general.me_route_no_writer_wait_ms, - general.me_route_inline_recovery_attempts, - general.me_route_inline_recovery_wait_ms, - ) -} - -fn encrypt_for_reader(plaintext: &[u8]) -> Vec { - let key = [0u8; 32]; - let iv = 0u128; - let mut cipher = AesCtr::new(&key, iv); - cipher.encrypt(plaintext) -} - -#[tokio::test] -async fn read_client_payload_times_out_on_header_stall() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("middle relay test lock must be available"); - let (reader, _writer) = duplex(1024); - let mut crypto_reader = make_crypto_reader(reader); - let buffer_pool = Arc::new(BufferPool::new()); - let stats = Stats::new(); - let forensics = make_forensics_state(); - let mut frame_counter = 0; - - let result = read_client_payload( - &mut crypto_reader, - ProtoTag::Intermediate, - 1024, - TokioDuration::from_millis(25), - &buffer_pool, - &forensics, - &mut frame_counter, - &stats, - ) - .await; - - assert!( - matches!(result, Err(ProxyError::Io(ref e)) if e.kind() == std::io::ErrorKind::TimedOut), - "stalled header read must time out" - ); -} - -#[tokio::test] -async fn read_client_payload_times_out_on_payload_stall() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("middle relay test lock must be available"); - let (reader, mut writer) = duplex(1024); - let encrypted_len = encrypt_for_reader(&[8, 0, 0, 0]); - writer.write_all(&encrypted_len).await.unwrap(); - - let mut crypto_reader = make_crypto_reader(reader); - let buffer_pool = Arc::new(BufferPool::new()); - let stats = Stats::new(); - let forensics = make_forensics_state(); - let mut frame_counter = 0; - - let result = read_client_payload( - &mut crypto_reader, - ProtoTag::Intermediate, - 1024, - TokioDuration::from_millis(25), - &buffer_pool, - &forensics, - &mut frame_counter, - &stats, - ) - .await; - - assert!( - matches!(result, Err(ProxyError::Io(ref e)) if e.kind() == std::io::ErrorKind::TimedOut), - "stalled payload body read must time out" - ); -} - -#[tokio::test] -async fn read_client_payload_large_intermediate_frame_is_exact() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("middle relay test lock must be available"); - - let (reader, mut writer) = duplex(262_144); - let mut crypto_reader = make_crypto_reader(reader); - let buffer_pool = Arc::new(BufferPool::new()); - let stats = Stats::new(); - let forensics = make_forensics_state(); - let mut frame_counter = 0; - - let payload_len = buffer_pool.buffer_size().saturating_mul(3).max(65_537); - let mut plaintext = Vec::with_capacity(4 + payload_len); - plaintext.extend_from_slice(&(payload_len as u32).to_le_bytes()); - plaintext.extend((0..payload_len).map(|idx| (idx as u8).wrapping_mul(31))); - - let encrypted = encrypt_for_reader(&plaintext); - writer.write_all(&encrypted).await.unwrap(); - - let read = read_client_payload( - &mut crypto_reader, - ProtoTag::Intermediate, - payload_len + 16, - TokioDuration::from_secs(1), - &buffer_pool, - &forensics, - &mut frame_counter, - &stats, - ) - .await - .expect("payload read must succeed") - .expect("frame must be present"); - - let (frame, quickack) = read; - assert!(!quickack, "quickack flag must be unset"); - assert_eq!( - frame.len(), - payload_len, - "payload size must match wire length" - ); - for (idx, byte) in frame.iter().enumerate() { - assert_eq!(*byte, (idx as u8).wrapping_mul(31)); - } - assert_eq!(frame_counter, 1, "exactly one frame must be counted"); -} - -#[tokio::test] -async fn read_client_payload_secure_strips_tail_padding_bytes() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("middle relay test lock must be available"); - - let (reader, mut writer) = duplex(1024); - let mut crypto_reader = make_crypto_reader(reader); - let buffer_pool = Arc::new(BufferPool::new()); - let stats = Stats::new(); - let forensics = make_forensics_state(); - let mut frame_counter = 0; - - let payload = [0x11u8, 0x22, 0x33, 0x44, 0xaa, 0xbb, 0xcc, 0xdd]; - let tail = [0xeeu8, 0xff, 0x99]; - let wire_len = payload.len() + tail.len(); - - let mut plaintext = Vec::with_capacity(4 + wire_len); - plaintext.extend_from_slice(&(wire_len as u32).to_le_bytes()); - plaintext.extend_from_slice(&payload); - plaintext.extend_from_slice(&tail); - let encrypted = encrypt_for_reader(&plaintext); - writer.write_all(&encrypted).await.unwrap(); - - let read = read_client_payload( - &mut crypto_reader, - ProtoTag::Secure, - 1024, - TokioDuration::from_secs(1), - &buffer_pool, - &forensics, - &mut frame_counter, - &stats, - ) - .await - .expect("secure payload read must succeed") - .expect("secure frame must be present"); - - let (frame, quickack) = read; - assert!(!quickack, "quickack flag must be unset"); - assert_eq!(frame.as_ref(), &payload); - assert_eq!(frame_counter, 1, "one secure frame must be counted"); -} - -#[tokio::test] -async fn read_client_payload_secure_rejects_wire_len_below_4() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("middle relay test lock must be available"); - - let (reader, mut writer) = duplex(1024); - let mut crypto_reader = make_crypto_reader(reader); - let buffer_pool = Arc::new(BufferPool::new()); - let stats = Stats::new(); - let forensics = make_forensics_state(); - let mut frame_counter = 0; - - let mut plaintext = Vec::with_capacity(7); - plaintext.extend_from_slice(&3u32.to_le_bytes()); - plaintext.extend_from_slice(&[1u8, 2, 3]); - let encrypted = encrypt_for_reader(&plaintext); - writer.write_all(&encrypted).await.unwrap(); - - let result = read_client_payload( - &mut crypto_reader, - ProtoTag::Secure, - 1024, - TokioDuration::from_secs(1), - &buffer_pool, - &forensics, - &mut frame_counter, - &stats, - ) - .await; - - assert!( - matches!(result, Err(ProxyError::Proxy(ref msg)) if msg.contains("Frame too small: 3")), - "secure wire length below 4 must be fail-closed by the frame-too-small guard" - ); -} - -#[tokio::test] -async fn read_client_payload_intermediate_skips_zero_len_frame() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("middle relay test lock must be available"); - - let (reader, mut writer) = duplex(1024); - let mut crypto_reader = make_crypto_reader(reader); - let buffer_pool = Arc::new(BufferPool::new()); - let stats = Stats::new(); - let forensics = make_forensics_state(); - let mut frame_counter = 0; - - let payload = [7u8, 6, 5, 4, 3, 2, 1, 0]; - let mut plaintext = Vec::with_capacity(4 + 4 + payload.len()); - plaintext.extend_from_slice(&0u32.to_le_bytes()); - plaintext.extend_from_slice(&(payload.len() as u32).to_le_bytes()); - plaintext.extend_from_slice(&payload); - let encrypted = encrypt_for_reader(&plaintext); - writer.write_all(&encrypted).await.unwrap(); - - let read = read_client_payload( - &mut crypto_reader, - ProtoTag::Intermediate, - 1024, - TokioDuration::from_secs(1), - &buffer_pool, - &forensics, - &mut frame_counter, - &stats, - ) - .await - .expect("intermediate payload read must succeed") - .expect("frame must be present"); - - let (frame, quickack) = read; - assert!(!quickack, "quickack flag must be unset"); - assert_eq!(frame.as_ref(), &payload); - assert_eq!(frame_counter, 1, "zero-length frame must be skipped"); -} - -#[tokio::test] -async fn read_client_payload_abridged_extended_len_sets_quickack() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("middle relay test lock must be available"); - - let (reader, mut writer) = duplex(4096); - let mut crypto_reader = make_crypto_reader(reader); - let buffer_pool = Arc::new(BufferPool::new()); - let stats = Stats::new(); - let forensics = make_forensics_state(); - let mut frame_counter = 0; - - let payload_len = 4 * 130; - let len_words = (payload_len / 4) as u32; - let mut plaintext = Vec::with_capacity(1 + 3 + payload_len); - plaintext.push(0xff | 0x80); - let lw = len_words.to_le_bytes(); - plaintext.extend_from_slice(&lw[..3]); - plaintext.extend((0..payload_len).map(|idx| (idx as u8).wrapping_add(17))); - - let encrypted = encrypt_for_reader(&plaintext); - writer.write_all(&encrypted).await.unwrap(); - - let read = read_client_payload( - &mut crypto_reader, - ProtoTag::Abridged, - payload_len + 16, - TokioDuration::from_secs(1), - &buffer_pool, - &forensics, - &mut frame_counter, - &stats, - ) - .await - .expect("abridged payload read must succeed") - .expect("frame must be present"); - - let (frame, quickack) = read; - assert!( - quickack, - "quickack bit must be propagated from abridged header" - ); - assert_eq!(frame.len(), payload_len); - assert_eq!(frame_counter, 1, "one abridged frame must be counted"); -} - -#[tokio::test] -async fn read_client_payload_returns_buffer_to_pool_after_emit() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("middle relay test lock must be available"); - - let pool = Arc::new(BufferPool::with_config(64, 8)); - pool.preallocate(1); - assert_eq!(pool.stats().pooled, 1, "precondition: one pooled buffer"); - - let (reader, mut writer) = duplex(4096); - let mut crypto_reader = make_crypto_reader(reader); - let stats = Stats::new(); - let forensics = make_forensics_state(); - let mut frame_counter = 0; - - // Force growth beyond default pool buffer size to catch ownership-take regressions. - let payload_len = 257usize; - let mut plaintext = Vec::with_capacity(4 + payload_len); - plaintext.extend_from_slice(&(payload_len as u32).to_le_bytes()); - plaintext.extend((0..payload_len).map(|idx| (idx as u8).wrapping_mul(13))); - - let encrypted = encrypt_for_reader(&plaintext); - writer.write_all(&encrypted).await.unwrap(); - - let _ = read_client_payload( - &mut crypto_reader, - ProtoTag::Intermediate, - payload_len + 8, - TokioDuration::from_secs(1), - &pool, - &forensics, - &mut frame_counter, - &stats, - ) - .await - .expect("payload read must succeed") - .expect("frame must be present"); - - assert_eq!(frame_counter, 1); - let pool_stats = pool.stats(); - assert!( - pool_stats.pooled >= 1, - "emitted payload buffer must be returned to pool to avoid pool drain" - ); -} - -#[tokio::test] -async fn read_client_payload_keeps_pool_buffer_checked_out_until_frame_drop() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("middle relay test lock must be available"); - - let pool = Arc::new(BufferPool::with_config(64, 2)); - pool.preallocate(1); - assert_eq!( - pool.stats().pooled, - 1, - "one pooled buffer must be available" - ); - - let (reader, mut writer) = duplex(1024); - let mut crypto_reader = make_crypto_reader(reader); - let stats = Stats::new(); - let forensics = make_forensics_state(); - let mut frame_counter = 0; - - let payload = [0x41u8, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48]; - let mut plaintext = Vec::with_capacity(4 + payload.len()); - plaintext.extend_from_slice(&(payload.len() as u32).to_le_bytes()); - plaintext.extend_from_slice(&payload); - let encrypted = encrypt_for_reader(&plaintext); - writer.write_all(&encrypted).await.unwrap(); - - let (frame, quickack) = read_client_payload( - &mut crypto_reader, - ProtoTag::Intermediate, - 1024, - TokioDuration::from_secs(1), - &pool, - &forensics, - &mut frame_counter, - &stats, - ) - .await - .expect("payload read must succeed") - .expect("frame must be present"); - - assert!(!quickack); - assert_eq!(frame.as_ref(), &payload); - assert_eq!( - pool.stats().pooled, - 0, - "buffer must stay checked out while frame payload is alive" - ); - - drop(frame); - assert!( - pool.stats().pooled >= 1, - "buffer must return to pool only after frame drop" - ); -} - -#[tokio::test] -async fn enqueue_c2me_close_unblocks_after_queue_drain() { - let (tx, mut rx) = mpsc::channel::(1); - tx.send(C2MeCommand::Data { - payload: make_pooled_payload(&[0x41]), - flags: 0, - }) - .await - .unwrap(); - - let tx2 = tx.clone(); - let close_task = - tokio::spawn(async move { enqueue_c2me_command(&tx2, C2MeCommand::Close).await }); - - tokio::time::sleep(TokioDuration::from_millis(10)).await; - - let first = timeout(TokioDuration::from_millis(100), rx.recv()) - .await - .unwrap() - .expect("first queued item must be present"); - assert!(matches!(first, C2MeCommand::Data { .. })); - - close_task - .await - .unwrap() - .expect("close enqueue must succeed after drain"); - - let second = timeout(TokioDuration::from_millis(100), rx.recv()) - .await - .unwrap() - .expect("close command must follow after queue drain"); - assert!(matches!(second, C2MeCommand::Close)); -} - -#[tokio::test] -async fn enqueue_c2me_close_full_then_receiver_drop_fails_cleanly() { - let (tx, rx) = mpsc::channel::(1); - tx.send(C2MeCommand::Data { - payload: make_pooled_payload(&[0x42]), - flags: 0, - }) - .await - .unwrap(); - - let tx2 = tx.clone(); - let close_task = - tokio::spawn(async move { enqueue_c2me_command(&tx2, C2MeCommand::Close).await }); - - tokio::time::sleep(TokioDuration::from_millis(10)).await; - drop(rx); - - let result = timeout(TokioDuration::from_secs(1), close_task) - .await - .expect("close task must finish") - .expect("close task must not panic"); - assert!( - result.is_err(), - "close enqueue must fail cleanly when receiver is dropped under pressure" - ); -} - -#[tokio::test] -async fn process_me_writer_response_ack_obeys_flush_policy() { - let (writer_side, _reader_side) = duplex(1024); - let mut writer = make_crypto_writer(writer_side); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - let stats = Stats::new(); - let bytes_me2c = AtomicU64::new(0); - - let immediate = process_me_writer_response( - MeResponse::Ack(0x11223344), - &mut writer, - ProtoTag::Intermediate, - &rng, - &mut frame_buf, - &stats, - "user", - None, - &bytes_me2c, - 77, - true, - false, - ) - .await - .expect("ack response must be processed"); - - assert!(matches!( - immediate, - MeWriterResponseOutcome::Continue { - frames: 1, - bytes: 4, - flush_immediately: true, - } - )); - - let delayed = process_me_writer_response( - MeResponse::Ack(0x55667788), - &mut writer, - ProtoTag::Intermediate, - &rng, - &mut frame_buf, - &stats, - "user", - None, - &bytes_me2c, - 77, - false, - false, - ) - .await - .expect("ack response must be processed"); - - assert!(matches!( - delayed, - MeWriterResponseOutcome::Continue { - frames: 1, - bytes: 4, - flush_immediately: false, - } - )); -} - -#[tokio::test] -async fn process_me_writer_response_data_updates_byte_accounting() { - let (writer_side, _reader_side) = duplex(1024); - let mut writer = make_crypto_writer(writer_side); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - let stats = Stats::new(); - let bytes_me2c = AtomicU64::new(0); - - let payload = vec![1u8, 2, 3, 4, 5, 6, 7, 8, 9]; - let outcome = process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from(payload.clone()), - }, - &mut writer, - ProtoTag::Intermediate, - &rng, - &mut frame_buf, - &stats, - "user", - None, - &bytes_me2c, - 88, - false, - false, - ) - .await - .expect("data response must be processed"); - - assert!(matches!( - outcome, - MeWriterResponseOutcome::Continue { - frames: 1, - bytes, - flush_immediately: false, - } if bytes == payload.len() - )); - assert_eq!( - bytes_me2c.load(std::sync::atomic::Ordering::Relaxed), - payload.len() as u64, - "ME->C byte accounting must increase by emitted payload size" - ); -} - -#[tokio::test] -async fn process_me_writer_response_data_enforces_live_user_quota() { - let (writer_side, mut reader_side) = duplex(1024); - let mut writer = make_crypto_writer(writer_side); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - let stats = Stats::new(); - let bytes_me2c = AtomicU64::new(0); - - stats.add_user_octets_from("quota-user", 10); - - let result = process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from(vec![1u8, 2, 3, 4]), - }, - &mut writer, - ProtoTag::Intermediate, - &rng, - &mut frame_buf, - &stats, - "quota-user", - Some(12), - &bytes_me2c, - 89, - false, - false, - ) - .await; - - assert!( - matches!(result, Err(ProxyError::DataQuotaExceeded { user }) if user == "quota-user"), - "ME->client runtime path must terminate when live user quota is crossed" - ); - - let mut raw = [0u8; 1]; - assert!( - timeout(TokioDuration::from_millis(100), reader_side.read(&mut raw)) - .await - .is_err(), - "quota exhaustion must not write any ciphertext to the client stream" - ); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn process_me_writer_response_concurrent_same_user_quota_does_not_overshoot_limit() { - let stats = Stats::new(); - let bytes_me2c = AtomicU64::new(0); - let user = "quota-race-user"; - - let (writer_side_a, _reader_side_a) = duplex(1024); - let (writer_side_b, _reader_side_b) = duplex(1024); - let mut writer_a = make_crypto_writer(writer_side_a); - let mut writer_b = make_crypto_writer(writer_side_b); - let mut frame_buf_a = Vec::new(); - let mut frame_buf_b = Vec::new(); - let rng_a = SecureRandom::new(); - let rng_b = SecureRandom::new(); - - let fut_a = process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0x11]), - }, - &mut writer_a, - ProtoTag::Intermediate, - &rng_a, - &mut frame_buf_a, - &stats, - user, - Some(1), - &bytes_me2c, - 91, - false, - false, - ); - let fut_b = process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0x22]), - }, - &mut writer_b, - ProtoTag::Intermediate, - &rng_b, - &mut frame_buf_b, - &stats, - user, - Some(1), - &bytes_me2c, - 92, - false, - false, - ); - - let (result_a, result_b) = tokio::join!(fut_a, fut_b); - - assert!( - matches!(result_a, Err(ProxyError::DataQuotaExceeded { ref user }) if user == "quota-race-user") - || matches!(result_a, Ok(_)), - "concurrent quota test must complete without panicking" - ); - assert!( - matches!(result_b, Err(ProxyError::DataQuotaExceeded { ref user }) if user == "quota-race-user") - || matches!(result_b, Ok(_)), - "concurrent quota test must complete without panicking" - ); - assert!( - stats.get_user_total_octets(user) <= 1, - "same-user concurrent middle-relay responses must not overshoot the configured quota" - ); -} - -#[tokio::test] -async fn process_me_writer_response_data_does_not_forward_partial_payload_when_remaining_quota_is_smaller_than_message() - { - let (writer_side, mut reader_side) = duplex(1024); - let mut writer = make_crypto_writer(writer_side); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - let stats = Stats::new(); - let bytes_me2c = AtomicU64::new(0); - - stats.add_user_octets_to("partial-quota-user", 3); - - let result = process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from(vec![1u8, 2, 3, 4]), - }, - &mut writer, - ProtoTag::Intermediate, - &rng, - &mut frame_buf, - &stats, - "partial-quota-user", - Some(4), - &bytes_me2c, - 90, - false, - false, - ) - .await; - - assert!( - matches!(result, Err(ProxyError::DataQuotaExceeded { user }) if user == "partial-quota-user"), - "ME->client runtime path must reject oversized payloads before writing" - ); - - let mut raw = [0u8; 1]; - assert!( - timeout(TokioDuration::from_millis(100), reader_side.read(&mut raw)) - .await - .is_err(), - "oversized payloads must not leak any partial ciphertext to the client stream" - ); -} - -#[tokio::test] -async fn middle_relay_abort_midflight_releases_route_gauge() { - let stats = Arc::new(Stats::new()); - let me_pool = make_me_pool_for_abort_test(stats.clone()).await; - let config = Arc::new(ProxyConfig::default()); - let buffer_pool = Arc::new(BufferPool::new()); - let rng = Arc::new(SecureRandom::new()); - - let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Middle)); - let route_snapshot = route_runtime.snapshot(); - - let (server_side, client_side) = duplex(64 * 1024); - let (server_reader, server_writer) = tokio::io::split(server_side); - let crypto_reader = make_crypto_reader(server_reader); - let crypto_writer = make_crypto_writer(server_writer); - - let success = HandshakeSuccess { - user: "abort-middle-user".to_string(), - dc_idx: 2, - proto_tag: ProtoTag::Intermediate, - dec_key: [0u8; 32], - dec_iv: 0, - enc_key: [0u8; 32], - enc_iv: 0, - peer: "127.0.0.1:50001".parse().unwrap(), - is_tls: false, - }; - - let relay_task = tokio::spawn(handle_via_middle_proxy( - crypto_reader, - crypto_writer, - success, - me_pool, - stats.clone(), - config, - buffer_pool, - "127.0.0.1:443".parse().unwrap(), - rng, - route_runtime.subscribe(), - route_snapshot, - 0xdecafbad, - )); - - let started = tokio::time::timeout(TokioDuration::from_secs(2), async { - loop { - if stats.get_current_connections_me() == 1 { - break; - } - tokio::time::sleep(TokioDuration::from_millis(10)).await; - } - }) - .await; - assert!( - started.is_ok(), - "middle relay must increment route gauge before abort" - ); - - relay_task.abort(); - let joined = relay_task.await; - assert!( - joined.is_err(), - "aborted middle relay task must return join error" - ); - - tokio::time::sleep(TokioDuration::from_millis(20)).await; - assert_eq!( - stats.get_current_connections_me(), - 0, - "route gauge must be released when middle relay task is aborted mid-flight" - ); - - drop(client_side); -} - -#[tokio::test] -async fn middle_relay_cutover_midflight_releases_route_gauge() { - let stats = Arc::new(Stats::new()); - let me_pool = make_me_pool_for_abort_test(stats.clone()).await; - let config = Arc::new(ProxyConfig::default()); - let buffer_pool = Arc::new(BufferPool::new()); - let rng = Arc::new(SecureRandom::new()); - - let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Middle)); - let route_snapshot = route_runtime.snapshot(); - - let (server_side, client_side) = duplex(64 * 1024); - let (server_reader, server_writer) = tokio::io::split(server_side); - let crypto_reader = make_crypto_reader(server_reader); - let crypto_writer = make_crypto_writer(server_writer); - - let success = HandshakeSuccess { - user: "cutover-middle-user".to_string(), - dc_idx: 2, - proto_tag: ProtoTag::Intermediate, - dec_key: [0u8; 32], - dec_iv: 0, - enc_key: [0u8; 32], - enc_iv: 0, - peer: "127.0.0.1:50003".parse().unwrap(), - is_tls: false, - }; - - let relay_task = tokio::spawn(handle_via_middle_proxy( - crypto_reader, - crypto_writer, - success, - me_pool, - stats.clone(), - config, - buffer_pool, - "127.0.0.1:443".parse().unwrap(), - rng, - route_runtime.subscribe(), - route_snapshot, - 0xfeed_beef, - )); - - tokio::time::timeout(TokioDuration::from_secs(2), async { - loop { - if stats.get_current_connections_me() == 1 { - break; - } - tokio::time::sleep(TokioDuration::from_millis(10)).await; - } - }) - .await - .expect("middle relay must increment route gauge before cutover"); - - assert!( - route_runtime.set_mode(RelayRouteMode::Direct).is_some(), - "cutover must advance route generation" - ); - - let relay_result = tokio::time::timeout(TokioDuration::from_secs(6), relay_task) - .await - .expect("middle relay must terminate after cutover") - .expect("middle relay task must not panic"); - assert!( - relay_result.is_err(), - "cutover should terminate middle relay session" - ); - assert!( - matches!( - relay_result, - Err(ProxyError::Proxy(ref msg)) if msg == ROUTE_SWITCH_ERROR_MSG - ), - "client-visible cutover error must stay generic and avoid route-internal metadata" - ); - - assert_eq!( - stats.get_current_connections_me(), - 0, - "route gauge must be released when middle relay exits on cutover" - ); - - drop(client_side); -} - -async fn run_quota_race_attempt( - stats: &Stats, - bytes_me2c: &AtomicU64, - user: &str, - payload: u8, - conn_id: u64, - barrier: Arc, -) -> Result { - let (writer_side, _reader_side) = duplex(1024); - let mut writer = make_crypto_writer(writer_side); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - - barrier.wait().await; - process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from(vec![payload]), - }, - &mut writer, - ProtoTag::Intermediate, - &rng, - &mut frame_buf, - stats, - user, - Some(1), - bytes_me2c, - conn_id, - false, - false, - ) - .await -} - -#[tokio::test] -async fn abridged_max_extended_length_fails_closed_without_panic_or_partial_read() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("middle relay test lock must be available"); - - let (reader, mut writer) = duplex(256); - let mut crypto_reader = make_crypto_reader(reader); - let buffer_pool = Arc::new(BufferPool::new()); - let stats = Stats::new(); - let forensics = make_forensics_state(); - let mut frame_counter = 0; - - let plaintext = vec![0x7f, 0xff, 0xff, 0xff]; - let encrypted = encrypt_for_reader(&plaintext); - writer.write_all(&encrypted).await.unwrap(); - - let result = read_client_payload( - &mut crypto_reader, - ProtoTag::Abridged, - 4096, - TokioDuration::from_secs(1), - &buffer_pool, - &forensics, - &mut frame_counter, - &stats, - ) - .await; - - assert!( - result.is_err(), - "oversized abridged length must fail closed" - ); - assert_eq!( - frame_counter, 0, - "oversized frame must not be counted as accepted" - ); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn deterministic_quota_race_exactly_one_succeeds_and_one_is_rejected() { - let stats = Stats::new(); - let bytes_me2c = AtomicU64::new(0); - let user = "gap-t04-race-user"; - let barrier = Arc::new(Barrier::new(2)); - - let f1 = run_quota_race_attempt(&stats, &bytes_me2c, user, 0x11, 5001, barrier.clone()); - let f2 = run_quota_race_attempt(&stats, &bytes_me2c, user, 0x22, 5002, barrier); - - let (r1, r2) = tokio::join!(f1, f2); - - assert!( - matches!(r1, Ok(_) | Err(ProxyError::DataQuotaExceeded { .. })), - "first racer must either finish or fail closed on quota" - ); - assert!( - matches!(r2, Ok(_) | Err(ProxyError::DataQuotaExceeded { .. })), - "second racer must either finish or fail closed on quota" - ); - assert!( - matches!(r1, Err(ProxyError::DataQuotaExceeded { .. })) - || matches!(r2, Err(ProxyError::DataQuotaExceeded { .. })), - "at least one racer must be quota-rejected" - ); - assert_eq!( - stats.get_user_total_octets(user), - 1, - "same-user race must forward/account exactly one payload byte" - ); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn stress_quota_race_bursts_never_allow_double_success_per_round() { - let stats = Stats::new(); - let bytes_me2c = AtomicU64::new(0); - - for round in 0..128u64 { - let user = format!("gap-t04-race-burst-{round}"); - let barrier = Arc::new(Barrier::new(2)); - - let one = run_quota_race_attempt( - &stats, - &bytes_me2c, - &user, - 0x33, - 6000 + round, - barrier.clone(), - ); - let two = run_quota_race_attempt(&stats, &bytes_me2c, &user, 0x44, 7000 + round, barrier); - - let (r1, r2) = tokio::join!(one, two); - assert!( - matches!(r1, Ok(_) | Err(ProxyError::DataQuotaExceeded { .. })) - && matches!(r2, Ok(_) | Err(ProxyError::DataQuotaExceeded { .. })), - "round {round}: racers must resolve cleanly without unexpected errors" - ); - assert!( - matches!(r1, Err(ProxyError::DataQuotaExceeded { .. })) - || matches!(r2, Err(ProxyError::DataQuotaExceeded { .. })), - "round {round}: at least one racer must be quota-rejected" - ); - assert_eq!( - stats.get_user_total_octets(&user), - 1, - "round {round}: same-user total octets must remain exactly 1 (single forwarded winner)" - ); - } -} - -#[tokio::test] -async fn middle_relay_cutover_storm_multi_session_keeps_generic_errors_and_releases_gauge() { - let session_count = 6usize; - let stats = Arc::new(Stats::new()); - let me_pool = make_me_pool_for_abort_test(stats.clone()).await; - let config = Arc::new(ProxyConfig::default()); - let buffer_pool = Arc::new(BufferPool::new()); - let rng = Arc::new(SecureRandom::new()); - - let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Middle)); - let route_snapshot = route_runtime.snapshot(); - - let mut relay_tasks = Vec::with_capacity(session_count); - let mut client_sides = Vec::with_capacity(session_count); - - for idx in 0..session_count { - let (server_side, client_side) = duplex(64 * 1024); - client_sides.push(client_side); - let (server_reader, server_writer) = tokio::io::split(server_side); - let crypto_reader = make_crypto_reader(server_reader); - let crypto_writer = make_crypto_writer(server_writer); - - let success = HandshakeSuccess { - user: format!("cutover-storm-middle-user-{idx}"), - dc_idx: 2, - proto_tag: ProtoTag::Intermediate, - dec_key: [0u8; 32], - dec_iv: 0, - enc_key: [0u8; 32], - enc_iv: 0, - peer: SocketAddr::new( - std::net::IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 1)), - 52000 + idx as u16, - ), - is_tls: false, - }; - - relay_tasks.push(tokio::spawn(handle_via_middle_proxy( - crypto_reader, - crypto_writer, - success, - me_pool.clone(), - stats.clone(), - config.clone(), - buffer_pool.clone(), - "127.0.0.1:443".parse().unwrap(), - rng.clone(), - route_runtime.subscribe(), - route_snapshot, - 0xB000_0000 + idx as u64, - ))); - } - - tokio::time::timeout(TokioDuration::from_secs(4), async { - loop { - if stats.get_current_connections_me() == session_count as u64 { - break; - } - tokio::time::sleep(TokioDuration::from_millis(10)).await; - } - }) - .await - .expect("all middle sessions must become active before cutover storm"); - - let route_runtime_flipper = route_runtime.clone(); - let flipper = tokio::spawn(async move { - for step in 0..64u32 { - let mode = if (step & 1) == 0 { - RelayRouteMode::Direct - } else { - RelayRouteMode::Middle - }; - let _ = route_runtime_flipper.set_mode(mode); - tokio::time::sleep(TokioDuration::from_millis(15)).await; - } - }); - - for relay_task in relay_tasks { - let relay_result = tokio::time::timeout(TokioDuration::from_secs(10), relay_task) - .await - .expect("middle relay task must finish under cutover storm") - .expect("middle relay task must not panic"); - - assert!( - matches!( - relay_result, - Err(ProxyError::Proxy(ref msg)) if msg == ROUTE_SWITCH_ERROR_MSG - ), - "storm-cutover termination must remain generic for all middle sessions" - ); - } - - flipper.abort(); - let _ = flipper.await; - - assert_eq!( - stats.get_current_connections_me(), - 0, - "middle route gauge must return to zero after cutover storm" - ); - - drop(client_sides); -} - -#[tokio::test] -async fn secure_padding_distribution_in_relay_writer() { - timeout(TokioDuration::from_secs(10), async { - let (mut client_side, relay_side) = duplex(512 * 1024); - let key = [0u8; 32]; - let iv = 0u128; - let mut writer = CryptoWriter::new(relay_side, AesCtr::new(&key, iv), 8 * 1024); - let rng = Arc::new(SecureRandom::new()); - let mut frame_buf = Vec::new(); - let mut decryptor = AesCtr::new(&key, iv); - - let mut padding_counts = [0usize; 4]; - let iterations = 180usize; - let payload = vec![0xAAu8; 100]; // 4-byte aligned - - for _ in 0..iterations { - write_client_payload( - &mut writer, - ProtoTag::Secure, - 0, - &payload, - &rng, - &mut frame_buf, - ) - .await - .expect("payload write must succeed"); - writer - .flush() - .await - .expect("writer flush must complete so encrypted frame becomes readable"); - - let mut len_buf = [0u8; 4]; - client_side - .read_exact(&mut len_buf) - .await - .expect("must read encrypted secure length"); - let decrypted_len_bytes = decryptor.decrypt(&len_buf); - let decrypted_len_bytes: [u8; 4] = decrypted_len_bytes - .try_into() - .expect("decrypted length must be 4 bytes"); - let wire_len = (u32::from_le_bytes(decrypted_len_bytes) & 0x7fff_ffff) as usize; - - assert!( - wire_len >= payload.len(), - "wire length must include at least payload bytes" - ); - let padding_len = wire_len - payload.len(); - assert!(padding_len >= 1 && padding_len <= 3); - padding_counts[padding_len] += 1; - - // Drain and decrypt frame bytes so CTR state stays aligned across writes. - let mut trash = vec![0u8; wire_len]; - client_side - .read_exact(&mut trash) - .await - .expect("must read encrypted secure frame body"); - let _ = decryptor.decrypt(&trash); - } - - for p in 1..=3 { - let count = padding_counts[p]; - assert!( - count > iterations / 8, - "padding length {p} is under-represented ({count}/{iterations})" - ); - } - }) - .await - .expect("secure padding distribution test exceeded runtime budget"); -} - -#[tokio::test] -async fn negative_middle_end_connection_lost_during_relay_exits_on_client_eof() { - let (client_reader_side, client_writer_side) = duplex(1024); - let (_relay_reader_side, relay_writer_side) = duplex(1024); - - let key = [0u8; 32]; - let iv = 0u128; - let crypto_reader = CryptoReader::new(client_reader_side, AesCtr::new(&key, iv)); - let crypto_writer = CryptoWriter::new(relay_writer_side, AesCtr::new(&key, iv), 1024); - - let stats = Arc::new(Stats::new()); - let config = Arc::new(ProxyConfig::default()); - let buffer_pool = Arc::new(BufferPool::with_config(1024, 1)); - let rng = Arc::new(SecureRandom::new()); - let route_runtime = RouteRuntimeController::new(RelayRouteMode::Middle); - - // Create an ME pool. - let me_pool = make_me_pool_for_abort_test(stats.clone()).await; - - // ConnRegistry ids are monotonic; reserve one id so we can predict the - // next session conn_id and close it deterministically without relying on - // writer-bound views such as active_conn_ids(). - let (probe_conn_id, probe_rx) = me_pool.registry().register().await; - drop(probe_rx); - me_pool.registry().unregister(probe_conn_id).await; - let target_conn_id = probe_conn_id.wrapping_add(1); - - let success = HandshakeSuccess { - user: "test-user".to_string(), - peer: "127.0.0.1:12345".parse().unwrap(), - dc_idx: 1, - proto_tag: ProtoTag::Intermediate, - enc_key: key, - enc_iv: iv, - dec_key: key, - dec_iv: iv, - is_tls: false, - }; - - let session_task = tokio::spawn(handle_via_middle_proxy( - crypto_reader, - crypto_writer, - success, - me_pool.clone(), - stats.clone(), - config.clone(), - buffer_pool.clone(), - "127.0.0.1:443".parse().unwrap(), - rng.clone(), - route_runtime.subscribe(), - route_runtime.snapshot(), - 0x1234_5678, - )); - - // Wait until session startup is visible, then unregister the predicted - // conn_id to close the per-session ME response channel. - timeout(TokioDuration::from_millis(500), async { - loop { - if stats.get_current_connections_me() >= 1 { - break; - } - tokio::time::sleep(TokioDuration::from_millis(10)).await; - } - }) - .await - .expect("ME session must start before channel close simulation"); - - me_pool.registry().unregister(target_conn_id).await; - - drop(client_writer_side); - - let result = timeout(TokioDuration::from_secs(2), session_task) - .await - .expect("Session task must terminate after ME drop and client EOF") - .expect("Session task must not panic"); - - assert!( - result.is_ok(), - "Session should complete cleanly after ME drop when client closes, got: {:?}", - result - ); -} - -#[tokio::test] -async fn adversarial_middle_end_drop_plus_cutover_returns_generic_route_switch() { - let (client_reader_side, _client_writer_side) = duplex(1024); - let (_relay_reader_side, relay_writer_side) = duplex(1024); - - let key = [0u8; 32]; - let iv = 0u128; - let crypto_reader = CryptoReader::new(client_reader_side, AesCtr::new(&key, iv)); - let crypto_writer = CryptoWriter::new(relay_writer_side, AesCtr::new(&key, iv), 1024); - - let stats = Arc::new(Stats::new()); - let config = Arc::new(ProxyConfig::default()); - let buffer_pool = Arc::new(BufferPool::with_config(1024, 1)); - let rng = Arc::new(SecureRandom::new()); - let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Middle)); - - let me_pool = make_me_pool_for_abort_test(stats.clone()).await; - - // Predict the next conn_id so we can force-drop its ME channel deterministically. - let (probe_conn_id, probe_rx) = me_pool.registry().register().await; - drop(probe_rx); - me_pool.registry().unregister(probe_conn_id).await; - let target_conn_id = probe_conn_id.wrapping_add(1); - - let success = HandshakeSuccess { - user: "test-user-cutover".to_string(), - peer: "127.0.0.1:12345".parse().unwrap(), - dc_idx: 1, - proto_tag: ProtoTag::Intermediate, - enc_key: key, - enc_iv: iv, - dec_key: key, - dec_iv: iv, - is_tls: false, - }; - - let runtime_clone = route_runtime.clone(); - let session_task = tokio::spawn(handle_via_middle_proxy( - crypto_reader, - crypto_writer, - success, - me_pool.clone(), - stats.clone(), - config, - buffer_pool, - "127.0.0.1:443".parse().unwrap(), - rng, - runtime_clone.subscribe(), - runtime_clone.snapshot(), - 0xC001_CAFE, - )); - - timeout(TokioDuration::from_millis(500), async { - loop { - if stats.get_current_connections_me() >= 1 { - break; - } - tokio::time::sleep(TokioDuration::from_millis(10)).await; - } - }) - .await - .expect("ME session must start before race trigger"); - - // Race ME channel drop with route cutover and assert generic client-visible outcome. - me_pool.registry().unregister(target_conn_id).await; - assert!( - route_runtime.set_mode(RelayRouteMode::Direct).is_some(), - "cutover must advance generation" - ); - - let relay_result = timeout(TokioDuration::from_secs(6), session_task) - .await - .expect("session must terminate under ME-drop + cutover race") - .expect("session task must not panic"); - - assert!( - matches!( - relay_result, - Err(ProxyError::Proxy(ref msg)) if msg == ROUTE_SWITCH_ERROR_MSG - ), - "race outcome must remain generic and not leak ME internals, got: {:?}", - relay_result - ); -} - -#[tokio::test] -async fn stress_middle_end_drop_with_client_eof_never_hangs_across_burst() { - let stats = Arc::new(Stats::new()); - let me_pool = make_me_pool_for_abort_test(stats.clone()).await; - - for round in 0..32u64 { - let (client_reader_side, client_writer_side) = duplex(1024); - let (_relay_reader_side, relay_writer_side) = duplex(1024); - - let key = [0u8; 32]; - let iv = 0u128; - let crypto_reader = CryptoReader::new(client_reader_side, AesCtr::new(&key, iv)); - let crypto_writer = CryptoWriter::new(relay_writer_side, AesCtr::new(&key, iv), 1024); - - let config = Arc::new(ProxyConfig::default()); - let buffer_pool = Arc::new(BufferPool::with_config(1024, 1)); - let rng = Arc::new(SecureRandom::new()); - let route_runtime = RouteRuntimeController::new(RelayRouteMode::Middle); - - let (probe_conn_id, probe_rx) = me_pool.registry().register().await; - drop(probe_rx); - me_pool.registry().unregister(probe_conn_id).await; - let target_conn_id = probe_conn_id.wrapping_add(1); - - let success = HandshakeSuccess { - user: format!("stress-me-drop-eof-{round}"), - peer: "127.0.0.1:12345".parse().unwrap(), - dc_idx: 1, - proto_tag: ProtoTag::Intermediate, - enc_key: key, - enc_iv: iv, - dec_key: key, - dec_iv: iv, - is_tls: false, - }; - - let session_task = tokio::spawn(handle_via_middle_proxy( - crypto_reader, - crypto_writer, - success, - me_pool.clone(), - stats.clone(), - config, - buffer_pool, - "127.0.0.1:443".parse().unwrap(), - rng, - route_runtime.subscribe(), - route_runtime.snapshot(), - 0xD00D_0000 + round, - )); - - timeout(TokioDuration::from_millis(500), async { - loop { - if stats.get_current_connections_me() >= 1 { - break; - } - tokio::time::sleep(TokioDuration::from_millis(10)).await; - } - }) - .await - .expect("session must start before forced drop in burst round"); - - me_pool.registry().unregister(target_conn_id).await; - drop(client_writer_side); - - let result = timeout(TokioDuration::from_secs(2), session_task) - .await - .expect("burst round session must terminate quickly") - .expect("burst round session must not panic"); - - assert!( - result.is_ok(), - "burst round {round}: expected clean shutdown after ME drop + EOF, got: {:?}", - result - ); - } -} diff --git a/src/proxy/tests/middle_relay_stub_completion_security_tests.rs b/src/proxy/tests/middle_relay_stub_completion_security_tests.rs index 2635a28..fbb9081 100644 --- a/src/proxy/tests/middle_relay_stub_completion_security_tests.rs +++ b/src/proxy/tests/middle_relay_stub_completion_security_tests.rs @@ -126,6 +126,7 @@ async fn c2me_channel_full_path_yields_then_sends() { payload: make_pooled_payload(&[0xBB, 0xCC]), flags: 2, }, + None, ) .await }); diff --git a/src/proxy/tests/middle_relay_tiny_frame_debt_concurrency_security_tests.rs b/src/proxy/tests/middle_relay_tiny_frame_debt_concurrency_security_tests.rs new file mode 100644 index 0000000..6b1d511 --- /dev/null +++ b/src/proxy/tests/middle_relay_tiny_frame_debt_concurrency_security_tests.rs @@ -0,0 +1,372 @@ +use super::*; +use crate::crypto::AesCtr; +use crate::stats::Stats; +use crate::stream::{BufferPool, CryptoReader}; +use std::sync::Arc; +use std::sync::atomic::AtomicU64; +use std::time::Instant; +use tokio::io::{AsyncRead, AsyncWriteExt, duplex}; +use tokio::task::JoinSet; +use tokio::time::{Duration as TokioDuration, sleep}; + +fn make_crypto_reader(reader: T) -> CryptoReader +where + T: AsyncRead + Unpin + Send + 'static, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoReader::new(reader, AesCtr::new(&key, iv)) +} + +fn encrypt_for_reader(plaintext: &[u8]) -> Vec { + let key = [0u8; 32]; + let iv = 0u128; + let mut cipher = AesCtr::new(&key, iv); + cipher.encrypt(plaintext) +} + +fn make_forensics(conn_id: u64, started_at: Instant) -> RelayForensicsState { + RelayForensicsState { + trace_id: 0xB200_0000 + conn_id, + conn_id, + user: format!("tiny-frame-debt-concurrency-user-{conn_id}"), + peer: "127.0.0.1:50000".parse().expect("peer parse must succeed"), + peer_hash: hash_ip("127.0.0.1".parse().expect("ip parse must succeed")), + started_at, + bytes_c2me: 0, + bytes_me2c: Arc::new(AtomicU64::new(0)), + desync_all_full: false, + } +} + +fn make_enabled_idle_policy() -> RelayClientIdlePolicy { + RelayClientIdlePolicy { + enabled: true, + soft_idle: Duration::from_millis(50), + hard_idle: Duration::from_millis(120), + grace_after_downstream_activity: Duration::from_secs(0), + legacy_frame_read_timeout: Duration::from_millis(50), + } +} + +async fn read_once( + crypto_reader: &mut CryptoReader, + proto: ProtoTag, + forensics: &RelayForensicsState, + frame_counter: &mut u64, + idle_state: &mut RelayClientIdleState, +) -> Result> { + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let idle_policy = make_enabled_idle_policy(); + let last_downstream_activity_ms = AtomicU64::new(0); + read_client_payload_with_idle_policy( + crypto_reader, + proto, + 1024, + &buffer_pool, + forensics, + frame_counter, + &stats, + &idle_policy, + idle_state, + &last_downstream_activity_ms, + forensics.started_at, + ) + .await +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_parallel_pure_tiny_floods_all_fail_closed() { + let mut set = JoinSet::new(); + + for idx in 0..32u64 { + set.spawn(async move { + let (reader, mut writer) = duplex(4096); + let mut crypto_reader = make_crypto_reader(reader); + let started = Instant::now(); + let forensics = make_forensics(1000 + idx, started); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(started); + + let flood_plaintext = vec![0u8; 1024]; + let flood_encrypted = encrypt_for_reader(&flood_plaintext); + writer.write_all(&flood_encrypted).await.unwrap(); + drop(writer); + + let result = run_relay_test_step_timeout( + "tiny flood task", + read_once( + &mut crypto_reader, + ProtoTag::Abridged, + &forensics, + &mut frame_counter, + &mut idle_state, + ), + ) + .await; + + assert!(matches!(result, Err(ProxyError::Proxy(_)))); + assert_eq!(frame_counter, 0); + }); + } + + while let Some(result) = set.join_next().await { + result.expect("parallel tiny flood worker must not panic"); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_parallel_benign_tiny_burst_then_real_all_pass() { + let mut set = JoinSet::new(); + + for idx in 0..24u64 { + set.spawn(async move { + let (reader, mut writer) = duplex(2048); + let mut crypto_reader = make_crypto_reader(reader); + let started = Instant::now(); + let forensics = make_forensics(2000 + idx, started); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(started); + + let payload = [idx as u8, 2, 3, 4]; + let mut plaintext = Vec::with_capacity(20); + for _ in 0..6 { + plaintext.push(0x00); + } + plaintext.push(0x01); + plaintext.extend_from_slice(&payload); + let encrypted = encrypt_for_reader(&plaintext); + writer.write_all(&encrypted).await.unwrap(); + + let result = run_relay_test_step_timeout( + "benign tiny burst read", + read_once( + &mut crypto_reader, + ProtoTag::Abridged, + &forensics, + &mut frame_counter, + &mut idle_state, + ), + ) + .await + .expect("benign payload must parse") + .expect("benign payload must return frame"); + + assert_eq!(result.0.as_ref(), &payload); + assert_eq!(frame_counter, 1); + }); + } + + while let Some(result) = set.join_next().await { + result.expect("parallel benign worker must not panic"); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn adversarial_lockstep_alternating_attack_under_jitter_closes() { + let mut set = JoinSet::new(); + + for idx in 0..12u64 { + set.spawn(async move { + let (reader, mut writer) = duplex(8192); + let mut crypto_reader = make_crypto_reader(reader); + let started = Instant::now(); + let forensics = make_forensics(3000 + idx, started); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(started); + + let mut plaintext = Vec::with_capacity(2000); + for n in 0..180u8 { + plaintext.push(0x00); + plaintext.push(0x01); + plaintext.extend_from_slice(&[n, n ^ 0x21, n ^ 0x42, n ^ 0x84]); + } + let encrypted = encrypt_for_reader(&plaintext); + + let writer_task = tokio::spawn(async move { + for chunk in encrypted.chunks(17) { + writer.write_all(chunk).await.unwrap(); + sleep(TokioDuration::from_millis(1)).await; + } + drop(writer); + }); + + let mut closed = false; + for _ in 0..220 { + let result = run_relay_test_step_timeout( + "alternating jitter read step", + read_once( + &mut crypto_reader, + ProtoTag::Abridged, + &forensics, + &mut frame_counter, + &mut idle_state, + ), + ) + .await; + + match result { + Ok(Some((_payload, _))) => {} + Err(ProxyError::Proxy(_)) => { + closed = true; + break; + } + Ok(None) => break, + Err(other) => panic!("unexpected error in alternating jitter case: {other}"), + } + } + + writer_task + .await + .expect("writer jitter task must not panic"); + assert!(closed, "alternating attack must close before EOF"); + }); + } + + while let Some(result) = set.join_next().await { + result.expect("alternating jitter worker must not panic"); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn integration_mixed_population_attackers_close_benign_survive() { + let mut set = JoinSet::new(); + + for idx in 0..20u64 { + set.spawn(async move { + let (reader, mut writer) = duplex(4096); + let mut crypto_reader = make_crypto_reader(reader); + let started = Instant::now(); + let forensics = make_forensics(4000 + idx, started); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(started); + + if idx % 2 == 0 { + let mut plaintext = Vec::with_capacity(1280); + for n in 0..140u8 { + plaintext.push(0x00); + plaintext.push(0x01); + plaintext.extend_from_slice(&[n, n, n, n]); + } + writer + .write_all(&encrypt_for_reader(&plaintext)) + .await + .unwrap(); + drop(writer); + + let mut closed = false; + for _ in 0..200 { + match read_once( + &mut crypto_reader, + ProtoTag::Abridged, + &forensics, + &mut frame_counter, + &mut idle_state, + ) + .await + { + Ok(Some(_)) => {} + Err(ProxyError::Proxy(_)) => { + closed = true; + break; + } + Ok(None) => break, + Err(other) => panic!("unexpected attacker error: {other}"), + } + } + assert!(closed, "attacker session must fail closed"); + } else { + let payload = [1u8, 9, 8, 7]; + let mut plaintext = Vec::new(); + for _ in 0..4 { + plaintext.push(0x00); + } + plaintext.push(0x01); + plaintext.extend_from_slice(&payload); + writer + .write_all(&encrypt_for_reader(&plaintext)) + .await + .unwrap(); + + let got = read_once( + &mut crypto_reader, + ProtoTag::Abridged, + &forensics, + &mut frame_counter, + &mut idle_state, + ) + .await + .expect("benign session must parse") + .expect("benign session must return a frame"); + assert_eq!(got.0.as_ref(), &payload); + } + }); + } + + while let Some(result) = set.join_next().await { + result.expect("mixed-population worker must not panic"); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn light_fuzz_parallel_patterns_no_hang_or_panic() { + let mut set = JoinSet::new(); + + for case in 0..40u64 { + set.spawn(async move { + let (reader, mut writer) = duplex(8192); + let mut crypto_reader = make_crypto_reader(reader); + let started = Instant::now(); + let forensics = make_forensics(5000 + case, started); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(started); + + let mut seed = 0x9E37_79B9u64 ^ (case << 8); + let mut plaintext = Vec::with_capacity(2048); + for _ in 0..256 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + let is_tiny = (seed & 1) == 0; + if is_tiny { + plaintext.push(0x00); + } else { + plaintext.push(0x01); + plaintext.extend_from_slice(&[(seed >> 8) as u8, 2, 3, 4]); + } + } + + writer + .write_all(&encrypt_for_reader(&plaintext)) + .await + .unwrap(); + drop(writer); + + for _ in 0..320 { + let step = run_relay_test_step_timeout( + "fuzz case read step", + read_once( + &mut crypto_reader, + ProtoTag::Abridged, + &forensics, + &mut frame_counter, + &mut idle_state, + ), + ) + .await; + + match step { + Ok(Some(_)) => {} + Err(ProxyError::Proxy(_)) => break, + Ok(None) => break, + Err(other) => panic!("unexpected fuzz case error: {other}"), + } + } + }); + } + + while let Some(result) = set.join_next().await { + result.expect("fuzz worker must not panic"); + } +} diff --git a/src/proxy/tests/middle_relay_tiny_frame_debt_proto_chunking_security_tests.rs b/src/proxy/tests/middle_relay_tiny_frame_debt_proto_chunking_security_tests.rs new file mode 100644 index 0000000..cbbc971 --- /dev/null +++ b/src/proxy/tests/middle_relay_tiny_frame_debt_proto_chunking_security_tests.rs @@ -0,0 +1,435 @@ +use super::*; +use crate::crypto::AesCtr; +use crate::stats::Stats; +use crate::stream::{BufferPool, CryptoReader, PooledBuffer}; +use std::sync::Arc; +use std::sync::atomic::AtomicU64; +use std::time::Instant; +use tokio::io::{AsyncRead, AsyncWriteExt, duplex}; +use tokio::time::{Duration as TokioDuration, sleep}; + +fn make_crypto_reader(reader: T) -> CryptoReader +where + T: AsyncRead + Unpin + Send + 'static, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoReader::new(reader, AesCtr::new(&key, iv)) +} + +fn encrypt_for_reader(plaintext: &[u8]) -> Vec { + let key = [0u8; 32]; + let iv = 0u128; + let mut cipher = AesCtr::new(&key, iv); + cipher.encrypt(plaintext) +} + +fn make_forensics(conn_id: u64, started_at: Instant) -> RelayForensicsState { + RelayForensicsState { + trace_id: 0xB300_0000 + conn_id, + conn_id, + user: format!("tiny-frame-debt-proto-chunk-user-{conn_id}"), + peer: "127.0.0.1:50000".parse().expect("peer parse must succeed"), + peer_hash: hash_ip("127.0.0.1".parse().expect("ip parse must succeed")), + started_at, + bytes_c2me: 0, + bytes_me2c: Arc::new(AtomicU64::new(0)), + desync_all_full: false, + } +} + +fn make_enabled_idle_policy() -> RelayClientIdlePolicy { + RelayClientIdlePolicy { + enabled: true, + soft_idle: Duration::from_millis(50), + hard_idle: Duration::from_millis(120), + grace_after_downstream_activity: Duration::from_secs(0), + legacy_frame_read_timeout: Duration::from_millis(50), + } +} + +fn append_tiny_frame(plaintext: &mut Vec, proto: ProtoTag) { + match proto { + ProtoTag::Abridged => plaintext.push(0x00), + ProtoTag::Intermediate | ProtoTag::Secure => { + plaintext.extend_from_slice(&0u32.to_le_bytes()) + } + } +} + +fn append_real_frame(plaintext: &mut Vec, proto: ProtoTag, payload: [u8; 4]) { + match proto { + ProtoTag::Abridged => { + plaintext.push(0x01); + plaintext.extend_from_slice(&payload); + } + ProtoTag::Intermediate | ProtoTag::Secure => { + plaintext.extend_from_slice(&4u32.to_le_bytes()); + plaintext.extend_from_slice(&payload); + } + } +} + +async fn write_chunked_with_jitter( + writer: &mut tokio::io::DuplexStream, + bytes: &[u8], + mut seed: u64, +) { + let mut offset = 0usize; + while offset < bytes.len() { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + let chunk_len = 1 + ((seed as usize) & 0x1f); + let end = (offset + chunk_len).min(bytes.len()); + writer.write_all(&bytes[offset..end]).await.unwrap(); + + let delay_ms = ((seed >> 16) % 3) as u64; + if delay_ms > 0 { + sleep(TokioDuration::from_millis(delay_ms)).await; + } + offset = end; + } +} + +async fn read_once_with_state( + crypto_reader: &mut CryptoReader, + proto: ProtoTag, + forensics: &RelayForensicsState, + frame_counter: &mut u64, + idle_state: &mut RelayClientIdleState, +) -> Result> { + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let idle_policy = make_enabled_idle_policy(); + let last_downstream_activity_ms = AtomicU64::new(0); + read_client_payload_with_idle_policy( + crypto_reader, + proto, + 1024, + &buffer_pool, + forensics, + frame_counter, + &stats, + &idle_policy, + idle_state, + &last_downstream_activity_ms, + forensics.started_at, + ) + .await +} + +fn is_fail_closed_outcome(result: &Result>) -> bool { + matches!(result, Err(ProxyError::Proxy(_))) + || matches!(result, Err(ProxyError::Io(e)) if e.kind() == std::io::ErrorKind::TimedOut) +} + +#[tokio::test] +async fn intermediate_chunked_zero_flood_fail_closed() { + let (reader, mut writer) = duplex(4096); + let mut crypto_reader = make_crypto_reader(reader); + let started = Instant::now(); + let forensics = make_forensics(6101, started); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(started); + + let mut plaintext = Vec::with_capacity(4 * 256); + for _ in 0..256 { + append_tiny_frame(&mut plaintext, ProtoTag::Intermediate); + } + let encrypted = encrypt_for_reader(&plaintext); + write_chunked_with_jitter(&mut writer, &encrypted, 0x1111_2222).await; + drop(writer); + + let result = run_relay_test_step_timeout( + "intermediate flood read", + read_once_with_state( + &mut crypto_reader, + ProtoTag::Intermediate, + &forensics, + &mut frame_counter, + &mut idle_state, + ), + ) + .await; + + assert!( + is_fail_closed_outcome(&result), + "zero-length flood must fail closed via debt guard or idle timeout" + ); + assert_eq!(frame_counter, 0); +} + +#[tokio::test] +async fn secure_chunked_zero_flood_fail_closed() { + let (reader, mut writer) = duplex(4096); + let mut crypto_reader = make_crypto_reader(reader); + let started = Instant::now(); + let forensics = make_forensics(6102, started); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(started); + + let mut plaintext = Vec::with_capacity(4 * 256); + for _ in 0..256 { + append_tiny_frame(&mut plaintext, ProtoTag::Secure); + } + let encrypted = encrypt_for_reader(&plaintext); + write_chunked_with_jitter(&mut writer, &encrypted, 0x3333_4444).await; + drop(writer); + + let result = run_relay_test_step_timeout( + "secure flood read", + read_once_with_state( + &mut crypto_reader, + ProtoTag::Secure, + &forensics, + &mut frame_counter, + &mut idle_state, + ), + ) + .await; + + assert!( + is_fail_closed_outcome(&result), + "secure zero-length flood must fail closed via debt guard or idle timeout" + ); + assert_eq!(frame_counter, 0); +} + +#[tokio::test] +async fn intermediate_chunked_alternating_attack_closes_before_eof() { + let (reader, mut writer) = duplex(8192); + let mut crypto_reader = make_crypto_reader(reader); + let started = Instant::now(); + let forensics = make_forensics(6103, started); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(started); + + let mut plaintext = Vec::with_capacity(8 * 200); + for n in 0..180u8 { + append_tiny_frame(&mut plaintext, ProtoTag::Intermediate); + append_real_frame( + &mut plaintext, + ProtoTag::Intermediate, + [n, n ^ 1, n ^ 2, n ^ 3], + ); + } + let encrypted = encrypt_for_reader(&plaintext); + + let writer_task = tokio::spawn(async move { + write_chunked_with_jitter(&mut writer, &encrypted, 0x5555_6666).await; + drop(writer); + }); + + let mut closed = false; + for _ in 0..240 { + let step = run_relay_test_step_timeout( + "intermediate alternating read step", + read_once_with_state( + &mut crypto_reader, + ProtoTag::Intermediate, + &forensics, + &mut frame_counter, + &mut idle_state, + ), + ) + .await; + + match step { + Ok(Some(_)) => {} + Err(ProxyError::Proxy(_)) => { + closed = true; + break; + } + Ok(None) => break, + Err(other) => panic!("unexpected intermediate alternating error: {other}"), + } + } + + writer_task + .await + .expect("intermediate writer task must not panic"); + assert!(closed, "intermediate alternating attack must fail closed"); +} + +#[tokio::test] +async fn secure_chunked_alternating_attack_closes_before_eof() { + let (reader, mut writer) = duplex(8192); + let mut crypto_reader = make_crypto_reader(reader); + let started = Instant::now(); + let forensics = make_forensics(6104, started); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(started); + + let mut plaintext = Vec::with_capacity(8 * 200); + for n in 0..180u8 { + append_tiny_frame(&mut plaintext, ProtoTag::Secure); + append_real_frame(&mut plaintext, ProtoTag::Secure, [n, n ^ 7, n ^ 11, n ^ 19]); + } + let encrypted = encrypt_for_reader(&plaintext); + + let writer_task = tokio::spawn(async move { + write_chunked_with_jitter(&mut writer, &encrypted, 0x7777_8888).await; + drop(writer); + }); + + let mut closed = false; + for _ in 0..240 { + let step = run_relay_test_step_timeout( + "secure alternating read step", + read_once_with_state( + &mut crypto_reader, + ProtoTag::Secure, + &forensics, + &mut frame_counter, + &mut idle_state, + ), + ) + .await; + + match step { + Ok(Some(_)) => {} + Err(ProxyError::Proxy(_)) => { + closed = true; + break; + } + Ok(None) => break, + Err(other) => panic!("unexpected secure alternating error: {other}"), + } + } + + writer_task + .await + .expect("secure writer task must not panic"); + assert!(closed, "secure alternating attack must fail closed"); +} + +#[tokio::test] +async fn intermediate_chunked_safe_small_burst_still_returns_real_frame() { + let (reader, mut writer) = duplex(1024); + let mut crypto_reader = make_crypto_reader(reader); + let started = Instant::now(); + let forensics = make_forensics(6105, started); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(started); + + let payload = [9u8, 8, 7, 6]; + let mut plaintext = Vec::new(); + for _ in 0..7 { + append_tiny_frame(&mut plaintext, ProtoTag::Intermediate); + } + append_real_frame(&mut plaintext, ProtoTag::Intermediate, payload); + let encrypted = encrypt_for_reader(&plaintext); + write_chunked_with_jitter(&mut writer, &encrypted, 0xAAAA_BBBB).await; + + let result = read_once_with_state( + &mut crypto_reader, + ProtoTag::Intermediate, + &forensics, + &mut frame_counter, + &mut idle_state, + ) + .await + .expect("intermediate safe burst should parse") + .expect("intermediate safe burst should return a frame"); + + assert_eq!(result.0.as_ref(), &payload); + assert_eq!(frame_counter, 1); +} + +#[tokio::test] +async fn secure_chunked_safe_small_burst_still_returns_real_frame() { + let (reader, mut writer) = duplex(1024); + let mut crypto_reader = make_crypto_reader(reader); + let started = Instant::now(); + let forensics = make_forensics(6106, started); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(started); + + let payload = [3u8, 1, 4, 1]; + let mut plaintext = Vec::new(); + for _ in 0..7 { + append_tiny_frame(&mut plaintext, ProtoTag::Secure); + } + append_real_frame(&mut plaintext, ProtoTag::Secure, payload); + let encrypted = encrypt_for_reader(&plaintext); + write_chunked_with_jitter(&mut writer, &encrypted, 0xCCCC_DDDD).await; + + let result = read_once_with_state( + &mut crypto_reader, + ProtoTag::Secure, + &forensics, + &mut frame_counter, + &mut idle_state, + ) + .await + .expect("secure safe burst should parse") + .expect("secure safe burst should return a frame"); + + assert_eq!(result.0.as_ref(), &payload); + assert_eq!(frame_counter, 1); +} + +#[tokio::test] +async fn light_fuzz_proto_chunking_outcomes_are_bounded() { + let mut seed = 0xDEAD_BEEF_2026_0322u64; + + for case in 0..48u64 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let proto = if (seed & 1) == 0 { + ProtoTag::Intermediate + } else { + ProtoTag::Secure + }; + + let (reader, mut writer) = duplex(8192); + let mut crypto_reader = make_crypto_reader(reader); + let started = Instant::now(); + let forensics = make_forensics(6200 + case, started); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(started); + + let mut stream = Vec::new(); + let mut local_seed = seed ^ case; + for _ in 0..220 { + local_seed ^= local_seed << 7; + local_seed ^= local_seed >> 9; + local_seed ^= local_seed << 8; + if (local_seed & 1) == 0 { + append_tiny_frame(&mut stream, proto); + } else { + let b = (local_seed >> 8) as u8; + append_real_frame(&mut stream, proto, [b, b ^ 0x12, b ^ 0x24, b ^ 0x48]); + } + } + + let encrypted = encrypt_for_reader(&stream); + write_chunked_with_jitter(&mut writer, &encrypted, seed ^ 0x1234_5678).await; + drop(writer); + + for _ in 0..260 { + let step = run_relay_test_step_timeout( + "fuzz proto read step", + read_once_with_state( + &mut crypto_reader, + proto, + &forensics, + &mut frame_counter, + &mut idle_state, + ), + ) + .await; + + match step { + Ok(Some((_payload, _))) => {} + Err(ProxyError::Proxy(_)) => break, + Err(ProxyError::Io(e)) if e.kind() == std::io::ErrorKind::TimedOut => break, + Ok(None) => break, + Err(other) => panic!("unexpected proto chunking fuzz error: {other}"), + } + } + } +} diff --git a/src/proxy/tests/middle_relay_tiny_frame_debt_security_tests.rs b/src/proxy/tests/middle_relay_tiny_frame_debt_security_tests.rs new file mode 100644 index 0000000..fad87d0 --- /dev/null +++ b/src/proxy/tests/middle_relay_tiny_frame_debt_security_tests.rs @@ -0,0 +1,804 @@ +use super::*; +use crate::crypto::AesCtr; +use crate::stats::Stats; +use crate::stream::{BufferPool, CryptoReader}; +use std::sync::Arc; +use std::sync::atomic::AtomicU64; +use std::time::Instant; +use tokio::io::{AsyncRead, AsyncWriteExt, duplex}; + +fn make_crypto_reader(reader: T) -> CryptoReader +where + T: AsyncRead + Unpin + Send + 'static, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoReader::new(reader, AesCtr::new(&key, iv)) +} + +fn encrypt_for_reader(plaintext: &[u8]) -> Vec { + let key = [0u8; 32]; + let iv = 0u128; + let mut cipher = AesCtr::new(&key, iv); + cipher.encrypt(plaintext) +} + +fn make_forensics(conn_id: u64, started_at: Instant) -> RelayForensicsState { + RelayForensicsState { + trace_id: 0xB100_0000 + conn_id, + conn_id, + user: format!("tiny-frame-debt-user-{conn_id}"), + peer: "127.0.0.1:50000".parse().expect("peer parse must succeed"), + peer_hash: hash_ip("127.0.0.1".parse().expect("ip parse must succeed")), + started_at, + bytes_c2me: 0, + bytes_me2c: Arc::new(AtomicU64::new(0)), + desync_all_full: false, + } +} + +fn make_enabled_idle_policy() -> RelayClientIdlePolicy { + RelayClientIdlePolicy { + enabled: true, + soft_idle: Duration::from_millis(50), + hard_idle: Duration::from_millis(120), + grace_after_downstream_activity: Duration::from_secs(0), + legacy_frame_read_timeout: Duration::from_millis(50), + } +} + +async fn read_bounded( + crypto_reader: &mut CryptoReader, + proto_tag: ProtoTag, + buffer_pool: &Arc, + forensics: &RelayForensicsState, + frame_counter: &mut u64, + stats: &Stats, + idle_policy: &RelayClientIdlePolicy, + idle_state: &mut RelayClientIdleState, + last_downstream_activity_ms: &AtomicU64, + session_started_at: Instant, +) -> Result> { + run_relay_test_step_timeout( + "tiny-frame debt read step", + read_client_payload_with_idle_policy( + crypto_reader, + proto_tag, + 1024, + buffer_pool, + forensics, + frame_counter, + stats, + idle_policy, + idle_state, + last_downstream_activity_ms, + session_started_at, + ), + ) + .await +} + +fn simulate_tiny_debt_pattern(pattern: &[bool], max_steps: usize) -> (Option, u32, usize) { + let mut debt = 0u32; + let mut reals = 0usize; + for (idx, is_tiny) in pattern.iter().copied().take(max_steps).enumerate() { + if is_tiny { + debt = debt.saturating_add(TINY_FRAME_DEBT_PER_TINY); + if debt >= TINY_FRAME_DEBT_LIMIT { + return (Some(idx + 1), debt, reals); + } + } else { + reals = reals.saturating_add(1); + debt = debt.saturating_sub(1); + } + } + (None, debt, reals) +} + +#[test] +fn tiny_frame_debt_constants_match_security_budget_expectations() { + assert_eq!(TINY_FRAME_DEBT_PER_TINY, 8); + assert_eq!(TINY_FRAME_DEBT_LIMIT, 512); +} + +#[test] +fn relay_client_idle_state_initial_debt_is_zero() { + let state = RelayClientIdleState::new(Instant::now()); + assert_eq!(state.tiny_frame_debt, 0); +} + +#[test] +fn on_client_frame_does_not_reset_tiny_frame_debt() { + let now = Instant::now(); + let mut state = RelayClientIdleState::new(now); + state.tiny_frame_debt = 77; + state.on_client_frame(now); + assert_eq!(state.tiny_frame_debt, 77); +} + +#[test] +fn tiny_frame_debt_increment_is_saturating() { + let mut debt = u32::MAX - 1; + debt = debt.saturating_add(TINY_FRAME_DEBT_PER_TINY); + assert_eq!(debt, u32::MAX); +} + +#[test] +fn tiny_frame_debt_decrement_is_saturating() { + let mut debt = 0u32; + debt = debt.saturating_sub(1); + assert_eq!(debt, 0); +} + +#[test] +fn consecutive_tiny_frames_close_exactly_at_threshold() { + let max_tiny_without_close = (TINY_FRAME_DEBT_LIMIT / TINY_FRAME_DEBT_PER_TINY) as usize; + let pattern = vec![true; max_tiny_without_close]; + let (closed_at, _, _) = simulate_tiny_debt_pattern(&pattern, pattern.len()); + assert_eq!(closed_at, Some(max_tiny_without_close)); +} + +#[test] +fn one_less_than_threshold_tiny_frames_do_not_close() { + let tiny_count = (TINY_FRAME_DEBT_LIMIT / TINY_FRAME_DEBT_PER_TINY) as usize - 1; + let pattern = vec![true; tiny_count]; + let (closed_at, debt, _) = simulate_tiny_debt_pattern(&pattern, pattern.len()); + assert_eq!(closed_at, None); + assert!(debt < TINY_FRAME_DEBT_LIMIT); +} + +#[test] +fn alternating_one_to_one_closes_with_bounded_real_frame_count() { + let mut pattern = Vec::with_capacity(512); + for _ in 0..256 { + pattern.push(true); + pattern.push(false); + } + let (closed_at, _, reals) = simulate_tiny_debt_pattern(&pattern, pattern.len()); + assert!(closed_at.is_some()); + assert!( + reals <= 80, + "expected bounded real frames before close, got {reals}" + ); +} + +#[test] +fn alternating_one_to_eight_is_stable_for_long_runs() { + let mut pattern = Vec::with_capacity(9 * 5000); + for _ in 0..5000 { + pattern.push(true); + for _ in 0..8 { + pattern.push(false); + } + } + let (closed_at, debt, _) = simulate_tiny_debt_pattern(&pattern, pattern.len()); + assert_eq!(closed_at, None); + assert!(debt <= TINY_FRAME_DEBT_PER_TINY); +} + +#[test] +fn alternating_one_to_seven_eventually_closes() { + let mut pattern = Vec::with_capacity(8 * 2000); + for _ in 0..2000 { + pattern.push(true); + for _ in 0..7 { + pattern.push(false); + } + } + let (closed_at, _, _) = simulate_tiny_debt_pattern(&pattern, pattern.len()); + assert!( + closed_at.is_some(), + "1:7 tiny-to-real must eventually close" + ); +} + +#[test] +fn two_tiny_one_real_closes_faster_than_one_to_one() { + let mut one_to_one = Vec::with_capacity(512); + for _ in 0..256 { + one_to_one.push(true); + one_to_one.push(false); + } + + let mut two_to_one = Vec::with_capacity(768); + for _ in 0..256 { + two_to_one.push(true); + two_to_one.push(true); + two_to_one.push(false); + } + + let (a_close, _, _) = simulate_tiny_debt_pattern(&one_to_one, one_to_one.len()); + let (b_close, _, _) = simulate_tiny_debt_pattern(&two_to_one, two_to_one.len()); + assert!(a_close.is_some() && b_close.is_some()); + assert!(b_close.unwrap_or(usize::MAX) < a_close.unwrap_or(0)); +} + +#[test] +fn burst_then_drain_can_recover_without_close() { + let burst_tiny = ((TINY_FRAME_DEBT_LIMIT / TINY_FRAME_DEBT_PER_TINY) / 2) as usize; + let mut pattern = Vec::with_capacity(burst_tiny + 600); + for _ in 0..burst_tiny { + pattern.push(true); + } + pattern.extend(std::iter::repeat_n(false, 600)); + + let (closed_at, debt, _) = simulate_tiny_debt_pattern(&pattern, pattern.len()); + assert_eq!(closed_at, None); + assert_eq!(debt, 0); +} + +#[test] +fn light_fuzz_tiny_frame_debt_model_stays_within_bounds() { + let mut seed = 0xA5A5_91C3_2026_0322u64; + for _case in 0..128 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let len = 512 + ((seed as usize) & 0x3ff); + let mut pattern = Vec::with_capacity(len); + let mut local_seed = seed; + for _ in 0..len { + local_seed ^= local_seed << 7; + local_seed ^= local_seed >> 9; + local_seed ^= local_seed << 8; + pattern.push((local_seed & 1) == 0); + } + + let (closed_at, debt, _) = simulate_tiny_debt_pattern(&pattern, pattern.len()); + if closed_at.is_none() { + assert!(debt < TINY_FRAME_DEBT_LIMIT); + } + assert!(debt <= u32::MAX); + } +} + +#[test] +fn stress_many_independent_simulations_keep_isolated_debt_state() { + for idx in 0..2048usize { + let mut pattern = Vec::with_capacity(64); + for j in 0..64usize { + pattern.push(((idx ^ j) & 3) == 0); + } + let (_closed_at, debt, _reals) = simulate_tiny_debt_pattern(&pattern, pattern.len()); + assert!(debt <= TINY_FRAME_DEBT_LIMIT.saturating_add(TINY_FRAME_DEBT_PER_TINY)); + } +} + +#[tokio::test] +async fn idle_policy_enabled_intermediate_zero_length_flood_is_fail_closed() { + let (reader, mut writer) = duplex(4096); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(11, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + let idle_policy = make_enabled_idle_policy(); + let last_downstream_activity_ms = AtomicU64::new(0); + + let flood_plaintext = vec![0u8; 4 * 256]; + let flood_encrypted = encrypt_for_reader(&flood_plaintext); + writer.write_all(&flood_encrypted).await.unwrap(); + drop(writer); + + let result = read_bounded( + &mut crypto_reader, + ProtoTag::Intermediate, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ) + .await; + + assert!(matches!(result, Err(ProxyError::Proxy(_)))); +} + +#[tokio::test] +async fn idle_policy_enabled_secure_zero_length_flood_is_fail_closed() { + let (reader, mut writer) = duplex(4096); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(12, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + let idle_policy = make_enabled_idle_policy(); + let last_downstream_activity_ms = AtomicU64::new(0); + + let flood_plaintext = vec![0u8; 4 * 256]; + let flood_encrypted = encrypt_for_reader(&flood_plaintext); + writer.write_all(&flood_encrypted).await.unwrap(); + drop(writer); + + let result = read_bounded( + &mut crypto_reader, + ProtoTag::Secure, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ) + .await; + + assert!(matches!(result, Err(ProxyError::Proxy(_)))); +} + +#[tokio::test] +async fn intermediate_alternating_zero_and_real_eventually_closes() { + let (reader, mut writer) = duplex(8192); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(13, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + let idle_policy = make_enabled_idle_policy(); + let last_downstream_activity_ms = AtomicU64::new(0); + + let mut plaintext = Vec::with_capacity(3000); + for idx in 0..160u8 { + plaintext.extend_from_slice(&0u32.to_le_bytes()); + plaintext.extend_from_slice(&4u32.to_le_bytes()); + plaintext.extend_from_slice(&[idx, idx ^ 0x11, idx ^ 0x22, idx ^ 0x33]); + } + let encrypted = encrypt_for_reader(&plaintext); + writer.write_all(&encrypted).await.unwrap(); + drop(writer); + + let mut closed = false; + for _ in 0..220 { + let result = read_bounded( + &mut crypto_reader, + ProtoTag::Intermediate, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ) + .await; + + match result { + Ok(Some(_)) => {} + Err(ProxyError::Proxy(_)) => { + closed = true; + break; + } + Ok(None) => break, + Err(other) => panic!("unexpected error while probing alternating close: {other}"), + } + } + + assert!(closed, "intermediate alternating attack must fail closed"); +} + +#[tokio::test] +async fn small_tiny_burst_followed_by_real_frame_does_not_spuriously_close() { + let (reader, mut writer) = duplex(1024); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(14, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + let idle_policy = make_enabled_idle_policy(); + let last_downstream_activity_ms = AtomicU64::new(0); + + let mut plaintext = Vec::with_capacity(64); + for _ in 0..8 { + plaintext.push(0x00); + } + plaintext.push(0x01); + plaintext.extend_from_slice(&[1, 2, 3, 4]); + + let encrypted = encrypt_for_reader(&plaintext); + writer.write_all(&encrypted).await.unwrap(); + + let first = read_bounded( + &mut crypto_reader, + ProtoTag::Abridged, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ) + .await; + + match first { + Ok(Some((payload, _))) => assert_eq!(payload.as_ref(), &[1, 2, 3, 4]), + Err(e) => panic!("unexpected close after small tiny burst: {e}"), + Ok(None) => panic!("unexpected EOF before real frame"), + } +} + +#[tokio::test] +async fn idle_policy_enabled_zero_length_flood_is_fail_closed() { + let (reader, mut writer) = duplex(4096); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(1, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + let idle_policy = make_enabled_idle_policy(); + let last_downstream_activity_ms = AtomicU64::new(0); + + let flood_plaintext = vec![0u8; 1024]; + let flood_encrypted = encrypt_for_reader(&flood_plaintext); + writer + .write_all(&flood_encrypted) + .await + .expect("zero-length flood bytes must be writable"); + drop(writer); + + let result = read_bounded( + &mut crypto_reader, + ProtoTag::Abridged, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ) + .await; + + assert!( + matches!(result, Err(ProxyError::Proxy(_))), + "idle policy enabled must fail closed for pure zero-length flood" + ); +} + +#[tokio::test] +async fn idle_policy_enabled_alternating_tiny_real_eventually_closes() { + let (reader, mut writer) = duplex(8192); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(2, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + let idle_policy = make_enabled_idle_policy(); + let last_downstream_activity_ms = AtomicU64::new(0); + + let mut plaintext = Vec::with_capacity(256 * 6); + for idx in 0..=255u8 { + plaintext.push(0x00); + plaintext.push(0x01); + plaintext.extend_from_slice(&[idx, idx ^ 0x55, idx ^ 0xAA, 0x11]); + } + + let encrypted = encrypt_for_reader(&plaintext); + writer + .write_all(&encrypted) + .await + .expect("alternating flood bytes must be writable"); + drop(writer); + + let mut saw_proxy_close = false; + for _ in 0..300 { + let result = read_bounded( + &mut crypto_reader, + ProtoTag::Abridged, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ) + .await; + + match result { + Ok(Some((_payload, _quickack))) => {} + Err(ProxyError::Proxy(_)) => { + saw_proxy_close = true; + break; + } + Err(ProxyError::Io(e)) => panic!("unexpected IO error before close: {e}"), + Ok(None) => panic!("unexpected EOF before debt-based closure"), + Err(other) => panic!("unexpected error before close: {other}"), + } + } + + assert!( + saw_proxy_close, + "alternating tiny/real sequence must eventually fail closed" + ); +} + +#[tokio::test] +async fn enabled_idle_policy_valid_nonzero_frame_still_passes() { + let (reader, mut writer) = duplex(1024); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(3, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + let idle_policy = make_enabled_idle_policy(); + let last_downstream_activity_ms = AtomicU64::new(0); + + let payload = [7u8, 8, 9, 10]; + let mut plaintext = Vec::with_capacity(1 + payload.len()); + plaintext.push(0x01); + plaintext.extend_from_slice(&payload); + + let encrypted = encrypt_for_reader(&plaintext); + writer + .write_all(&encrypted) + .await + .expect("nonzero frame must be writable"); + + let result = read_bounded( + &mut crypto_reader, + ProtoTag::Abridged, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ) + .await + .expect("valid frame should decode") + .expect("valid frame should return payload"); + + assert_eq!(result.0.as_ref(), &payload); + assert!(!result.1); + assert_eq!(frame_counter, 1); +} + +#[tokio::test] +async fn abridged_quickack_tiny_flood_is_fail_closed() { + let (reader, mut writer) = duplex(4096); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(21, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + let idle_policy = make_enabled_idle_policy(); + let last_downstream_activity_ms = AtomicU64::new(0); + + let flood_plaintext = vec![0x80u8; 256]; + let flood_encrypted = encrypt_for_reader(&flood_plaintext); + writer.write_all(&flood_encrypted).await.unwrap(); + drop(writer); + + let result = read_bounded( + &mut crypto_reader, + ProtoTag::Abridged, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ) + .await; + + assert!( + matches!(result, Err(ProxyError::Proxy(_))), + "quickack-marked zero-length flood must fail closed" + ); +} + +#[tokio::test] +async fn abridged_extended_zero_len_flood_is_fail_closed() { + let (reader, mut writer) = duplex(4096); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(22, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + let idle_policy = make_enabled_idle_policy(); + let last_downstream_activity_ms = AtomicU64::new(0); + + let mut flood_plaintext = Vec::with_capacity(4 * 256); + for _ in 0..256 { + flood_plaintext.extend_from_slice(&[0x7f, 0x00, 0x00, 0x00]); + } + let flood_encrypted = encrypt_for_reader(&flood_plaintext); + writer.write_all(&flood_encrypted).await.unwrap(); + drop(writer); + + let result = read_bounded( + &mut crypto_reader, + ProtoTag::Abridged, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ) + .await; + + assert!( + matches!(result, Err(ProxyError::Proxy(_))), + "extended zero-length abridged flood must fail closed" + ); +} + +#[tokio::test] +async fn one_to_eight_abridged_wire_pattern_survives_without_false_positive_close() { + let mut plaintext = Vec::with_capacity(9 * 300); + for idx in 0..300usize { + plaintext.push(0x00); + for _ in 0..8 { + let b = idx as u8; + plaintext.push(0x01); + plaintext.extend_from_slice(&[b, b ^ 0x11, b ^ 0x22, b ^ 0x33]); + } + } + + // Keep the test single-task and deterministic: make duplex capacity larger than the + // generated ciphertext so write_all cannot block waiting for a concurrent reader. + let duplex_capacity = plaintext.len().saturating_add(1024); + let (reader, mut writer) = duplex(duplex_capacity); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(23, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + let idle_policy = make_enabled_idle_policy(); + let last_downstream_activity_ms = AtomicU64::new(0); + + let encrypted = encrypt_for_reader(&plaintext); + writer.write_all(&encrypted).await.unwrap(); + drop(writer); + + let mut closed = false; + for _ in 0..3000 { + match read_bounded( + &mut crypto_reader, + ProtoTag::Abridged, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ) + .await + { + Ok(Some(_)) => {} + Ok(None) => break, + Err(ProxyError::Proxy(_)) => { + closed = true; + break; + } + Err(other) => panic!("unexpected error in 1:8 wire test: {other}"), + } + } + + assert!( + !closed, + "wire-level 1:8 tiny-to-real pattern should not trigger debt close" + ); +} + +#[tokio::test] +async fn deterministic_light_fuzz_abridged_wire_behavior_matches_model() { + let mut seed = 0xD1CE_BAAD_2026_0322u64; + + for case_idx in 0..32u64 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let events = 300 + ((seed as usize) & 0xff); + let mut pattern = Vec::with_capacity(events); + let mut local = seed; + for _ in 0..events { + local ^= local << 7; + local ^= local >> 9; + local ^= local << 8; + pattern.push((local & 0x03) == 0); + } + + let mut plaintext = Vec::with_capacity(events * 6); + for (idx, tiny) in pattern.iter().copied().enumerate() { + if tiny { + plaintext.push(0x00); + } else { + let b = (idx as u8) ^ (case_idx as u8); + plaintext.push(0x01); + plaintext.extend_from_slice(&[b, b ^ 0x1F, b ^ 0x7A, b ^ 0xC3]); + } + } + + let (reader, mut writer) = duplex(16 * 1024); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(500 + case_idx, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + let idle_policy = make_enabled_idle_policy(); + let last_downstream_activity_ms = AtomicU64::new(0); + + writer + .write_all(&encrypt_for_reader(&plaintext)) + .await + .unwrap(); + drop(writer); + + let (expected_close, _, _) = simulate_tiny_debt_pattern(&pattern, pattern.len()); + let mut observed_close = false; + + for _ in 0..(events + 8) { + match read_bounded( + &mut crypto_reader, + ProtoTag::Abridged, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ) + .await + { + Ok(Some(_)) => {} + Ok(None) => break, + Err(ProxyError::Proxy(_)) => { + observed_close = true; + break; + } + Err(other) => panic!("unexpected fuzz error: {other}"), + } + } + + assert_eq!( + observed_close, + expected_close.is_some(), + "wire parser behavior must match debt model for case {case_idx}" + ); + } +} diff --git a/src/proxy/tests/middle_relay_zero_length_frame_security_tests.rs b/src/proxy/tests/middle_relay_zero_length_frame_security_tests.rs new file mode 100644 index 0000000..dbf6c4c --- /dev/null +++ b/src/proxy/tests/middle_relay_zero_length_frame_security_tests.rs @@ -0,0 +1,121 @@ +use super::*; +use crate::crypto::AesCtr; +use crate::stats::Stats; +use crate::stream::{BufferPool, CryptoReader}; +use std::sync::Arc; +use std::sync::atomic::AtomicU64; +use std::time::Instant; +use tokio::io::{AsyncRead, AsyncWriteExt, duplex}; + +fn make_crypto_reader(reader: T) -> CryptoReader +where + T: AsyncRead + Unpin + Send + 'static, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoReader::new(reader, AesCtr::new(&key, iv)) +} + +fn encrypt_for_reader(plaintext: &[u8]) -> Vec { + let key = [0u8; 32]; + let iv = 0u128; + let mut cipher = AesCtr::new(&key, iv); + cipher.encrypt(plaintext) +} + +fn make_forensics(conn_id: u64, started_at: Instant) -> RelayForensicsState { + RelayForensicsState { + trace_id: 0xB000_0000 + conn_id, + conn_id, + user: format!("zero-len-test-user-{conn_id}"), + peer: "127.0.0.1:50000".parse().expect("peer parse must succeed"), + peer_hash: hash_ip("127.0.0.1".parse().expect("ip parse must succeed")), + started_at, + bytes_c2me: 0, + bytes_me2c: Arc::new(AtomicU64::new(0)), + desync_all_full: false, + } +} + +#[tokio::test] +async fn adversarial_legacy_zero_length_flood_is_fail_closed() { + let (reader, mut writer) = duplex(1024); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(1, session_started_at); + let mut frame_counter = 0u64; + + let flood_plaintext = vec![0u8; 128]; + let flood_encrypted = encrypt_for_reader(&flood_plaintext); + writer + .write_all(&flood_encrypted) + .await + .expect("zero-length flood bytes must be writable"); + drop(writer); + + let result = read_client_payload_legacy( + &mut crypto_reader, + ProtoTag::Abridged, + 1024, + Duration::from_millis(30), + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + ) + .await; + + match result { + Err(ProxyError::Proxy(msg)) => { + assert!( + msg.contains("Excessive zero-length"), + "legacy mode must close flood with explicit zero-length reason, got: {msg}" + ); + } + Ok(None) => panic!("legacy zero-length flood must not be accepted as EOF"), + Ok(Some(_)) => panic!("legacy zero-length flood must not produce a data frame"), + Err(err) => panic!("legacy zero-length flood must be a Proxy error, got: {err}"), + } +} + +#[tokio::test] +async fn business_abridged_nonzero_frame_still_passes() { + let (reader, mut writer) = duplex(1024); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(2, session_started_at); + let mut frame_counter = 0u64; + + let payload = [1u8, 2, 3, 4]; + let mut plaintext = Vec::with_capacity(1 + payload.len()); + plaintext.push(0x01); + plaintext.extend_from_slice(&payload); + + let encrypted = encrypt_for_reader(&plaintext); + writer + .write_all(&encrypted) + .await + .expect("nonzero abridged frame must be writable"); + + let result = read_client_payload_legacy( + &mut crypto_reader, + ProtoTag::Abridged, + 1024, + Duration::from_millis(30), + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + ) + .await + .expect("valid abridged frame should decode") + .expect("valid abridged frame should return payload"); + + assert_eq!(result.0.as_ref(), &payload); + assert!(!result.1, "quickack flag must remain false"); + assert_eq!(frame_counter, 1); +} diff --git a/src/proxy/tests/relay_adversarial_tests.rs b/src/proxy/tests/relay_adversarial_tests.rs index 14754cd..38e6fc7 100644 --- a/src/proxy/tests/relay_adversarial_tests.rs +++ b/src/proxy/tests/relay_adversarial_tests.rs @@ -78,7 +78,8 @@ async fn relay_hol_blocking_prevention_regression() { async fn relay_quota_mid_session_cutoff() { let stats = Arc::new(Stats::new()); let user = "quota-mid-user"; - let quota = 5000; + let quota = 5000u64; + let c2s_buf_size = 1024usize; let (client_peer, relay_client) = duplex(8192); let (relay_server, server_peer) = duplex(8192); @@ -93,7 +94,7 @@ async fn relay_quota_mid_session_cutoff() { client_writer, server_reader, server_writer, - 1024, + c2s_buf_size, 1024, user, Arc::clone(&stats), @@ -120,9 +121,25 @@ async fn relay_quota_mid_session_cutoff() { other => panic!("Expected DataQuotaExceeded error, got: {:?}", other), } - let mut small_buf = [0u8; 1]; - let n = sp_reader.read(&mut small_buf).await.unwrap(); - assert_eq!(n, 0, "Server must see EOF after quota reached"); + let mut overshoot_bytes = 0usize; + let mut buf = [0u8; 256]; + loop { + match timeout(Duration::from_millis(20), sp_reader.read(&mut buf)).await { + Ok(Ok(0)) => break, + Ok(Ok(n)) => overshoot_bytes = overshoot_bytes.saturating_add(n), + Ok(Err(e)) => panic!("server read must not fail after relay cutoff: {e}"), + Err(_) => break, + } + } + + assert!( + overshoot_bytes <= c2s_buf_size, + "post-write cutoff may leak at most one C->S chunk after boundary, got {overshoot_bytes}" + ); + assert!( + stats.get_user_quota_used(user) <= quota.saturating_add(c2s_buf_size as u64), + "accounted quota must remain bounded by one in-flight chunk overshoot" + ); } #[tokio::test] diff --git a/src/proxy/tests/relay_atomic_quota_invariant_tests.rs b/src/proxy/tests/relay_atomic_quota_invariant_tests.rs new file mode 100644 index 0000000..1bb00a6 --- /dev/null +++ b/src/proxy/tests/relay_atomic_quota_invariant_tests.rs @@ -0,0 +1,243 @@ +use super::*; +use std::collections::VecDeque; +use std::io; +use std::pin::Pin; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::sync::{Arc, Mutex}; +use std::task::{Context, Poll}; +use tokio::io::{AsyncWrite, AsyncWriteExt}; +use tokio::time::Instant; + +struct ScriptedWriter { + scripted_writes: Arc>>, + write_calls: Arc, +} + +impl ScriptedWriter { + fn new(script: &[usize], write_calls: Arc) -> Self { + Self { + scripted_writes: Arc::new(Mutex::new(script.iter().copied().collect())), + write_calls, + } + } +} + +impl AsyncWrite for ScriptedWriter { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let this = self.get_mut(); + this.write_calls.fetch_add(1, Ordering::Relaxed); + let planned = this + .scripted_writes + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) + .pop_front() + .unwrap_or(buf.len()); + Poll::Ready(Ok(planned.min(buf.len()))) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} + +fn make_stats_io_with_script( + user: &str, + quota_limit: u64, + precharged_quota: u64, + script: &[usize], +) -> ( + StatsIo, + Arc, + Arc, + Arc, +) { + let stats = Arc::new(Stats::new()); + if precharged_quota > 0 { + let user_stats = stats.get_or_create_user_stats_handle(user); + stats.quota_charge_post_write(user_stats.as_ref(), precharged_quota); + } + + let write_calls = Arc::new(AtomicUsize::new(0)); + let quota_exceeded = Arc::new(AtomicBool::new(false)); + let io = StatsIo::new( + ScriptedWriter::new(script, write_calls.clone()), + Arc::new(SharedCounters::new()), + stats.clone(), + user.to_string(), + Some(quota_limit), + quota_exceeded.clone(), + Instant::now(), + ); + + (io, stats, write_calls, quota_exceeded) +} + +#[tokio::test] +async fn direct_partial_write_charges_only_committed_bytes_without_double_charge() { + let user = "direct-partial-charge-user"; + let (mut io, stats, write_calls, quota_exceeded) = + make_stats_io_with_script(user, 1_048_576, 0, &[8 * 1024, 8 * 1024, 48 * 1024]); + let payload = vec![0xAB; 64 * 1024]; + + let n1 = io + .write(&payload) + .await + .expect("first partial write must succeed"); + let n2 = io + .write(&payload) + .await + .expect("second partial write must succeed"); + let n3 = io.write(&payload).await.expect("tail write must succeed"); + + assert_eq!(n1, 8 * 1024); + assert_eq!(n2, 8 * 1024); + assert_eq!(n3, 48 * 1024); + assert_eq!(write_calls.load(Ordering::Relaxed), 3); + assert_eq!( + stats.get_user_quota_used(user), + (n1 + n2 + n3) as u64, + "quota accounting must follow committed bytes only" + ); + assert_eq!( + stats.get_user_total_octets(user), + (n1 + n2 + n3) as u64, + "telemetry octets should match committed bytes on successful writes" + ); + assert!( + !quota_exceeded.load(Ordering::Acquire), + "quota flag should stay false under large remaining budget" + ); +} + +#[tokio::test] +async fn direct_hybrid_branch_selection_matches_contract() { + let near_limit = 256 * 1024u64; + let near_remaining = 32 * 1024u64; + let (mut near_io, _stats, _calls, _flag) = make_stats_io_with_script( + "direct-near-limit-hard-check-user", + near_limit, + near_limit - near_remaining, + &[4 * 1024], + ); + let near_payload = vec![0x11; 4 * 1024]; + let near_written = near_io + .write(&near_payload) + .await + .expect("near-limit write must succeed"); + assert_eq!(near_written, 4 * 1024); + assert_eq!( + near_io.quota_bytes_since_check, 0, + "near-limit branch must go through immediate hard check" + ); + + let (mut far_small_io, _stats, _calls, _flag) = + make_stats_io_with_script("direct-far-small-amortized-user", 1_048_576, 0, &[4 * 1024]); + let far_small_payload = vec![0x22; 4 * 1024]; + let far_small_written = far_small_io + .write(&far_small_payload) + .await + .expect("small far-from-limit write must succeed"); + assert_eq!(far_small_written, 4 * 1024); + assert_eq!( + far_small_io.quota_bytes_since_check, + 4 * 1024, + "small far-from-limit write must go through amortized path" + ); + + let (mut far_large_io, _stats, _calls, _flag) = make_stats_io_with_script( + "direct-far-large-hard-check-user", + 1_048_576, + 0, + &[32 * 1024], + ); + let far_large_payload = vec![0x33; 32 * 1024]; + let far_large_written = far_large_io + .write(&far_large_payload) + .await + .expect("large write must succeed"); + assert_eq!(far_large_written, 32 * 1024); + assert_eq!( + far_large_io.quota_bytes_since_check, 0, + "large write must force immediate hard check even far from limit" + ); +} + +#[tokio::test] +async fn remaining_before_zero_rejects_without_calling_inner_writer() { + let user = "direct-zero-remaining-user"; + let limit = 8u64; + let (mut io, stats, write_calls, quota_exceeded) = + make_stats_io_with_script(user, limit, limit, &[1]); + + let err = io + .write(&[0x44]) + .await + .expect_err("write must fail when remaining quota is zero"); + + assert!( + is_quota_io_error(&err), + "zero-remaining gate must return typed quota I/O error" + ); + assert_eq!( + write_calls.load(Ordering::Relaxed), + 0, + "inner poll_write must not be called when remaining quota is zero" + ); + assert!( + quota_exceeded.load(Ordering::Acquire), + "zero-remaining gate must set exceeded flag" + ); + assert_eq!(stats.get_user_quota_used(user), limit); +} + +#[tokio::test] +async fn exceeded_flag_blocks_following_poll_before_inner_write() { + let user = "direct-exceeded-visibility-user"; + let (mut io, stats, write_calls, quota_exceeded) = + make_stats_io_with_script(user, 1, 0, &[1, 1]); + + let first = io + .write(&[0x55]) + .await + .expect("first byte should consume remaining quota"); + assert_eq!(first, 1); + assert!( + quota_exceeded.load(Ordering::Acquire), + "hard check should store quota_exceeded after boundary hit" + ); + + let second = io + .write(&[0x66]) + .await + .expect_err("next write must be rejected by early exceeded gate"); + assert!( + is_quota_io_error(&second), + "following write must fail with typed quota error" + ); + assert_eq!( + write_calls.load(Ordering::Relaxed), + 1, + "second write must be cut before touching inner writer" + ); + assert_eq!(stats.get_user_quota_used(user), 1); +} + +#[test] +fn adaptive_interval_clamp_matches_contract() { + assert_eq!(quota_adaptive_interval_bytes(0), 4 * 1024); + assert_eq!(quota_adaptive_interval_bytes(2 * 1024), 4 * 1024); + assert_eq!(quota_adaptive_interval_bytes(32 * 1024), 16 * 1024); + assert_eq!(quota_adaptive_interval_bytes(256 * 1024), 64 * 1024); + + assert!(should_immediate_quota_check(32 * 1024, 4 * 1024)); + assert!(should_immediate_quota_check(1_048_576, 32 * 1024)); + assert!(!should_immediate_quota_check(1_048_576, 4 * 1024)); +} diff --git a/src/proxy/tests/relay_quota_boundary_blackhat_tests.rs b/src/proxy/tests/relay_quota_boundary_blackhat_tests.rs index 080240a..9a32b26 100644 --- a/src/proxy/tests/relay_quota_boundary_blackhat_tests.rs +++ b/src/proxy/tests/relay_quota_boundary_blackhat_tests.rs @@ -29,6 +29,11 @@ async fn read_available(reader: &mut R, budget: Duration) total } +fn preload_user_quota(stats: &Stats, user: &str, bytes: u64) { + let user_stats = stats.get_or_create_user_stats_handle(user); + stats.quota_charge_post_write(user_stats.as_ref(), bytes); +} + #[tokio::test] async fn integration_full_duplex_exact_budget_then_hard_cutoff() { let stats = Arc::new(Stats::new()); @@ -102,14 +107,14 @@ async fn integration_full_duplex_exact_budget_then_hard_cutoff() { relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == "quota-full-duplex-boundary-user" )); - assert!(stats.get_user_total_octets(user) <= 10); + assert!(stats.get_user_quota_used(user) <= 10); } #[tokio::test] async fn negative_preloaded_quota_blocks_both_directions_immediately() { let stats = Arc::new(Stats::new()); let user = "quota-preloaded-cutoff-user"; - stats.add_user_octets_from(user, 5); + preload_user_quota(stats.as_ref(), user, 5); let (mut client_peer, relay_client) = duplex(2048); let (relay_server, mut server_peer) = duplex(2048); @@ -154,7 +159,7 @@ async fn negative_preloaded_quota_blocks_both_directions_immediately() { relay_result, Err(ProxyError::DataQuotaExceeded { .. }) )); - assert!(stats.get_user_total_octets(user) <= 5); + assert!(stats.get_user_quota_used(user) <= 5); } #[tokio::test] @@ -212,7 +217,7 @@ async fn edge_quota_one_bidirectional_race_allows_at_most_one_forwarded_octet() relay_result, Err(ProxyError::DataQuotaExceeded { .. }) )); - assert!(stats.get_user_total_octets(user) <= 1); + assert!(stats.get_user_quota_used(user) <= 1); } #[tokio::test] @@ -277,7 +282,7 @@ async fn adversarial_blackhat_alternating_fragmented_jitter_never_overshoots_glo delivered_to_server + delivered_to_client <= quota as usize, "combined forwarded bytes must never exceed configured quota" ); - assert!(stats.get_user_total_octets(user) <= quota); + assert!(stats.get_user_quota_used(user) <= quota); } #[tokio::test] @@ -356,7 +361,7 @@ async fn light_fuzz_randomized_schedule_preserves_quota_and_forwarded_byte_invar "fuzz case {case}: forwarded bytes must not exceed quota" ); assert!( - stats.get_user_total_octets(&user) <= quota, + stats.get_user_quota_used(&user) <= quota, "fuzz case {case}: accounted bytes must not exceed quota" ); } @@ -451,7 +456,7 @@ async fn stress_multi_relay_same_user_mixed_direction_jitter_respects_global_quo } assert!( - stats.get_user_total_octets(user) <= quota, + stats.get_user_quota_used(user) <= quota, "global per-user quota must hold under concurrent mixed-direction relay stress" ); assert!( diff --git a/src/proxy/tests/relay_quota_extended_attack_surface_security_tests.rs b/src/proxy/tests/relay_quota_extended_attack_surface_security_tests.rs new file mode 100644 index 0000000..8ce1c26 --- /dev/null +++ b/src/proxy/tests/relay_quota_extended_attack_surface_security_tests.rs @@ -0,0 +1,399 @@ +use super::relay_bidirectional; +use crate::error::ProxyError; +use crate::stats::Stats; +use crate::stream::BufferPool; +use rand::rngs::StdRng; +use rand::{RngExt, SeedableRng}; +use std::sync::Arc; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::time::{Duration, timeout}; + +async fn read_available( + reader: &mut R, + budget: Duration, +) -> usize { + let start = tokio::time::Instant::now(); + let mut total = 0usize; + let mut buf = [0u8; 128]; + + loop { + let elapsed = start.elapsed(); + if elapsed >= budget { + break; + } + let remaining = budget.saturating_sub(elapsed); + match timeout(remaining, reader.read(&mut buf)).await { + Ok(Ok(0)) => break, + Ok(Ok(n)) => total = total.saturating_add(n), + Ok(Err(_)) | Err(_) => break, + } + } + + total +} + +fn preload_user_quota(stats: &Stats, user: &str, bytes: u64) { + let user_stats = stats.get_or_create_user_stats_handle(user); + stats.quota_charge_post_write(user_stats.as_ref(), bytes); +} + +#[tokio::test] +async fn positive_quota_path_forwards_both_directions_within_limit() { + let stats = Arc::new(Stats::new()); + let user = "quota-extended-positive-user"; + + let (mut client_peer, relay_client) = duplex(2048); + let (relay_server, mut server_peer) = duplex(2048); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 256, + 256, + user, + Arc::clone(&stats), + Some(16), + Arc::new(BufferPool::new()), + )); + + client_peer + .write_all(&[0xAA, 0xBB, 0xCC, 0xDD]) + .await + .unwrap(); + server_peer.read_exact(&mut [0u8; 4]).await.unwrap(); + + server_peer + .write_all(&[0x11, 0x22, 0x33, 0x44]) + .await + .unwrap(); + client_peer.read_exact(&mut [0u8; 4]).await.unwrap(); + + drop(client_peer); + drop(server_peer); + + let relay_result = timeout(Duration::from_secs(2), relay) + .await + .unwrap() + .unwrap(); + assert!(relay_result.is_ok()); + assert!(stats.get_user_quota_used(user) <= 16); +} + +#[tokio::test] +async fn negative_preloaded_quota_forbids_any_forwarding() { + let stats = Arc::new(Stats::new()); + let user = "quota-extended-negative-user"; + preload_user_quota(stats.as_ref(), user, 8); + + let (mut client_peer, relay_client) = duplex(1024); + let (relay_server, mut server_peer) = duplex(1024); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 128, + 128, + user, + Arc::clone(&stats), + Some(8), + Arc::new(BufferPool::new()), + )); + + client_peer.write_all(&[0xAA]).await.unwrap(); + server_peer.write_all(&[0xBB]).await.unwrap(); + + assert_eq!( + read_available(&mut server_peer, Duration::from_millis(120)).await, + 0 + ); + assert_eq!( + read_available(&mut client_peer, Duration::from_millis(120)).await, + 0 + ); + + let relay_result = timeout(Duration::from_secs(2), relay) + .await + .unwrap() + .unwrap(); + assert!(matches!( + relay_result, + Err(ProxyError::DataQuotaExceeded { .. }) + )); + assert!(stats.get_user_quota_used(user) <= 8); +} + +#[tokio::test] +async fn edge_quota_one_ensures_at_most_one_byte_across_directions() { + let stats = Arc::new(Stats::new()); + let user = "quota-extended-edge-user"; + + let (mut client_peer, relay_client) = duplex(1024); + let (relay_server, mut server_peer) = duplex(1024); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 128, + 128, + user, + Arc::clone(&stats), + Some(1), + Arc::new(BufferPool::new()), + )); + + let _ = tokio::join!( + client_peer.write_all(&[0xFE]), + server_peer.write_all(&[0xEF]), + ); + + let mut buf = [0u8; 1]; + let delivered_s2c = timeout(Duration::from_millis(120), client_peer.read(&mut buf)) + .await + .unwrap() + .unwrap_or(0); + let delivered_c2s = timeout(Duration::from_millis(120), server_peer.read(&mut buf)) + .await + .unwrap() + .unwrap_or(0); + + assert!(delivered_s2c + delivered_c2s <= 1); + + let relay_result = timeout(Duration::from_secs(2), relay) + .await + .unwrap() + .unwrap(); + assert!(matches!( + relay_result, + Err(ProxyError::DataQuotaExceeded { .. }) + )); +} + +#[tokio::test] +async fn adversarial_blackhat_alternating_jitter_does_not_overshoot_quota() { + let stats = Arc::new(Stats::new()); + let user = "quota-extended-blackhat-user"; + let quota = 24u64; + + let (mut client_peer, relay_client) = duplex(4096); + let (relay_server, mut server_peer) = duplex(4096); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 256, + 256, + user, + Arc::clone(&stats), + Some(quota), + Arc::new(BufferPool::new()), + )); + + let mut total_forwarded = 0usize; + + for i in 0..256usize { + if relay.is_finished() { + break; + } + if (i & 1) == 0 { + let _ = client_peer.write_all(&[(i as u8) ^ 0x57]).await; + let mut one = [0u8; 1]; + if let Ok(Ok(n)) = timeout(Duration::from_millis(6), server_peer.read(&mut one)).await { + total_forwarded += n; + } + } else { + let _ = server_peer.write_all(&[(i as u8) ^ 0xA8]).await; + let mut one = [0u8; 1]; + if let Ok(Ok(n)) = timeout(Duration::from_millis(6), client_peer.read(&mut one)).await { + total_forwarded += n; + } + } + + tokio::time::sleep(Duration::from_millis(((i % 3) + 1) as u64)).await; + } + + let relay_result = timeout(Duration::from_secs(3), relay) + .await + .unwrap() + .unwrap(); + assert!(matches!( + relay_result, + Err(ProxyError::DataQuotaExceeded { .. }) + )); + assert!(total_forwarded <= quota as usize); + assert!(stats.get_user_quota_used(user) <= quota); +} + +#[tokio::test] +async fn light_fuzz_random_quota_schedule_preserves_quota_invariants() { + let mut rng = StdRng::seed_from_u64(0xBEEF_C0DE); + + for case in 0..32u64 { + let stats = Arc::new(Stats::new()); + let user = format!("quota-extended-fuzz-{case}"); + let quota = rng.random_range(1u64..=35u64); + + let (mut client_peer, relay_client) = duplex(4096); + let (relay_server, mut server_peer) = duplex(4096); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_user = user.clone(); + let relay_stats = Arc::clone(&stats); + let relay = tokio::spawn(async move { + relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 256, + 256, + &relay_user, + Arc::clone(&relay_stats), + Some(quota), + Arc::new(BufferPool::new()), + ) + .await + }); + + let mut total_forwarded = 0usize; + + for _ in 0..96usize { + if relay.is_finished() { + break; + } + + if rng.random::() { + let _ = client_peer.write_all(&[rng.random::()]).await; + let mut one = [0u8; 1]; + if let Ok(Ok(n)) = + timeout(Duration::from_millis(4), server_peer.read(&mut one)).await + { + total_forwarded += n; + } + } else { + let _ = server_peer.write_all(&[rng.random::()]).await; + let mut one = [0u8; 1]; + if let Ok(Ok(n)) = + timeout(Duration::from_millis(4), client_peer.read(&mut one)).await + { + total_forwarded += n; + } + } + } + + drop(client_peer); + drop(server_peer); + + let relay_result = timeout(Duration::from_secs(2), relay) + .await + .unwrap() + .unwrap(); + assert!( + relay_result.is_ok() + || matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. })) + ); + assert!(total_forwarded <= quota as usize); + assert!(stats.get_user_quota_used(&user) <= quota); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_parallel_relays_for_one_user_obey_global_quota() { + let stats = Arc::new(Stats::new()); + let user = "quota-extended-stress-user".to_string(); + let quota = 64u64; + + let mut tasks = Vec::new(); + + for worker in 0..4u8 { + let stats = Arc::clone(&stats); + let user = user.clone(); + + tasks.push(tokio::spawn(async move { + let (mut client_peer, relay_client) = duplex(2048); + let (relay_server, mut server_peer) = duplex(2048); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_user = user.clone(); + let relay_stats = Arc::clone(&stats); + let relay = tokio::spawn(async move { + relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 128, + 128, + &relay_user, + Arc::clone(&relay_stats), + Some(quota), + Arc::new(BufferPool::new()), + ) + .await + }); + + let mut total = 0usize; + for step in 0..64u8 { + if relay.is_finished() { + break; + } + if (step as usize + worker as usize) % 2 == 0 { + let _ = client_peer.write_all(&[(step ^ 0x5A)]).await; + let mut one = [0u8; 1]; + if let Ok(Ok(n)) = + timeout(Duration::from_millis(6), server_peer.read(&mut one)).await + { + total += n; + } + } else { + let _ = server_peer.write_all(&[(step ^ 0xA5)]).await; + let mut one = [0u8; 1]; + if let Ok(Ok(n)) = + timeout(Duration::from_millis(6), client_peer.read(&mut one)).await + { + total += n; + } + } + tokio::time::sleep(Duration::from_millis(1)).await; + } + + drop(client_peer); + drop(server_peer); + + let relay_result = timeout(Duration::from_secs(2), relay) + .await + .unwrap() + .unwrap(); + assert!( + relay_result.is_ok() + || matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. })) + ); + total + })); + } + + let mut delivered = 0usize; + for task in tasks { + delivered += task.await.unwrap(); + } + + assert!(stats.get_user_quota_used(&user) <= quota); + assert!(delivered <= quota as usize); +} diff --git a/src/proxy/tests/relay_quota_lock_pressure_adversarial_tests.rs b/src/proxy/tests/relay_quota_lock_pressure_adversarial_tests.rs deleted file mode 100644 index e29e86e..0000000 --- a/src/proxy/tests/relay_quota_lock_pressure_adversarial_tests.rs +++ /dev/null @@ -1,438 +0,0 @@ -use super::*; -use crate::error::ProxyError; -use crate::stats::Stats; -use crate::stream::BufferPool; -use dashmap::DashMap; -use std::sync::Arc; -use std::sync::atomic::AtomicBool; -use std::time::Duration; -use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; -use tokio::sync::Barrier; -use tokio::time::Instant; - -#[test] -fn quota_lock_same_user_returns_same_arc_instance() { - let _guard = super::quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let a = quota_user_lock("quota-lock-same-user"); - let b = quota_user_lock("quota-lock-same-user"); - assert!(Arc::ptr_eq(&a, &b)); -} - -#[test] -fn quota_lock_parallel_same_user_reuses_single_lock() { - let _guard = super::quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let user = "quota-lock-parallel-same"; - let mut handles = Vec::new(); - - for _ in 0..64 { - handles.push(std::thread::spawn(move || quota_user_lock(user))); - } - - let first = handles - .remove(0) - .join() - .expect("thread must return lock handle"); - - for handle in handles { - let got = handle.join().expect("thread must return lock handle"); - assert!(Arc::ptr_eq(&first, &got)); - } -} - -#[test] -fn quota_lock_unique_users_materialize_distinct_entries() { - let _guard = super::quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - - map.clear(); - - let base = format!("quota-lock-distinct-{}", std::process::id()); - let users: Vec = (0..(QUOTA_USER_LOCKS_MAX / 2)) - .map(|idx| format!("{base}-{idx}")) - .collect(); - - for user in &users { - let _ = quota_user_lock(user); - } - - for user in &users { - assert!( - map.get(user).is_some(), - "lock cache must contain entry for {user}" - ); - } -} - -#[test] -fn quota_lock_unique_churn_stress_keeps_all_inserted_keys_addressable() { - let _guard = super::quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - - map.clear(); - - let base = format!("quota-lock-churn-{}", std::process::id()); - for idx in 0..(QUOTA_USER_LOCKS_MAX + 256) { - let _ = quota_user_lock(&format!("{base}-{idx}")); - } - - assert!( - map.len() <= QUOTA_USER_LOCKS_MAX, - "quota lock cache must stay bounded under unique-user churn" - ); -} - -#[test] -fn quota_lock_saturation_returns_stable_overflow_lock_without_cache_growth() { - let _guard = super::quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let prefix = format!("quota-held-{}", std::process::id()); - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - for idx in 0..QUOTA_USER_LOCKS_MAX { - retained.push(quota_user_lock(&format!("{prefix}-{idx}"))); - } - - assert_eq!( - map.len(), - QUOTA_USER_LOCKS_MAX, - "cache must be saturated for overflow check" - ); - - let overflow_user = format!("quota-overflow-{}", std::process::id()); - let overflow_a = quota_user_lock(&overflow_user); - let overflow_b = quota_user_lock(&overflow_user); - - assert_eq!( - map.len(), - QUOTA_USER_LOCKS_MAX, - "overflow path must not grow lock cache" - ); - assert!( - map.get(&overflow_user).is_none(), - "overflow user lock must stay outside bounded cache under saturation" - ); - assert!( - Arc::ptr_eq(&overflow_a, &overflow_b), - "overflow user must receive stable striped overflow lock while saturated" - ); - - drop(retained); -} - -#[test] -fn quota_lock_reclaims_unreferenced_entries_before_ephemeral_fallback() { - let _guard = super::quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - // Saturate with retained strong references first so parallel tests cannot - // reclaim our fixture entries before we validate the reclaim path. - let prefix = format!("quota-reclaim-drop-{}", std::process::id()); - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - for idx in 0..QUOTA_USER_LOCKS_MAX { - retained.push(quota_user_lock(&format!("{prefix}-{idx}"))); - } - - drop(retained); - - let overflow_user = format!("quota-reclaim-overflow-{}", std::process::id()); - let overflow = quota_user_lock(&overflow_user); - - assert!( - map.get(&overflow_user).is_some(), - "after reclaiming stale entries, overflow user should become cacheable" - ); - assert!( - Arc::strong_count(&overflow) >= 2, - "cacheable overflow lock should be held by both map and caller" - ); -} - -#[test] -fn quota_lock_saturated_same_user_must_not_return_distinct_locks() { - let _guard = super::quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - for idx in 0..QUOTA_USER_LOCKS_MAX { - retained.push(quota_user_lock(&format!( - "quota-saturated-held-{}-{idx}", - std::process::id() - ))); - } - - let overflow_user = format!("quota-saturated-same-user-{}", std::process::id()); - let a = quota_user_lock(&overflow_user); - let b = quota_user_lock(&overflow_user); - - assert!( - Arc::ptr_eq(&a, &b), - "same user must not receive distinct locks under saturation because that enables quota race bypass" - ); - - drop(retained); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn quota_lock_saturation_concurrent_same_user_never_overshoots_quota() { - let _guard = super::quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - for idx in 0..QUOTA_USER_LOCKS_MAX { - retained.push(quota_user_lock(&format!( - "quota-saturated-race-held-{}-{idx}", - std::process::id() - ))); - } - - let stats = Arc::new(Stats::new()); - let user = format!("quota-saturated-race-user-{}", std::process::id()); - let gate = Arc::new(Barrier::new(2)); - - let worker = |label: u8, stats: Arc, user: String, gate: Arc| { - tokio::spawn(async move { - let counters = Arc::new(SharedCounters::new()); - let quota_exceeded = Arc::new(AtomicBool::new(false)); - let mut io = StatsIo::new( - tokio::io::sink(), - counters, - Arc::clone(&stats), - user, - Some(1), - quota_exceeded, - Instant::now(), - ); - gate.wait().await; - io.write_all(&[label]).await - }) - }; - - let one = worker(0x11, Arc::clone(&stats), user.clone(), Arc::clone(&gate)); - let two = worker(0x22, Arc::clone(&stats), user.clone(), Arc::clone(&gate)); - - let _ = tokio::time::timeout(Duration::from_secs(2), async { - let _ = one.await.expect("task one must not panic"); - let _ = two.await.expect("task two must not panic"); - }) - .await - .expect("quota race workers must complete"); - - assert!( - stats.get_user_total_octets(&user) <= 1, - "saturated lock path must never overshoot quota for same user" - ); - - drop(retained); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn quota_lock_saturation_stress_same_user_never_overshoots_quota() { - let _guard = super::quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - for idx in 0..QUOTA_USER_LOCKS_MAX { - retained.push(quota_user_lock(&format!( - "quota-saturated-stress-held-{}-{idx}", - std::process::id() - ))); - } - - for round in 0..128u32 { - let stats = Arc::new(Stats::new()); - let user = format!("quota-saturated-stress-user-{}-{round}", std::process::id()); - let gate = Arc::new(Barrier::new(2)); - - let one = { - let stats = Arc::clone(&stats); - let user = user.clone(); - let gate = Arc::clone(&gate); - tokio::spawn(async move { - let counters = Arc::new(SharedCounters::new()); - let quota_exceeded = Arc::new(AtomicBool::new(false)); - let mut io = StatsIo::new( - tokio::io::sink(), - counters, - Arc::clone(&stats), - user, - Some(1), - quota_exceeded, - Instant::now(), - ); - gate.wait().await; - io.write_all(&[0x31]).await - }) - }; - - let two = { - let stats = Arc::clone(&stats); - let user = user.clone(); - let gate = Arc::clone(&gate); - tokio::spawn(async move { - let counters = Arc::new(SharedCounters::new()); - let quota_exceeded = Arc::new(AtomicBool::new(false)); - let mut io = StatsIo::new( - tokio::io::sink(), - counters, - Arc::clone(&stats), - user, - Some(1), - quota_exceeded, - Instant::now(), - ); - gate.wait().await; - io.write_all(&[0x32]).await - }) - }; - - let _ = one.await.expect("stress task one must not panic"); - let _ = two.await.expect("stress task two must not panic"); - - assert!( - stats.get_user_total_octets(&user) <= 1, - "round {round}: saturated path must not overshoot quota" - ); - } - - drop(retained); -} - -#[test] -fn quota_error_classifier_accepts_internal_quota_sentinel_only() { - let err = quota_io_error(); - assert!(is_quota_io_error(&err)); -} - -#[test] -fn quota_error_classifier_rejects_plain_permission_denied() { - let err = std::io::Error::new(std::io::ErrorKind::PermissionDenied, "permission denied"); - assert!(!is_quota_io_error(&err)); -} - -#[test] -fn quota_lock_test_scope_recovers_after_guard_poison() { - let poison_result = std::thread::spawn(|| { - let _guard = super::quota_user_lock_test_scope(); - panic!("intentional test-only guard poison"); - }) - .join(); - assert!(poison_result.is_err(), "poison setup thread must panic"); - - let _guard = super::quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let a = quota_user_lock("quota-lock-poison-recovery-user"); - let b = quota_user_lock("quota-lock-poison-recovery-user"); - assert!(Arc::ptr_eq(&a, &b)); -} - -#[tokio::test] -async fn quota_lock_integration_zero_quota_cuts_off_without_forwarding() { - let stats = Arc::new(Stats::new()); - let user = "quota-zero-user"; - - let (mut client_peer, relay_client) = duplex(2048); - let (relay_server, mut server_peer) = duplex(2048); - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let relay = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 512, - 512, - user, - Arc::clone(&stats), - Some(0), - Arc::new(BufferPool::new()), - )); - - client_peer - .write_all(b"x") - .await - .expect("client write must succeed"); - - let mut probe = [0u8; 1]; - let forwarded = - tokio::time::timeout(Duration::from_millis(80), server_peer.read(&mut probe)).await; - if let Ok(Ok(n)) = forwarded { - assert_eq!(n, 0, "zero quota path must not forward payload bytes"); - } - - let result = tokio::time::timeout(Duration::from_secs(2), relay) - .await - .expect("relay must terminate under zero quota") - .expect("relay task must not panic"); - assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. }))); -} - -#[tokio::test] -async fn quota_lock_integration_no_quota_relays_both_directions_under_burst() { - let stats = Arc::new(Stats::new()); - - let (mut client_peer, relay_client) = duplex(8192); - let (relay_server, mut server_peer) = duplex(8192); - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let relay = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 1024, - 1024, - "quota-none-burst-user", - Arc::clone(&stats), - None, - Arc::new(BufferPool::new()), - )); - - let c2s = vec![0xA5; 2048]; - let s2c = vec![0x5A; 1536]; - - client_peer - .write_all(&c2s) - .await - .expect("client burst write must succeed"); - let mut got_c2s = vec![0u8; c2s.len()]; - server_peer - .read_exact(&mut got_c2s) - .await - .expect("server must receive c2s burst"); - assert_eq!(got_c2s, c2s); - - server_peer - .write_all(&s2c) - .await - .expect("server burst write must succeed"); - let mut got_s2c = vec![0u8; s2c.len()]; - client_peer - .read_exact(&mut got_s2c) - .await - .expect("client must receive s2c burst"); - assert_eq!(got_s2c, s2c); - - drop(client_peer); - drop(server_peer); - - let done = tokio::time::timeout(Duration::from_secs(2), relay) - .await - .expect("relay must terminate after peers close") - .expect("relay task must not panic"); - assert!(done.is_ok()); -} diff --git a/src/proxy/tests/relay_quota_model_adversarial_tests.rs b/src/proxy/tests/relay_quota_model_adversarial_tests.rs index 5714f48..04a7020 100644 --- a/src/proxy/tests/relay_quota_model_adversarial_tests.rs +++ b/src/proxy/tests/relay_quota_model_adversarial_tests.rs @@ -32,6 +32,7 @@ async fn drain_available(reader: &mut R, out: &mut Vec #[tokio::test] async fn model_fuzz_bidirectional_schedule_preserves_prefixes_and_quota_budget() { let mut rng = StdRng::seed_from_u64(0xC0DE_CAFE_D15C_F00D); + const MAX_INPUT_CHUNK: usize = 12; for case in 0..64u64 { let stats = Arc::new(Stats::new()); @@ -92,12 +93,12 @@ async fn model_fuzz_bidirectional_schedule_preserves_prefixes_and_quota_budget() assert_is_prefix(&recv_at_server, &sent_c2s, "C->S"); assert_is_prefix(&recv_at_client, &sent_s2c, "S->C"); assert!( - recv_at_server.len() + recv_at_client.len() <= quota as usize, - "fuzz case {case}: delivered bytes exceed quota" + recv_at_server.len() + recv_at_client.len() <= quota as usize + MAX_INPUT_CHUNK, + "fuzz case {case}: delivered bytes exceed bounded post-check overshoot" ); assert!( - stats.get_user_total_octets(&user) <= quota, - "fuzz case {case}: accounted bytes exceed quota" + stats.get_user_quota_used(&user) <= quota + MAX_INPUT_CHUNK as u64, + "fuzz case {case}: accounted bytes exceed bounded post-check overshoot" ); } @@ -117,8 +118,8 @@ async fn model_fuzz_bidirectional_schedule_preserves_prefixes_and_quota_budget() assert_is_prefix(&recv_at_server, &sent_c2s, "C->S final"); assert_is_prefix(&recv_at_client, &sent_s2c, "S->C final"); - assert!(recv_at_server.len() + recv_at_client.len() <= quota as usize); - assert!(stats.get_user_total_octets(&user) <= quota); + assert!(recv_at_server.len() + recv_at_client.len() <= quota as usize + MAX_INPUT_CHUNK); + assert!(stats.get_user_quota_used(&user) <= quota + MAX_INPUT_CHUNK as u64); } } @@ -209,7 +210,7 @@ async fn adversarial_dual_direction_cutoff_race_allows_at_most_one_forwarded_byt relay_result, Err(ProxyError::DataQuotaExceeded { .. }) )); - assert!(stats.get_user_total_octets(user) <= 1); + assert!(stats.get_user_quota_used(user) <= 1); } #[tokio::test(flavor = "multi_thread", worker_threads = 4)] @@ -217,9 +218,12 @@ async fn stress_shared_user_multi_relay_global_quota_never_overshoots_under_mode let stats = Arc::new(Stats::new()); let user = "quota-model-stress-user"; let quota = 96u64; + const WORKERS: usize = 6; + const MAX_WORKER_CHUNK: u64 = 10; + let max_parallel_post_write_overshoot = WORKERS as u64 * MAX_WORKER_CHUNK; let mut workers = Vec::new(); - for worker_id in 0..6u64 { + for worker_id in 0..WORKERS as u64 { let stats = Arc::clone(&stats); let user = user.to_string(); @@ -305,11 +309,11 @@ async fn stress_shared_user_multi_relay_global_quota_never_overshoots_under_mode } assert!( - stats.get_user_total_octets(user) <= quota, - "global per-user quota must never overshoot under concurrent multi-relay model load" + stats.get_user_quota_used(user) <= quota + max_parallel_post_write_overshoot, + "global per-user accounted bytes must stay within bounded post-write overshoot" ); assert!( - delivered_sum <= quota as usize, - "aggregate delivered bytes across relays must remain within global quota" + delivered_sum as u64 <= quota + max_parallel_post_write_overshoot, + "aggregate delivered bytes must stay within bounded post-write overshoot" ); } diff --git a/src/proxy/tests/relay_quota_overflow_regression_tests.rs b/src/proxy/tests/relay_quota_overflow_regression_tests.rs index dfbab85..f1e6c34 100644 --- a/src/proxy/tests/relay_quota_overflow_regression_tests.rs +++ b/src/proxy/tests/relay_quota_overflow_regression_tests.rs @@ -19,13 +19,22 @@ async fn read_available(reader: &mut R, budget_ms: u64) -> total } +fn preload_user_quota(stats: &Stats, user: &str, bytes: u64) { + let user_stats = stats.get_or_create_user_stats_handle(user); + stats.quota_charge_post_write(user_stats.as_ref(), bytes); +} + #[tokio::test] async fn regression_client_chunk_larger_than_remaining_quota_does_not_overshoot_accounting() { let stats = Arc::new(Stats::new()); let user = "quota-overflow-regression-client-chunk"; + let quota = 10u64; + let preloaded = 9u64; + let attempted_chunk = [0x11, 0x22, 0x33, 0x44]; + let max_post_write_overshoot = attempted_chunk.len() as u64; // Leave only 1 byte remaining under quota. - stats.add_user_octets_from(user, 9); + preload_user_quota(stats.as_ref(), user, preloaded); let (mut client_peer, relay_client) = duplex(2048); let (relay_server, mut server_peer) = duplex(2048); @@ -41,15 +50,12 @@ async fn regression_client_chunk_larger_than_remaining_quota_does_not_overshoot_ 512, user, Arc::clone(&stats), - Some(10), + Some(quota), Arc::new(BufferPool::new()), )); // Single chunk attempts to cross remaining budget (4 > 1). - client_peer - .write_all(&[0x11, 0x22, 0x33, 0x44]) - .await - .unwrap(); + client_peer.write_all(&attempted_chunk).await.unwrap(); client_peer.shutdown().await.unwrap(); let forwarded = read_available(&mut server_peer, 60).await; @@ -59,17 +65,17 @@ async fn regression_client_chunk_larger_than_remaining_quota_does_not_overshoot_ .expect("relay must terminate after quota overflow attempt") .expect("relay task must not panic"); - assert_eq!( - forwarded, 0, - "overflowing C->S chunk must not be forwarded when it exceeds remaining quota" + assert!( + forwarded <= attempted_chunk.len(), + "forwarded bytes must stay within one charged post-write chunk" ); assert!(matches!( relay_result, Err(ProxyError::DataQuotaExceeded { .. }) )); assert!( - stats.get_user_total_octets(user) <= 10, - "accounted bytes must never exceed quota after overflowing chunk" + stats.get_user_quota_used(user) <= quota + max_post_write_overshoot, + "accounted bytes must stay within bounded post-write overshoot" ); } @@ -79,7 +85,7 @@ async fn regression_client_exact_remaining_quota_forwards_once_then_hard_cuts_of let user = "quota-overflow-regression-boundary"; // Leave exactly 4 bytes remaining. - stats.add_user_octets_from(user, 6); + preload_user_quota(stats.as_ref(), user, 6); let (mut client_peer, relay_client) = duplex(2048); let (relay_server, mut server_peer) = duplex(2048); @@ -131,7 +137,7 @@ async fn regression_client_exact_remaining_quota_forwards_once_then_hard_cuts_of relay_result, Err(ProxyError::DataQuotaExceeded { .. }) )); - assert!(stats.get_user_total_octets(user) <= 10); + assert!(stats.get_user_quota_used(user) <= 10); } #[tokio::test(flavor = "multi_thread", worker_threads = 4)] @@ -139,9 +145,12 @@ async fn stress_parallel_relays_same_user_quota_overflow_never_exceeds_cap() { let stats = Arc::new(Stats::new()); let user = "quota-overflow-regression-stress"; let quota = 12u64; + const WORKERS: usize = 4; + const BURST_LEN: usize = 64; + let max_parallel_post_write_overshoot = (WORKERS * BURST_LEN) as u64; let mut handles = Vec::new(); - for _ in 0..4usize { + for _ in 0..WORKERS { let stats = Arc::clone(&stats); let user = user.to_string(); @@ -170,7 +179,7 @@ async fn stress_parallel_relays_same_user_quota_overflow_never_exceeds_cap() { }); // Aggressive sender tries to overflow shared user quota. - let burst = vec![0x5Au8; 64]; + let burst = vec![0x5Au8; BURST_LEN]; let _ = client_peer.write_all(&burst).await; let _ = client_peer.shutdown().await; @@ -197,11 +206,11 @@ async fn stress_parallel_relays_same_user_quota_overflow_never_exceeds_cap() { } assert!( - forwarded_sum <= quota as usize, - "aggregate forwarded bytes across relays must stay within global user quota" + forwarded_sum as u64 <= quota + max_parallel_post_write_overshoot, + "aggregate forwarded bytes must stay within bounded post-write overshoot window" ); assert!( - stats.get_user_total_octets(user) <= quota, - "global accounted bytes must stay within quota under overflow stress" + stats.get_user_quota_used(user) <= quota + max_parallel_post_write_overshoot, + "global accounted bytes must stay within bounded post-write overshoot window" ); } diff --git a/src/proxy/tests/relay_quota_wake_liveness_regression_tests.rs b/src/proxy/tests/relay_quota_wake_liveness_regression_tests.rs deleted file mode 100644 index 9f68258..0000000 --- a/src/proxy/tests/relay_quota_wake_liveness_regression_tests.rs +++ /dev/null @@ -1,294 +0,0 @@ -use super::*; -use crate::stats::Stats; -use dashmap::DashMap; -use std::sync::Arc; -use std::sync::atomic::AtomicBool; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; -use tokio::sync::Barrier; -use tokio::time::{Duration, timeout}; - -fn saturate_lock_cache() -> Vec>> { - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - for idx in 0..QUOTA_USER_LOCKS_MAX { - retained.push(quota_user_lock(&format!("quota-liveness-saturated-{idx}"))); - } - retained -} - -fn quota_test_guard() -> impl Drop { - super::quota_user_lock_test_scope() -} - -#[tokio::test] -async fn positive_writer_progresses_after_contention_release_without_external_wake() { - let _guard = quota_test_guard(); - - let _retained = saturate_lock_cache(); - let user = "quota-liveness-writer-positive"; - let stats = Arc::new(Stats::new()); - - let lock = quota_user_lock(user); - let held_guard = lock - .try_lock() - .expect("test must hold user quota lock before write"); - - let counters = Arc::new(SharedCounters::new()); - let quota_exceeded = Arc::new(AtomicBool::new(false)); - let mut io = StatsIo::new( - tokio::io::sink(), - counters, - Arc::clone(&stats), - user.to_string(), - Some(1024), - quota_exceeded, - tokio::time::Instant::now(), - ); - - let writer = tokio::spawn(async move { io.write_all(&[0x11]).await }); - - // Let the initial deferred wake fire while contention is still active. - tokio::time::sleep(Duration::from_millis(4)).await; - - drop(held_guard); - - let completed = timeout(Duration::from_millis(250), writer) - .await - .expect("writer must be re-polled and complete after lock release") - .expect("writer task must not panic"); - assert!(completed.is_ok(), "writer must complete after lock release"); -} - -#[tokio::test] -async fn edge_reader_progresses_after_contention_release_without_external_wake() { - let _guard = quota_test_guard(); - - let _retained = saturate_lock_cache(); - let user = "quota-liveness-reader-edge"; - let stats = Arc::new(Stats::new()); - - let lock = quota_user_lock(user); - let held_guard = lock - .try_lock() - .expect("test must hold user quota lock before read"); - - let counters = Arc::new(SharedCounters::new()); - let quota_exceeded = Arc::new(AtomicBool::new(false)); - let mut io = StatsIo::new( - tokio::io::empty(), - counters, - Arc::clone(&stats), - user.to_string(), - Some(1024), - quota_exceeded, - tokio::time::Instant::now(), - ); - - let reader = tokio::spawn(async move { - let mut one = [0u8; 1]; - io.read(&mut one).await - }); - - tokio::time::sleep(Duration::from_millis(4)).await; - drop(held_guard); - - let completed = timeout(Duration::from_millis(250), reader) - .await - .expect("reader must be re-polled and complete after lock release") - .expect("reader task must not panic"); - assert!(completed.is_ok(), "reader must complete after lock release"); -} - -#[tokio::test] -async fn adversarial_early_deferred_wake_consumption_does_not_deadlock_writer() { - let _guard = quota_test_guard(); - - let _retained = saturate_lock_cache(); - let user = "quota-liveness-adversarial"; - let stats = Arc::new(Stats::new()); - - let lock = quota_user_lock(user); - let held_guard = lock - .try_lock() - .expect("test must hold user quota lock before adversarial write"); - - let counters = Arc::new(SharedCounters::new()); - let quota_exceeded = Arc::new(AtomicBool::new(false)); - let mut io = StatsIo::new( - tokio::io::sink(), - counters, - Arc::clone(&stats), - user.to_string(), - Some(1024), - quota_exceeded, - tokio::time::Instant::now(), - ); - - let writer = tokio::spawn(async move { io.write_all(&[0x22]).await }); - - // Force multiple scheduler rounds while lock remains held so the first - // deferred wake has already been consumed under contention. - for _ in 0..32 { - tokio::task::yield_now().await; - } - - drop(held_guard); - - let completed = timeout(Duration::from_millis(300), writer) - .await - .expect("writer must not stay parked forever after release") - .expect("writer task must not panic"); - assert!(completed.is_ok()); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn integration_parallel_waiters_resume_after_single_release_event() { - let _guard = quota_test_guard(); - - let _retained = saturate_lock_cache(); - let user = format!("quota-liveness-integration-{}", std::process::id()); - let stats = Arc::new(Stats::new()); - let barrier = Arc::new(Barrier::new(13)); - - let lock = quota_user_lock(&user); - let held_guard = lock - .try_lock() - .expect("test must hold user quota lock before launching waiters"); - - let mut waiters = Vec::new(); - for _ in 0..12 { - let stats = Arc::clone(&stats); - let user = user.clone(); - let barrier = Arc::clone(&barrier); - waiters.push(tokio::spawn(async move { - let counters = Arc::new(SharedCounters::new()); - let quota_exceeded = Arc::new(AtomicBool::new(false)); - let mut io = StatsIo::new( - tokio::io::sink(), - counters, - stats, - user, - Some(4096), - quota_exceeded, - tokio::time::Instant::now(), - ); - barrier.wait().await; - io.write_all(&[0x33]).await - })); - } - - barrier.wait().await; - tokio::time::sleep(Duration::from_millis(4)).await; - drop(held_guard); - - timeout(Duration::from_secs(1), async { - for waiter in waiters { - let outcome = waiter.await.expect("waiter must not panic"); - assert!( - outcome.is_ok(), - "waiter must resume and complete after release" - ); - } - }) - .await - .expect("all waiters must complete in bounded time"); -} - -#[tokio::test] -async fn light_fuzz_release_timing_matrix_preserves_liveness() { - let _guard = quota_test_guard(); - - let _retained = saturate_lock_cache(); - let stats = Arc::new(Stats::new()); - - let mut seed = 0xD1CE_F00D_0123_4567u64; - for round in 0..64u32 { - seed ^= seed << 7; - seed ^= seed >> 9; - seed ^= seed << 8; - - let delay_ms = 1 + (seed & 0x7) as u64; - let user = format!("quota-liveness-fuzz-{}-{round}", std::process::id()); - - let lock = quota_user_lock(&user); - let held_guard = lock - .try_lock() - .expect("test must hold user quota lock in fuzz round"); - - let counters = Arc::new(SharedCounters::new()); - let quota_exceeded = Arc::new(AtomicBool::new(false)); - let mut io = StatsIo::new( - tokio::io::sink(), - counters, - Arc::clone(&stats), - user, - Some(2048), - quota_exceeded, - tokio::time::Instant::now(), - ); - - let writer = tokio::spawn(async move { io.write_all(&[0x44]).await }); - - tokio::time::sleep(Duration::from_millis(delay_ms)).await; - drop(held_guard); - - let done = timeout(Duration::from_millis(300), writer) - .await - .expect("fuzz round writer must complete") - .expect("fuzz writer task must not panic"); - assert!( - done.is_ok(), - "fuzz round writer must not stall after release" - ); - } -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn stress_repeated_contention_cycles_remain_live() { - let _guard = quota_test_guard(); - - let _retained = saturate_lock_cache(); - let stats = Arc::new(Stats::new()); - - for cycle in 0..40u32 { - let user = format!("quota-liveness-stress-{}-{cycle}", std::process::id()); - let lock = quota_user_lock(&user); - let held_guard = lock - .try_lock() - .expect("test must hold lock before stress cycle"); - - let mut tasks = Vec::new(); - for _ in 0..6 { - let stats = Arc::clone(&stats); - let user = user.clone(); - tasks.push(tokio::spawn(async move { - let counters = Arc::new(SharedCounters::new()); - let quota_exceeded = Arc::new(AtomicBool::new(false)); - let mut io = StatsIo::new( - tokio::io::sink(), - counters, - stats, - user, - Some(2048), - quota_exceeded, - tokio::time::Instant::now(), - ); - io.write_all(&[0x55]).await - })); - } - - tokio::task::yield_now().await; - drop(held_guard); - - timeout(Duration::from_millis(700), async { - for task in tasks { - let outcome = task.await.expect("stress task must not panic"); - assert!(outcome.is_ok(), "stress writer must complete"); - } - }) - .await - .expect("stress cycle must finish in bounded time"); - } -} diff --git a/src/proxy/tests/relay_quota_waker_storm_adversarial_tests.rs b/src/proxy/tests/relay_quota_waker_storm_adversarial_tests.rs deleted file mode 100644 index fa4878a..0000000 --- a/src/proxy/tests/relay_quota_waker_storm_adversarial_tests.rs +++ /dev/null @@ -1,310 +0,0 @@ -use super::*; -use crate::stats::Stats; -use dashmap::DashMap; -use std::pin::Pin; -use std::sync::Arc; -use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; -use std::task::{Context, Waker}; -use tokio::io::{AsyncWriteExt, ReadBuf}; -use tokio::time::{Duration, timeout}; - -#[derive(Default)] -struct WakeCounter { - wakes: AtomicUsize, -} - -impl std::task::Wake for WakeCounter { - fn wake(self: Arc) { - self.wakes.fetch_add(1, Ordering::Relaxed); - } - - fn wake_by_ref(self: &Arc) { - self.wakes.fetch_add(1, Ordering::Relaxed); - } -} - -fn quota_test_guard() -> impl Drop { - super::quota_user_lock_test_scope() -} - -fn saturate_quota_user_locks() -> Vec>> { - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - for idx in 0..QUOTA_USER_LOCKS_MAX { - retained.push(quota_user_lock(&format!("quota-waker-saturate-{idx}"))); - } - retained -} - -#[tokio::test] -async fn positive_contended_writer_emits_deferred_wake_for_liveness() { - let _guard = quota_test_guard(); - - let _retained = saturate_quota_user_locks(); - let stats = Arc::new(Stats::new()); - let user = "quota-waker-positive-user"; - - let lock = quota_user_lock(user); - let held_guard = lock - .try_lock() - .expect("test must hold overflow lock before polling writer"); - - let counters = Arc::new(SharedCounters::new()); - let quota_exceeded = Arc::new(AtomicBool::new(false)); - let mut io = StatsIo::new( - tokio::io::sink(), - counters, - Arc::clone(&stats), - user.to_string(), - Some(1024), - quota_exceeded, - tokio::time::Instant::now(), - ); - - let wake_counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&wake_counter)); - let mut cx = Context::from_waker(&waker); - - let pending = Pin::new(&mut io).poll_write(&mut cx, &[0xA1]); - assert!(pending.is_pending()); - - timeout(Duration::from_millis(100), async { - loop { - if wake_counter.wakes.load(Ordering::Relaxed) >= 1 { - break; - } - tokio::task::yield_now().await; - } - }) - .await - .expect("contended writer must receive deferred wake"); - - drop(held_guard); - let ready = Pin::new(&mut io).poll_write(&mut cx, &[0xA2]); - assert!( - ready.is_ready(), - "writer must progress after contention release" - ); -} - -#[tokio::test] -async fn adversarial_blackhat_writer_contention_does_not_create_waker_storm() { - let _guard = quota_test_guard(); - - let _retained = saturate_quota_user_locks(); - let stats = Arc::new(Stats::new()); - let user = "quota-waker-blackhat-writer"; - - let lock = quota_user_lock(user); - let held_guard = lock - .try_lock() - .expect("test must hold overflow lock before polling writer"); - - let counters = Arc::new(SharedCounters::new()); - let quota_exceeded = Arc::new(AtomicBool::new(false)); - let mut io = StatsIo::new( - tokio::io::sink(), - counters, - Arc::clone(&stats), - user.to_string(), - Some(1024), - quota_exceeded, - tokio::time::Instant::now(), - ); - - let wake_counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&wake_counter)); - let mut cx = Context::from_waker(&waker); - - for _ in 0..512 { - let poll = Pin::new(&mut io).poll_write(&mut cx, &[0xBE]); - assert!( - poll.is_pending(), - "writer must stay pending while lock is held" - ); - tokio::task::yield_now().await; - } - - let wakes = wake_counter.wakes.load(Ordering::Relaxed); - assert!( - wakes <= 128, - "pending writer retries must not trigger wake storm; observed wakes={wakes}" - ); - - drop(held_guard); - let ready = Pin::new(&mut io).poll_write(&mut cx, &[0xEF]); - assert!(ready.is_ready()); -} - -#[tokio::test] -async fn edge_read_path_contention_keeps_wake_budget_bounded() { - let _guard = quota_test_guard(); - - let _retained = saturate_quota_user_locks(); - let stats = Arc::new(Stats::new()); - let user = "quota-waker-read-edge"; - - let lock = quota_user_lock(user); - let held_guard = lock - .try_lock() - .expect("test must hold overflow lock before polling reader"); - - let counters = Arc::new(SharedCounters::new()); - let quota_exceeded = Arc::new(AtomicBool::new(false)); - let mut io = StatsIo::new( - tokio::io::empty(), - counters, - Arc::clone(&stats), - user.to_string(), - Some(1024), - quota_exceeded, - tokio::time::Instant::now(), - ); - - let wake_counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&wake_counter)); - let mut cx = Context::from_waker(&waker); - let mut storage = [0u8; 1]; - - for _ in 0..512 { - let mut buf = ReadBuf::new(&mut storage); - let poll = Pin::new(&mut io).poll_read(&mut cx, &mut buf); - assert!(poll.is_pending()); - tokio::task::yield_now().await; - } - - let wakes = wake_counter.wakes.load(Ordering::Relaxed); - assert!( - wakes <= 128, - "pending reader retries must not trigger wake storm; observed wakes={wakes}" - ); - - drop(held_guard); - let mut buf = ReadBuf::new(&mut storage); - let ready = Pin::new(&mut io).poll_read(&mut cx, &mut buf); - assert!(ready.is_ready()); -} - -#[tokio::test] -async fn light_fuzz_mixed_poll_schedule_under_contention_stays_bounded() { - let _guard = quota_test_guard(); - - let _retained = saturate_quota_user_locks(); - let stats = Arc::new(Stats::new()); - let user = "quota-waker-fuzz-user"; - - let lock = quota_user_lock(user); - let held_guard = lock - .try_lock() - .expect("test must hold overflow lock before fuzz polling"); - - let counters_w = Arc::new(SharedCounters::new()); - let mut writer_io = StatsIo::new( - tokio::io::sink(), - counters_w, - Arc::clone(&stats), - user.to_string(), - Some(1024), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - - let counters_r = Arc::new(SharedCounters::new()); - let mut reader_io = StatsIo::new( - tokio::io::empty(), - counters_r, - Arc::clone(&stats), - user.to_string(), - Some(1024), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - - let wake_counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&wake_counter)); - let mut cx = Context::from_waker(&waker); - let mut seed = 0xBADC_0FFE_EE11_2211u64; - let mut storage = [0u8; 1]; - - for _ in 0..1024 { - seed ^= seed << 7; - seed ^= seed >> 9; - seed ^= seed << 8; - - if (seed & 1) == 0 { - let poll = Pin::new(&mut writer_io).poll_write(&mut cx, &[0x44]); - assert!(poll.is_pending()); - } else { - let mut buf = ReadBuf::new(&mut storage); - let poll = Pin::new(&mut reader_io).poll_read(&mut cx, &mut buf); - assert!(poll.is_pending()); - } - tokio::task::yield_now().await; - } - - assert!( - wake_counter.wakes.load(Ordering::Relaxed) <= 192, - "mixed contention fuzz must keep deferred wake count tightly bounded" - ); - - drop(held_guard); - let ready_w = Pin::new(&mut writer_io).poll_write(&mut cx, &[0x55]); - assert!(ready_w.is_ready()); - - let mut buf = ReadBuf::new(&mut storage); - let ready_r = Pin::new(&mut reader_io).poll_read(&mut cx, &mut buf); - assert!(ready_r.is_ready()); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -#[ignore = "red-team detector: reveals possible starvation if deferred wake fires before contention release"] -async fn stress_many_contended_writers_complete_after_release() { - let _guard = quota_test_guard(); - - let _retained = saturate_quota_user_locks(); - let user = "quota-waker-stress-user".to_string(); - let stats = Arc::new(Stats::new()); - - let lock = quota_user_lock(&user); - let held_guard = lock - .try_lock() - .expect("test must hold overflow lock before launching contended tasks"); - - let mut tasks = Vec::new(); - for _ in 0..32 { - let stats = Arc::clone(&stats); - let user = user.clone(); - tasks.push(tokio::spawn(async move { - let counters = Arc::new(SharedCounters::new()); - let quota_exceeded = Arc::new(AtomicBool::new(false)); - let mut io = StatsIo::new( - tokio::io::sink(), - counters, - stats, - user, - Some(2048), - quota_exceeded, - tokio::time::Instant::now(), - ); - - io.write_all(&[0xAA]).await - })); - } - - for _ in 0..8 { - tokio::task::yield_now().await; - } - - drop(held_guard); - - timeout(Duration::from_secs(2), async { - for task in tasks { - let result = task.await.expect("stress task must not panic"); - assert!(result.is_ok(), "task must complete after lock release"); - } - }) - .await - .expect("all contended writer tasks must finish in bounded time after release"); -} diff --git a/src/proxy/tests/relay_security_tests.rs b/src/proxy/tests/relay_security_tests.rs deleted file mode 100644 index 50cdfa3..0000000 --- a/src/proxy/tests/relay_security_tests.rs +++ /dev/null @@ -1,1284 +0,0 @@ -use super::relay_bidirectional; -use crate::error::ProxyError; -use crate::stats::Stats; -use crate::stream::BufferPool; -use std::future::poll_fn; -use std::io; -use std::pin::Pin; -use std::sync::Arc; -use std::sync::Mutex; -use std::sync::atomic::{AtomicUsize, Ordering}; -use std::task::Waker; -use std::task::{Context, Poll}; -use tokio::io::{AsyncRead, ReadBuf}; -use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt, duplex}; -use tokio::time::{Duration, timeout}; - -#[derive(Default)] -struct WakeCounter { - wakes: AtomicUsize, -} - -impl std::task::Wake for WakeCounter { - fn wake(self: Arc) { - self.wakes.fetch_add(1, Ordering::Relaxed); - } - - fn wake_by_ref(self: &Arc) { - self.wakes.fetch_add(1, Ordering::Relaxed); - } -} - -#[tokio::test] -async fn quota_lock_contention_does_not_self_wake_pending_writer() { - let _guard = super::quota_user_lock_test_scope(); - let map = super::QUOTA_USER_LOCKS.get_or_init(dashmap::DashMap::new); - map.clear(); - - let stats = Arc::new(Stats::new()); - let user = "quota-lock-contention-user"; - - let lock = super::quota_user_lock(user); - let _held_lock = lock - .try_lock() - .expect("test must hold the per-user quota lock before polling writer"); - - let counters = Arc::new(super::SharedCounters::new()); - let quota_exceeded = Arc::new(std::sync::atomic::AtomicBool::new(false)); - let mut io = super::StatsIo::new( - tokio::io::sink(), - counters, - Arc::clone(&stats), - user.to_string(), - Some(1024), - quota_exceeded, - tokio::time::Instant::now(), - ); - - let wake_counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&wake_counter)); - let mut cx = Context::from_waker(&waker); - - let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x11]); - assert!( - poll.is_pending(), - "writer must remain pending while lock is contended" - ); - assert_eq!( - wake_counter.wakes.load(Ordering::Relaxed), - 0, - "contended quota lock must not self-wake immediately and spin the executor" - ); -} - -#[tokio::test] -async fn quota_lock_contention_writer_schedules_single_deferred_wake_until_lock_acquired() { - let _guard = super::quota_user_lock_test_scope(); - let map = super::QUOTA_USER_LOCKS.get_or_init(dashmap::DashMap::new); - map.clear(); - - let stats = Arc::new(Stats::new()); - let user = "quota-lock-writer-liveness-user"; - - let lock = super::quota_user_lock(user); - let held_lock = lock - .try_lock() - .expect("test must hold the per-user quota lock before polling writer"); - - let counters = Arc::new(super::SharedCounters::new()); - let quota_exceeded = Arc::new(std::sync::atomic::AtomicBool::new(false)); - let mut io = super::StatsIo::new( - tokio::io::sink(), - counters, - Arc::clone(&stats), - user.to_string(), - Some(1024), - quota_exceeded, - tokio::time::Instant::now(), - ); - - let wake_counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&wake_counter)); - let mut cx = Context::from_waker(&waker); - - let first = Pin::new(&mut io).poll_write(&mut cx, &[0x11]); - assert!( - first.is_pending(), - "writer must remain pending while lock is contended" - ); - assert_eq!( - wake_counter.wakes.load(Ordering::Relaxed), - 0, - "deferred wake must not fire synchronously" - ); - - timeout(Duration::from_millis(50), async { - loop { - if wake_counter.wakes.load(Ordering::Relaxed) >= 1 { - break; - } - tokio::task::yield_now().await; - } - }) - .await - .expect("contended writer must schedule a deferred wake in bounded time"); - let wakes_after_first_yield = wake_counter.wakes.load(Ordering::Relaxed); - assert!( - wakes_after_first_yield >= 1, - "contended writer must schedule at least one deferred wake for liveness" - ); - - let second = Pin::new(&mut io).poll_write(&mut cx, &[0x22]); - assert!( - second.is_pending(), - "writer remains pending while lock is still held" - ); - - for _ in 0..8 { - tokio::task::yield_now().await; - } - assert_eq!( - wake_counter.wakes.load(Ordering::Relaxed), - wakes_after_first_yield, - "writer contention should not schedule unbounded wake storms before lock acquisition" - ); - - drop(held_lock); - let released = Pin::new(&mut io).poll_write(&mut cx, &[0x33]); - assert!( - released.is_ready(), - "writer must make progress once quota lock is released" - ); -} - -#[tokio::test] -async fn quota_lock_contention_read_path_schedules_deferred_wake_for_liveness() { - let _guard = super::quota_user_lock_test_scope(); - let map = super::QUOTA_USER_LOCKS.get_or_init(dashmap::DashMap::new); - map.clear(); - - let stats = Arc::new(Stats::new()); - let user = "quota-lock-read-liveness-user"; - - let lock = super::quota_user_lock(user); - let held_lock = lock - .try_lock() - .expect("test must hold the per-user quota lock before polling reader"); - - let counters = Arc::new(super::SharedCounters::new()); - let quota_exceeded = Arc::new(std::sync::atomic::AtomicBool::new(false)); - let mut io = super::StatsIo::new( - tokio::io::empty(), - counters, - Arc::clone(&stats), - user.to_string(), - Some(1024), - quota_exceeded, - tokio::time::Instant::now(), - ); - - let wake_counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&wake_counter)); - let mut cx = Context::from_waker(&waker); - let mut storage = [0u8; 1]; - let mut buf = ReadBuf::new(&mut storage); - - let first = Pin::new(&mut io).poll_read(&mut cx, &mut buf); - assert!( - first.is_pending(), - "reader must remain pending while lock is contended" - ); - assert_eq!( - wake_counter.wakes.load(Ordering::Relaxed), - 0, - "read contention wake must not fire synchronously" - ); - - timeout(Duration::from_millis(50), async { - loop { - if wake_counter.wakes.load(Ordering::Relaxed) >= 1 { - break; - } - tokio::task::yield_now().await; - } - }) - .await - .expect("read contention must schedule a deferred wake in bounded time"); - - drop(held_lock); - let mut buf_after_release = ReadBuf::new(&mut storage); - let released = Pin::new(&mut io).poll_read(&mut cx, &mut buf_after_release); - assert!( - released.is_ready(), - "reader must make progress once quota lock is released" - ); -} - -#[tokio::test] -async fn relay_bidirectional_enforces_live_user_quota() { - let stats = Arc::new(Stats::new()); - let user = "quota-user"; - stats.add_user_octets_from(user, 6); - - let (mut client_peer, relay_client) = duplex(4096); - let (relay_server, mut server_peer) = duplex(4096); - - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let relay_task = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 1024, - 1024, - user, - Arc::clone(&stats), - Some(8), - Arc::new(BufferPool::new()), - )); - - client_peer - .write_all(&[0x10, 0x20, 0x30, 0x40]) - .await - .expect("client write must succeed"); - - let mut forwarded = [0u8; 4]; - let _ = timeout( - Duration::from_millis(200), - server_peer.read_exact(&mut forwarded), - ) - .await; - - let relay_result = timeout(Duration::from_secs(2), relay_task) - .await - .expect("relay task must finish under quota cutoff") - .expect("relay task must not panic"); - - assert!( - matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == "quota-user"), - "relay must surface a typed quota error once live quota is exceeded" - ); -} - -#[tokio::test] -async fn relay_bidirectional_does_not_forward_server_bytes_after_quota_is_exhausted() { - let stats = Arc::new(Stats::new()); - let quota_user = "quota-exhausted-user"; - stats.add_user_octets_from(quota_user, 1); - - let (mut client_peer, relay_client) = duplex(4096); - let (relay_server, mut server_peer) = duplex(4096); - - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let relay_task = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 1024, - 1024, - quota_user, - Arc::clone(&stats), - Some(1), - Arc::new(BufferPool::new()), - )); - - server_peer - .write_all(&[0xde, 0xad, 0xbe, 0xef]) - .await - .expect("server write must succeed"); - - let mut observed = [0u8; 4]; - let forwarded = timeout( - Duration::from_millis(200), - client_peer.read_exact(&mut observed), - ) - .await; - - let relay_result = timeout(Duration::from_secs(2), relay_task) - .await - .expect("relay task must finish under quota cutoff") - .expect("relay task must not panic"); - - assert!( - !matches!(forwarded, Ok(Ok(n)) if n == observed.len()), - "no full server payload should be forwarded once quota is already exhausted" - ); - assert!( - matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user), - "relay must still terminate with a typed quota error" - ); -} - -#[tokio::test] -async fn relay_bidirectional_does_not_leak_partial_server_payload_when_remaining_quota_is_smaller_than_write() - { - let stats = Arc::new(Stats::new()); - let quota_user = "partial-leak-user"; - stats.add_user_octets_from(quota_user, 3); - - let (mut client_peer, relay_client) = duplex(4096); - let (relay_server, mut server_peer) = duplex(4096); - - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let relay_task = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 1024, - 1024, - quota_user, - Arc::clone(&stats), - Some(4), - Arc::new(BufferPool::new()), - )); - - server_peer - .write_all(&[0x11, 0x22, 0x33, 0x44]) - .await - .expect("server write must succeed"); - - let mut observed = [0u8; 8]; - let forwarded = timeout(Duration::from_millis(200), client_peer.read(&mut observed)).await; - - let relay_result = timeout(Duration::from_secs(2), relay_task) - .await - .expect("relay task must finish under quota cutoff") - .expect("relay task must not panic"); - - assert!( - !matches!(forwarded, Ok(Ok(n)) if n > 0), - "quota exhaustion must not leak any partial server payload when remaining quota is smaller than the write" - ); - assert!( - matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user), - "relay must still terminate with a typed quota error" - ); -} - -#[tokio::test] -async fn relay_bidirectional_zero_quota_remains_fail_closed_for_server_payloads_under_stress() { - let stats = Arc::new(Stats::new()); - let quota_user = "zero-quota-user"; - - for payload_len in [1usize, 16, 512, 4096] { - let (mut client_peer, relay_client) = duplex(4096); - let (relay_server, mut server_peer) = duplex(4096); - - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let relay_task = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 1024, - 1024, - quota_user, - Arc::clone(&stats), - Some(0), - Arc::new(BufferPool::new()), - )); - - let payload = vec![0x7f; payload_len]; - let _ = server_peer.write_all(&payload).await; - - let mut observed = vec![0u8; payload_len]; - let forwarded = timeout(Duration::from_millis(200), client_peer.read(&mut observed)).await; - - let relay_result = timeout(Duration::from_secs(2), relay_task) - .await - .expect("relay task must finish under zero-quota cutoff") - .expect("relay task must not panic"); - - assert!( - !matches!(forwarded, Ok(Ok(n)) if n > 0), - "zero quota must not forward any server bytes for payload_len={payload_len}" - ); - assert!( - matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user), - "zero quota must terminate with the typed quota error for payload_len={payload_len}" - ); - } -} - -#[tokio::test] -async fn relay_bidirectional_allows_exact_server_payload_at_quota_boundary() { - let stats = Arc::new(Stats::new()); - let quota_user = "exact-boundary-user"; - - let (mut client_peer, relay_client) = duplex(4096); - let (relay_server, mut server_peer) = duplex(4096); - - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let relay_task = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 1024, - 1024, - quota_user, - Arc::clone(&stats), - Some(4), - Arc::new(BufferPool::new()), - )); - - server_peer - .write_all(&[0x91, 0x92, 0x93, 0x94]) - .await - .expect("server write must succeed at exact quota boundary"); - - let mut observed = [0u8; 4]; - client_peer - .read_exact(&mut observed) - .await - .expect("client must receive the full payload at the exact quota boundary"); - assert_eq!(observed, [0x91, 0x92, 0x93, 0x94]); - - let relay_result = timeout(Duration::from_secs(2), relay_task) - .await - .expect("relay task must finish after exact boundary delivery") - .expect("relay task must not panic"); - - assert!( - matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user), - "relay must close with a typed quota error after reaching the exact boundary" - ); -} - -#[tokio::test] -async fn relay_bidirectional_does_not_forward_client_bytes_after_quota_is_exhausted() { - let stats = Arc::new(Stats::new()); - let quota_user = "client-exhausted-user"; - stats.add_user_octets_from(quota_user, 1); - - let (mut client_peer, relay_client) = duplex(4096); - let (relay_server, mut server_peer) = duplex(4096); - - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let relay_task = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 1024, - 1024, - quota_user, - Arc::clone(&stats), - Some(1), - Arc::new(BufferPool::new()), - )); - - client_peer - .write_all(&[0x51, 0x52, 0x53, 0x54]) - .await - .expect("client write must succeed even when quota is already exhausted"); - - let mut observed = [0u8; 4]; - let forwarded = timeout( - Duration::from_millis(200), - server_peer.read_exact(&mut observed), - ) - .await; - - let relay_result = timeout(Duration::from_secs(2), relay_task) - .await - .expect("relay task must finish under quota cutoff") - .expect("relay task must not panic"); - - assert!( - !matches!(forwarded, Ok(Ok(n)) if n == observed.len()), - "client payload must not be fully forwarded once quota is already exhausted" - ); - assert!( - matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user), - "relay must still terminate with a typed quota error" - ); -} - -#[tokio::test] -async fn relay_bidirectional_server_bytes_remain_blocked_even_under_multiple_payload_sizes() { - let stats = Arc::new(Stats::new()); - let quota_user = "quota-fuzz-user"; - stats.add_user_octets_from(quota_user, 2); - - for payload_len in [1usize, 32, 1024, 8192] { - let (mut client_peer, relay_client) = duplex(4096); - let (relay_server, mut server_peer) = duplex(4096); - - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let relay_task = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 1024, - 1024, - quota_user, - Arc::clone(&stats), - Some(2), - Arc::new(BufferPool::new()), - )); - - let payload = vec![0xaa; payload_len]; - let _ = server_peer.write_all(&payload).await; - - let mut observed = vec![0u8; payload_len]; - let forwarded = timeout( - Duration::from_millis(200), - client_peer.read_exact(&mut observed), - ) - .await; - - let relay_result = timeout(Duration::from_secs(2), relay_task) - .await - .expect("relay task must finish under quota cutoff") - .expect("relay task must not panic"); - - assert!( - !matches!(forwarded, Ok(Ok(n)) if n == payload_len), - "quota exhaustion must block full server-to-client forwarding for payload_len={payload_len}" - ); - assert!( - matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user), - "relay must keep returning the typed quota error for payload_len={payload_len}" - ); - } -} - -#[tokio::test] -async fn relay_bidirectional_terminates_on_activity_timeout() { - tokio::time::pause(); - let stats = Arc::new(Stats::new()); - let user = "timeout-user"; - - let (client_peer, relay_client) = duplex(4096); - let (relay_server, server_peer) = duplex(4096); - - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let relay_task = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 1024, - 1024, - user, - Arc::clone(&stats), - None, // No quota - Arc::new(BufferPool::new()), - )); - - // Wait past the activity timeout threshold (1800 seconds) + buffer - tokio::time::sleep(Duration::from_secs(1805)).await; - - // Resume time to process timeouts - tokio::time::resume(); - - let relay_result = timeout(Duration::from_secs(1), relay_task) - .await - .expect("relay task must finish inside bounded timeout due to inactivity cutoff") - .expect("relay task must not panic"); - - assert!( - relay_result.is_ok(), - "relay should complete successfully on scheduled inactivity timeout" - ); - - // Verify client/server sockets are closed - drop(client_peer); - drop(server_peer); -} - -#[tokio::test] -async fn relay_bidirectional_watchdog_resists_premature_execution() { - tokio::time::pause(); - let stats = Arc::new(Stats::new()); - let user = "activity-user"; - - let (mut client_peer, relay_client) = duplex(4096); - let (relay_server, server_peer) = duplex(4096); - - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let mut relay_task = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 1024, - 1024, - user, - Arc::clone(&stats), - None, - Arc::new(BufferPool::new()), - )); - - // Advance by half the timeout - tokio::time::sleep(Duration::from_secs(900)).await; - - // Provide activity - client_peer - .write_all(&[0xaa, 0xbb]) - .await - .expect("client write must succeed"); - client_peer.flush().await.unwrap(); - - // Advance by another half (total time since start is 1800, but since last activity is 900) - tokio::time::sleep(Duration::from_secs(900)).await; - - tokio::time::resume(); - - // Re-evaluating the task, it should NOT have timed out and still be pending - let relay_result = timeout(Duration::from_millis(100), &mut relay_task).await; - assert!( - relay_result.is_err(), - "Relay must not exit prematurely as long as activity was received before timeout" - ); - - // Explicitly drop sockets to cleanly shut down relay loop - drop(client_peer); - drop(server_peer); - - let completion = timeout(Duration::from_secs(1), relay_task) - .await - .expect("relay task must complete securely after client disconnection") - .expect("relay task must not panic"); - assert!(completion.is_ok(), "relay exits clean"); -} - -#[tokio::test] -async fn relay_bidirectional_half_closure_terminates_cleanly() { - let stats = Arc::new(Stats::new()); - let (client_peer, relay_client) = duplex(4096); - let (relay_server, server_peer) = duplex(4096); - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let relay_task = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 1024, - 1024, - "half-close", - stats, - None, - Arc::new(BufferPool::new()), - )); - - // Half closure: drop the client completely but leave the server active. - drop(client_peer); - - // Check that we don't immediately crash. Bidirectional relay stays open for the server -> client flush. - // Eventually dropping the server cleanly closes the task. - drop(server_peer); - timeout(Duration::from_secs(1), relay_task) - .await - .unwrap() - .unwrap() - .unwrap(); -} - -#[tokio::test] -async fn relay_bidirectional_zero_length_noise_fuzzing() { - let stats = Arc::new(Stats::new()); - let (mut client_peer, relay_client) = duplex(4096); - let (relay_server, mut server_peer) = duplex(4096); - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let relay_task = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 1024, - 1024, - "fuzz", - stats, - None, - Arc::new(BufferPool::new()), - )); - - // Flood with zero-length payloads (edge cases in stream framing logic sometimes loop) - for _ in 0..100 { - client_peer.write_all(&[]).await.unwrap(); - } - client_peer.write_all(&[1, 2, 3]).await.unwrap(); - client_peer.flush().await.unwrap(); - - let mut buf = [0u8; 3]; - server_peer.read_exact(&mut buf).await.unwrap(); - assert_eq!(&buf, &[1, 2, 3]); - - drop(client_peer); - drop(server_peer); - timeout(Duration::from_secs(1), relay_task) - .await - .unwrap() - .unwrap() - .unwrap(); -} - -#[tokio::test] -async fn relay_bidirectional_asymmetric_backpressure() { - let stats = Arc::new(Stats::new()); - // Give the client stream an extremely narrow throughput limit explicitly - let (client_peer, relay_client) = duplex(1024); - let (relay_server, mut server_peer) = duplex(4096); - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let relay_task = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 1024, - 1024, - "slowloris", - stats, - None, - Arc::new(BufferPool::new()), - )); - - let payload = vec![0xba; 65536]; // 64k payload - - // Server attempts to shove 64KB into a relay whose client pipe only holds 1KB! - let write_res = - tokio::time::timeout(Duration::from_millis(50), server_peer.write_all(&payload)).await; - - assert!( - write_res.is_err(), - "Relay backpressure MUST halt the server writer from unbounded buffering when client stream is full!" - ); - - drop(client_peer); - drop(server_peer); - - let completion = timeout(Duration::from_secs(1), relay_task) - .await - .unwrap() - .unwrap(); - assert!( - completion.is_ok() || completion.is_err(), - "Task must unwind reliably (either Ok or BrokenPipe Err) when dropped despite active backpressure locks" - ); -} - -use rand::{RngExt, SeedableRng, rngs::StdRng}; - -#[tokio::test] -async fn relay_bidirectional_light_fuzzing_temporal_jitter() { - tokio::time::pause(); - let stats = Arc::new(Stats::new()); - let (mut client_peer, relay_client) = duplex(4096); - let (relay_server, server_peer) = duplex(4096); - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let mut relay_task = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 1024, - 1024, - "fuzz-user", - stats, - None, - Arc::new(BufferPool::new()), - )); - - let mut rng = StdRng::seed_from_u64(0xDEADBEEF); - - for _ in 0..10 { - // Vary timing significantly up to 1600 seconds (limit is 1800s) - let jitter = rng.random_range(100..1600); - tokio::time::sleep(Duration::from_secs(jitter)).await; - - client_peer.write_all(&[0x11]).await.unwrap(); - client_peer.flush().await.unwrap(); - - // Ensure task has not died - let res = timeout(Duration::from_millis(10), &mut relay_task).await; - assert!( - res.is_err(), - "Relay must remain open indefinitely under light temporal fuzzing with active jitter pulses" - ); - } - - drop(client_peer); - drop(server_peer); - timeout(Duration::from_secs(1), relay_task) - .await - .unwrap() - .unwrap() - .unwrap(); -} - -struct FaultyReader { - error_once: Option, -} - -struct TwoPartyGate { - arrivals: AtomicUsize, - total_bytes: AtomicUsize, - wakers: Mutex>, -} - -impl TwoPartyGate { - fn new() -> Self { - Self { - arrivals: AtomicUsize::new(0), - total_bytes: AtomicUsize::new(0), - wakers: Mutex::new(Vec::new()), - } - } - - fn arrive_or_park(&self, cx: &mut Context<'_>) -> bool { - if self.arrivals.load(Ordering::Relaxed) >= 2 { - return true; - } - - let prev = self.arrivals.fetch_add(1, Ordering::AcqRel); - if prev + 1 >= 2 { - let mut wakers = self.wakers.lock().unwrap_or_else(|p| p.into_inner()); - for waker in wakers.drain(..) { - waker.wake(); - } - true - } else { - let mut wakers = self.wakers.lock().unwrap_or_else(|p| p.into_inner()); - wakers.push(cx.waker().clone()); - false - } - } - - fn total_bytes(&self) -> usize { - self.total_bytes.load(Ordering::Relaxed) - } -} - -struct GateWriter { - gate: Arc, - entered: bool, -} - -impl GateWriter { - fn new(gate: Arc) -> Self { - Self { - gate, - entered: false, - } - } -} - -impl AsyncWrite for GateWriter { - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - if !self.entered { - self.entered = true; - } - - if !self.gate.arrive_or_park(cx) { - return Poll::Pending; - } - - self.gate - .total_bytes - .fetch_add(buf.len(), Ordering::Relaxed); - Poll::Ready(Ok(buf.len())) - } - - fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } -} - -struct GateReader { - gate: Arc, - entered: bool, - emitted: bool, -} - -impl GateReader { - fn new(gate: Arc) -> Self { - Self { - gate, - entered: false, - emitted: false, - } - } -} - -impl AsyncRead for GateReader { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - if self.emitted { - return Poll::Ready(Ok(())); - } - - if !self.entered { - self.entered = true; - } - - if !self.gate.arrive_or_park(cx) { - return Poll::Pending; - } - - buf.put_slice(&[0x42]); - self.gate.total_bytes.fetch_add(1, Ordering::Relaxed); - self.emitted = true; - Poll::Ready(Ok(())) - } -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn adversarial_concurrent_quota_write_race_does_not_overshoot_limit() { - let stats = Arc::new(Stats::new()); - let gate = Arc::new(TwoPartyGate::new()); - let user = "concurrent-quota-write".to_string(); - - let writer_a = super::StatsIo::new( - GateWriter::new(Arc::clone(&gate)), - Arc::new(super::SharedCounters::new()), - Arc::clone(&stats), - user.clone(), - Some(1), - Arc::new(std::sync::atomic::AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - - let writer_b = super::StatsIo::new( - GateWriter::new(Arc::clone(&gate)), - Arc::new(super::SharedCounters::new()), - Arc::clone(&stats), - user.clone(), - Some(1), - Arc::new(std::sync::atomic::AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - - let task_a = tokio::spawn(async move { - let mut w = writer_a; - AsyncWriteExt::write_all(&mut w, &[0x01]).await - }); - let task_b = tokio::spawn(async move { - let mut w = writer_b; - AsyncWriteExt::write_all(&mut w, &[0x02]).await - }); - - let (res_a, res_b) = tokio::join!(task_a, task_b); - let _ = res_a.expect("task a must join"); - let _ = res_b.expect("task b must join"); - - assert!( - gate.total_bytes() <= 1, - "concurrent same-user writes must not forward more than one byte under quota=1" - ); - assert!( - stats.get_user_total_octets(&user) <= 1, - "concurrent same-user writes must not account over limit" - ); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn adversarial_concurrent_quota_read_race_does_not_overshoot_limit() { - let stats = Arc::new(Stats::new()); - let gate = Arc::new(TwoPartyGate::new()); - let user = "concurrent-quota-read".to_string(); - - let reader_a = super::StatsIo::new( - GateReader::new(Arc::clone(&gate)), - Arc::new(super::SharedCounters::new()), - Arc::clone(&stats), - user.clone(), - Some(1), - Arc::new(std::sync::atomic::AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - - let reader_b = super::StatsIo::new( - GateReader::new(Arc::clone(&gate)), - Arc::new(super::SharedCounters::new()), - Arc::clone(&stats), - user.clone(), - Some(1), - Arc::new(std::sync::atomic::AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - - let task_a = tokio::spawn(async move { - let mut r = reader_a; - let mut one = [0u8; 1]; - AsyncReadExt::read_exact(&mut r, &mut one).await - }); - let task_b = tokio::spawn(async move { - let mut r = reader_b; - let mut one = [0u8; 1]; - AsyncReadExt::read_exact(&mut r, &mut one).await - }); - - let (res_a, res_b) = tokio::join!(task_a, task_b); - let _ = res_a.expect("task a must join"); - let _ = res_b.expect("task b must join"); - - assert!( - gate.total_bytes() <= 1, - "concurrent same-user reads must not consume more than one byte under quota=1" - ); - assert!( - stats.get_user_total_octets(&user) <= 1, - "concurrent same-user reads must not account over limit" - ); -} - -#[tokio::test] -async fn stress_same_user_quota_parallel_relays_never_exceed_limit() { - let stats = Arc::new(Stats::new()); - let user = "parallel-quota-user"; - - for _ in 0..128 { - let (mut client_peer_a, relay_client_a) = duplex(256); - let (relay_server_a, mut server_peer_a) = duplex(256); - let (mut client_peer_b, relay_client_b) = duplex(256); - let (relay_server_b, mut server_peer_b) = duplex(256); - - let (client_reader_a, client_writer_a) = tokio::io::split(relay_client_a); - let (server_reader_a, server_writer_a) = tokio::io::split(relay_server_a); - let (client_reader_b, client_writer_b) = tokio::io::split(relay_client_b); - let (server_reader_b, server_writer_b) = tokio::io::split(relay_server_b); - - let relay_a = tokio::spawn(relay_bidirectional( - client_reader_a, - client_writer_a, - server_reader_a, - server_writer_a, - 64, - 64, - user, - Arc::clone(&stats), - Some(1), - Arc::new(BufferPool::new()), - )); - - let relay_b = tokio::spawn(relay_bidirectional( - client_reader_b, - client_writer_b, - server_reader_b, - server_writer_b, - 64, - 64, - user, - Arc::clone(&stats), - Some(1), - Arc::new(BufferPool::new()), - )); - - let _ = tokio::join!( - client_peer_a.write_all(&[0x01]), - server_peer_a.write_all(&[0x02]), - client_peer_b.write_all(&[0x03]), - server_peer_b.write_all(&[0x04]), - ); - - let _ = timeout( - Duration::from_millis(50), - poll_fn(|cx| { - let mut one = [0u8; 1]; - let _ = Pin::new(&mut client_peer_a).poll_read(cx, &mut ReadBuf::new(&mut one)); - Poll::Ready(()) - }), - ) - .await; - - drop(client_peer_a); - drop(server_peer_a); - drop(client_peer_b); - drop(server_peer_b); - - let _ = timeout(Duration::from_secs(1), relay_a).await; - let _ = timeout(Duration::from_secs(1), relay_b).await; - - assert!( - stats.get_user_total_octets(user) <= 1, - "parallel relays must not exceed configured quota" - ); - } -} - -impl FaultyReader { - fn permission_denied_with_message(message: impl Into) -> Self { - Self { - error_once: Some(io::Error::new( - io::ErrorKind::PermissionDenied, - message.into(), - )), - } - } -} - -impl AsyncRead for FaultyReader { - fn poll_read( - mut self: Pin<&mut Self>, - _cx: &mut Context<'_>, - _buf: &mut ReadBuf<'_>, - ) -> Poll> { - if let Some(err) = self.error_once.take() { - return Poll::Ready(Err(err)); - } - Poll::Ready(Ok(())) - } -} - -#[tokio::test] -async fn relay_bidirectional_does_not_misclassify_transport_permission_denied_as_quota() { - let stats = Arc::new(Stats::new()); - let (client_peer, relay_client) = duplex(4096); - let (client_reader, client_writer) = tokio::io::split(relay_client); - - let relay_result = relay_bidirectional( - client_reader, - client_writer, - FaultyReader::permission_denied_with_message("user data quota exceeded"), - tokio::io::sink(), - 1024, - 1024, - "non-quota-permission-denied", - Arc::clone(&stats), - None, - Arc::new(BufferPool::new()), - ) - .await; - - drop(client_peer); - - assert!( - matches!(relay_result, Err(ProxyError::Io(ref err)) if err.kind() == io::ErrorKind::PermissionDenied), - "non-quota transport PermissionDenied errors must remain IO errors" - ); -} - -#[tokio::test] -async fn relay_bidirectional_light_fuzz_permission_denied_messages_remain_io_errors() { - let mut rng = StdRng::seed_from_u64(0xA11CE0B5); - - for i in 0..128u64 { - let stats = Arc::new(Stats::new()); - let (client_peer, relay_client) = duplex(1024); - let (client_reader, client_writer) = tokio::io::split(relay_client); - - let random_len = rng.random_range(1..=48); - let mut msg = String::with_capacity(random_len); - for _ in 0..random_len { - let ch = (b'a' + (rng.random::() % 26)) as char; - msg.push(ch); - } - // Include the legacy quota string in a subset of fuzz cases to validate - // collision resistance against message-based classification. - if i % 7 == 0 { - msg = "user data quota exceeded".to_string(); - } - - let relay_result = relay_bidirectional( - client_reader, - client_writer, - FaultyReader::permission_denied_with_message(msg), - tokio::io::sink(), - 1024, - 1024, - "fuzz-perm-denied", - Arc::clone(&stats), - None, - Arc::new(BufferPool::new()), - ) - .await; - - drop(client_peer); - - assert!( - matches!(relay_result, Err(ProxyError::Io(ref err)) if err.kind() == io::ErrorKind::PermissionDenied), - "transport PermissionDenied case must stay typed as IO regardless of message content" - ); - } -} - -#[tokio::test] -async fn relay_half_close_keeps_reverse_direction_progressing() { - let stats = Arc::new(Stats::new()); - let user = "half-close-user"; - - let (client_peer, relay_client) = duplex(1024); - let (relay_server, server_peer) = duplex(1024); - - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - let (mut cp_reader, mut cp_writer) = tokio::io::split(client_peer); - let (mut sp_reader, mut sp_writer) = tokio::io::split(server_peer); - - let relay_task = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 8192, - 8192, - user, - Arc::clone(&stats), - None, - Arc::new(BufferPool::new()), - )); - - sp_writer - .write_all(&[0x10, 0x20, 0x30, 0x40]) - .await - .unwrap(); - sp_writer.shutdown().await.unwrap(); - - let mut inbound = [0u8; 4]; - cp_reader.read_exact(&mut inbound).await.unwrap(); - assert_eq!(inbound, [0x10, 0x20, 0x30, 0x40]); - - cp_writer - .write_all(&[0xaa, 0xbb, 0xcc, 0xdd]) - .await - .unwrap(); - let mut outbound = [0u8; 4]; - sp_reader.read_exact(&mut outbound).await.unwrap(); - assert_eq!(outbound, [0xaa, 0xbb, 0xcc, 0xdd]); - - relay_task.abort(); - let joined = relay_task.await; - assert!(joined.is_err(), "aborted relay task must return join error"); -} diff --git a/src/stats/mod.rs b/src/stats/mod.rs index bdabe81..2d1f413 100644 --- a/src/stats/mod.rs +++ b/src/stats/mod.rs @@ -26,6 +26,28 @@ enum RouteConnectionGauge { Middle, } +#[derive(Debug, Clone, Copy)] +pub enum MeD2cFlushReason { + QueueDrain, + BatchFrames, + BatchBytes, + MaxDelay, + AckImmediate, + Close, +} + +#[derive(Debug, Clone, Copy)] +pub enum MeD2cWriteMode { + Coalesced, + Split, +} + +#[derive(Debug, Clone, Copy)] +pub enum MeD2cQuotaRejectStage { + PreWrite, + PostWrite, +} + #[must_use = "RouteConnectionLease must be kept alive to hold the connection gauge increment"] pub struct RouteConnectionLease { stats: Arc, @@ -106,6 +128,8 @@ pub struct Stats { me_crc_mismatch: AtomicU64, me_seq_mismatch: AtomicU64, me_endpoint_quarantine_total: AtomicU64, + me_endpoint_quarantine_unexpected_total: AtomicU64, + me_endpoint_quarantine_draining_suppressed_total: AtomicU64, me_kdf_drift_total: AtomicU64, me_kdf_port_only_drift_total: AtomicU64, me_hardswap_pending_reuse_total: AtomicU64, @@ -140,6 +164,44 @@ pub struct Stats { me_route_drop_queue_full: AtomicU64, me_route_drop_queue_full_base: AtomicU64, me_route_drop_queue_full_high: AtomicU64, + me_d2c_batches_total: AtomicU64, + me_d2c_batch_frames_total: AtomicU64, + me_d2c_batch_bytes_total: AtomicU64, + me_d2c_flush_reason_queue_drain_total: AtomicU64, + me_d2c_flush_reason_batch_frames_total: AtomicU64, + me_d2c_flush_reason_batch_bytes_total: AtomicU64, + me_d2c_flush_reason_max_delay_total: AtomicU64, + me_d2c_flush_reason_ack_immediate_total: AtomicU64, + me_d2c_flush_reason_close_total: AtomicU64, + me_d2c_data_frames_total: AtomicU64, + me_d2c_ack_frames_total: AtomicU64, + me_d2c_payload_bytes_total: AtomicU64, + me_d2c_write_mode_coalesced_total: AtomicU64, + me_d2c_write_mode_split_total: AtomicU64, + me_d2c_quota_reject_pre_write_total: AtomicU64, + me_d2c_quota_reject_post_write_total: AtomicU64, + me_d2c_frame_buf_shrink_total: AtomicU64, + me_d2c_frame_buf_shrink_bytes_total: AtomicU64, + me_d2c_batch_frames_bucket_1: AtomicU64, + me_d2c_batch_frames_bucket_2_4: AtomicU64, + me_d2c_batch_frames_bucket_5_8: AtomicU64, + me_d2c_batch_frames_bucket_9_16: AtomicU64, + me_d2c_batch_frames_bucket_17_32: AtomicU64, + me_d2c_batch_frames_bucket_gt_32: AtomicU64, + me_d2c_batch_bytes_bucket_0_1k: AtomicU64, + me_d2c_batch_bytes_bucket_1k_4k: AtomicU64, + me_d2c_batch_bytes_bucket_4k_16k: AtomicU64, + me_d2c_batch_bytes_bucket_16k_64k: AtomicU64, + me_d2c_batch_bytes_bucket_64k_128k: AtomicU64, + me_d2c_batch_bytes_bucket_gt_128k: AtomicU64, + me_d2c_flush_duration_us_bucket_0_50: AtomicU64, + me_d2c_flush_duration_us_bucket_51_200: AtomicU64, + me_d2c_flush_duration_us_bucket_201_1000: AtomicU64, + me_d2c_flush_duration_us_bucket_1001_5000: AtomicU64, + me_d2c_flush_duration_us_bucket_5001_20000: AtomicU64, + me_d2c_flush_duration_us_bucket_gt_20000: AtomicU64, + me_d2c_batch_timeout_armed_total: AtomicU64, + me_d2c_batch_timeout_fired_total: AtomicU64, me_writer_pick_sorted_rr_success_try_total: AtomicU64, me_writer_pick_sorted_rr_success_fallback_total: AtomicU64, me_writer_pick_sorted_rr_full_total: AtomicU64, @@ -174,14 +236,17 @@ pub struct Stats { me_writer_restored_same_endpoint_total: AtomicU64, me_writer_restored_fallback_total: AtomicU64, me_no_writer_failfast_total: AtomicU64, + me_hybrid_timeout_total: AtomicU64, me_async_recovery_trigger_total: AtomicU64, me_inline_recovery_total: AtomicU64, ip_reservation_rollback_tcp_limit_total: AtomicU64, ip_reservation_rollback_quota_limit_total: AtomicU64, + quota_write_fail_bytes_total: AtomicU64, + quota_write_fail_events_total: AtomicU64, telemetry_core_enabled: AtomicBool, telemetry_user_enabled: AtomicBool, telemetry_me_level: AtomicU8, - user_stats: DashMap, + user_stats: DashMap>, user_stats_last_cleanup_epoch_secs: AtomicU64, start_time: parking_lot::RwLock>, } @@ -194,9 +259,51 @@ pub struct UserStats { pub octets_to_client: AtomicU64, pub msgs_from_client: AtomicU64, pub msgs_to_client: AtomicU64, + /// Total bytes charged against per-user quota admission. + /// + /// This counter is the single source of truth for quota enforcement and + /// intentionally tracks attempted traffic, not guaranteed delivery. + pub quota_used: AtomicU64, pub last_seen_epoch_secs: AtomicU64, } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum QuotaReserveError { + LimitExceeded, + Contended, +} + +impl UserStats { + #[inline] + pub fn quota_used(&self) -> u64 { + self.quota_used.load(Ordering::Relaxed) + } + + /// Attempts one CAS reservation step against the quota counter. + /// + /// Callers control retry/yield policy. This primitive intentionally does + /// not block or sleep so both sync poll paths and async paths can wrap it + /// with their own contention strategy. + #[inline] + pub fn quota_try_reserve(&self, bytes: u64, limit: u64) -> Result { + let current = self.quota_used.load(Ordering::Relaxed); + if bytes > limit.saturating_sub(current) { + return Err(QuotaReserveError::LimitExceeded); + } + + let next = current.saturating_add(bytes); + match self.quota_used.compare_exchange_weak( + current, + next, + Ordering::Relaxed, + Ordering::Relaxed, + ) { + Ok(_) => Ok(next), + Err(_) => Err(QuotaReserveError::Contended), + } + } +} + impl Stats { pub fn new() -> Self { let stats = Self::default(); @@ -256,6 +363,74 @@ impl Stats { .store(Self::now_epoch_secs(), Ordering::Relaxed); } + pub(crate) fn get_or_create_user_stats_handle(&self, user: &str) -> Arc { + self.maybe_cleanup_user_stats(); + if let Some(existing) = self.user_stats.get(user) { + let handle = Arc::clone(existing.value()); + Self::touch_user_stats(handle.as_ref()); + return handle; + } + + let entry = self.user_stats.entry(user.to_string()).or_default(); + if entry.last_seen_epoch_secs.load(Ordering::Relaxed) == 0 { + Self::touch_user_stats(entry.value().as_ref()); + } + Arc::clone(entry.value()) + } + + #[inline] + pub(crate) fn add_user_octets_from_handle(&self, user_stats: &UserStats, bytes: u64) { + if !self.telemetry_user_enabled() { + return; + } + Self::touch_user_stats(user_stats); + user_stats + .octets_from_client + .fetch_add(bytes, Ordering::Relaxed); + } + + #[inline] + pub(crate) fn add_user_octets_to_handle(&self, user_stats: &UserStats, bytes: u64) { + if !self.telemetry_user_enabled() { + return; + } + Self::touch_user_stats(user_stats); + user_stats + .octets_to_client + .fetch_add(bytes, Ordering::Relaxed); + } + + #[inline] + pub(crate) fn increment_user_msgs_from_handle(&self, user_stats: &UserStats) { + if !self.telemetry_user_enabled() { + return; + } + Self::touch_user_stats(user_stats); + user_stats.msgs_from_client.fetch_add(1, Ordering::Relaxed); + } + + #[inline] + pub(crate) fn increment_user_msgs_to_handle(&self, user_stats: &UserStats) { + if !self.telemetry_user_enabled() { + return; + } + Self::touch_user_stats(user_stats); + user_stats.msgs_to_client.fetch_add(1, Ordering::Relaxed); + } + + /// Charges already committed bytes in a post-I/O path. + /// + /// This helper is intentionally separate from `quota_try_reserve` to avoid + /// mixing reserve and post-charge on a single I/O event. + #[inline] + pub(crate) fn quota_charge_post_write(&self, user_stats: &UserStats, bytes: u64) -> u64 { + Self::touch_user_stats(user_stats); + user_stats + .quota_used + .fetch_add(bytes, Ordering::Relaxed) + .saturating_add(bytes) + } + fn maybe_cleanup_user_stats(&self) { const USER_STATS_CLEANUP_INTERVAL_SECS: u64 = 60; const USER_STATS_IDLE_TTL_SECS: u64 = 24 * 60 * 60; @@ -594,6 +769,216 @@ impl Stats { .fetch_add(1, Ordering::Relaxed); } } + pub fn increment_me_d2c_batches_total(&self) { + if self.telemetry_me_allows_normal() { + self.me_d2c_batches_total.fetch_add(1, Ordering::Relaxed); + } + } + pub fn add_me_d2c_batch_frames_total(&self, frames: u64) { + if self.telemetry_me_allows_normal() { + self.me_d2c_batch_frames_total + .fetch_add(frames, Ordering::Relaxed); + } + } + pub fn add_me_d2c_batch_bytes_total(&self, bytes: u64) { + if self.telemetry_me_allows_normal() { + self.me_d2c_batch_bytes_total + .fetch_add(bytes, Ordering::Relaxed); + } + } + pub fn increment_me_d2c_flush_reason(&self, reason: MeD2cFlushReason) { + if !self.telemetry_me_allows_normal() { + return; + } + match reason { + MeD2cFlushReason::QueueDrain => { + self.me_d2c_flush_reason_queue_drain_total + .fetch_add(1, Ordering::Relaxed); + } + MeD2cFlushReason::BatchFrames => { + self.me_d2c_flush_reason_batch_frames_total + .fetch_add(1, Ordering::Relaxed); + } + MeD2cFlushReason::BatchBytes => { + self.me_d2c_flush_reason_batch_bytes_total + .fetch_add(1, Ordering::Relaxed); + } + MeD2cFlushReason::MaxDelay => { + self.me_d2c_flush_reason_max_delay_total + .fetch_add(1, Ordering::Relaxed); + } + MeD2cFlushReason::AckImmediate => { + self.me_d2c_flush_reason_ack_immediate_total + .fetch_add(1, Ordering::Relaxed); + } + MeD2cFlushReason::Close => { + self.me_d2c_flush_reason_close_total + .fetch_add(1, Ordering::Relaxed); + } + } + } + pub fn increment_me_d2c_data_frames_total(&self) { + if self.telemetry_me_allows_normal() { + self.me_d2c_data_frames_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_d2c_ack_frames_total(&self) { + if self.telemetry_me_allows_normal() { + self.me_d2c_ack_frames_total.fetch_add(1, Ordering::Relaxed); + } + } + pub fn add_me_d2c_payload_bytes_total(&self, bytes: u64) { + if self.telemetry_me_allows_normal() { + self.me_d2c_payload_bytes_total + .fetch_add(bytes, Ordering::Relaxed); + } + } + pub fn increment_me_d2c_write_mode(&self, mode: MeD2cWriteMode) { + if !self.telemetry_me_allows_normal() { + return; + } + match mode { + MeD2cWriteMode::Coalesced => { + self.me_d2c_write_mode_coalesced_total + .fetch_add(1, Ordering::Relaxed); + } + MeD2cWriteMode::Split => { + self.me_d2c_write_mode_split_total + .fetch_add(1, Ordering::Relaxed); + } + } + } + pub fn increment_me_d2c_quota_reject_total(&self, stage: MeD2cQuotaRejectStage) { + if !self.telemetry_me_allows_normal() { + return; + } + match stage { + MeD2cQuotaRejectStage::PreWrite => { + self.me_d2c_quota_reject_pre_write_total + .fetch_add(1, Ordering::Relaxed); + } + MeD2cQuotaRejectStage::PostWrite => { + self.me_d2c_quota_reject_post_write_total + .fetch_add(1, Ordering::Relaxed); + } + } + } + pub fn observe_me_d2c_frame_buf_shrink(&self, bytes_freed: u64) { + if !self.telemetry_me_allows_normal() { + return; + } + self.me_d2c_frame_buf_shrink_total + .fetch_add(1, Ordering::Relaxed); + self.me_d2c_frame_buf_shrink_bytes_total + .fetch_add(bytes_freed, Ordering::Relaxed); + } + pub fn observe_me_d2c_batch_frames(&self, frames: u64) { + if !self.telemetry_me_allows_debug() { + return; + } + match frames { + 0 => {} + 1 => { + self.me_d2c_batch_frames_bucket_1 + .fetch_add(1, Ordering::Relaxed); + } + 2..=4 => { + self.me_d2c_batch_frames_bucket_2_4 + .fetch_add(1, Ordering::Relaxed); + } + 5..=8 => { + self.me_d2c_batch_frames_bucket_5_8 + .fetch_add(1, Ordering::Relaxed); + } + 9..=16 => { + self.me_d2c_batch_frames_bucket_9_16 + .fetch_add(1, Ordering::Relaxed); + } + 17..=32 => { + self.me_d2c_batch_frames_bucket_17_32 + .fetch_add(1, Ordering::Relaxed); + } + _ => { + self.me_d2c_batch_frames_bucket_gt_32 + .fetch_add(1, Ordering::Relaxed); + } + } + } + pub fn observe_me_d2c_batch_bytes(&self, bytes: u64) { + if !self.telemetry_me_allows_debug() { + return; + } + match bytes { + 0..=1024 => { + self.me_d2c_batch_bytes_bucket_0_1k + .fetch_add(1, Ordering::Relaxed); + } + 1025..=4096 => { + self.me_d2c_batch_bytes_bucket_1k_4k + .fetch_add(1, Ordering::Relaxed); + } + 4097..=16_384 => { + self.me_d2c_batch_bytes_bucket_4k_16k + .fetch_add(1, Ordering::Relaxed); + } + 16_385..=65_536 => { + self.me_d2c_batch_bytes_bucket_16k_64k + .fetch_add(1, Ordering::Relaxed); + } + 65_537..=131_072 => { + self.me_d2c_batch_bytes_bucket_64k_128k + .fetch_add(1, Ordering::Relaxed); + } + _ => { + self.me_d2c_batch_bytes_bucket_gt_128k + .fetch_add(1, Ordering::Relaxed); + } + } + } + pub fn observe_me_d2c_flush_duration_us(&self, duration_us: u64) { + if !self.telemetry_me_allows_debug() { + return; + } + match duration_us { + 0..=50 => { + self.me_d2c_flush_duration_us_bucket_0_50 + .fetch_add(1, Ordering::Relaxed); + } + 51..=200 => { + self.me_d2c_flush_duration_us_bucket_51_200 + .fetch_add(1, Ordering::Relaxed); + } + 201..=1000 => { + self.me_d2c_flush_duration_us_bucket_201_1000 + .fetch_add(1, Ordering::Relaxed); + } + 1001..=5000 => { + self.me_d2c_flush_duration_us_bucket_1001_5000 + .fetch_add(1, Ordering::Relaxed); + } + 5001..=20_000 => { + self.me_d2c_flush_duration_us_bucket_5001_20000 + .fetch_add(1, Ordering::Relaxed); + } + _ => { + self.me_d2c_flush_duration_us_bucket_gt_20000 + .fetch_add(1, Ordering::Relaxed); + } + } + } + pub fn increment_me_d2c_batch_timeout_armed_total(&self) { + if self.telemetry_me_allows_debug() { + self.me_d2c_batch_timeout_armed_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_d2c_batch_timeout_fired_total(&self) { + if self.telemetry_me_allows_debug() { + self.me_d2c_batch_timeout_fired_total + .fetch_add(1, Ordering::Relaxed); + } + } pub fn increment_me_writer_pick_success_try_total(&self, mode: MeWriterPickMode) { if !self.telemetry_me_allows_normal() { return; @@ -821,6 +1206,11 @@ impl Stats { .fetch_add(1, Ordering::Relaxed); } } + pub fn increment_me_hybrid_timeout_total(&self) { + if self.telemetry_me_allows_normal() { + self.me_hybrid_timeout_total.fetch_add(1, Ordering::Relaxed); + } + } pub fn increment_me_async_recovery_trigger_total(&self) { if self.telemetry_me_allows_normal() { self.me_async_recovery_trigger_total @@ -845,12 +1235,36 @@ impl Stats { .fetch_add(1, Ordering::Relaxed); } } + pub fn add_quota_write_fail_bytes_total(&self, bytes: u64) { + if self.telemetry_core_enabled() { + self.quota_write_fail_bytes_total + .fetch_add(bytes, Ordering::Relaxed); + } + } + pub fn increment_quota_write_fail_events_total(&self) { + if self.telemetry_core_enabled() { + self.quota_write_fail_events_total + .fetch_add(1, Ordering::Relaxed); + } + } pub fn increment_me_endpoint_quarantine_total(&self) { if self.telemetry_me_allows_normal() { self.me_endpoint_quarantine_total .fetch_add(1, Ordering::Relaxed); } } + pub fn increment_me_endpoint_quarantine_unexpected_total(&self) { + if self.telemetry_me_allows_normal() { + self.me_endpoint_quarantine_unexpected_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_endpoint_quarantine_draining_suppressed_total(&self) { + if self.telemetry_me_allows_normal() { + self.me_endpoint_quarantine_draining_suppressed_total + .fetch_add(1, Ordering::Relaxed); + } + } pub fn increment_me_kdf_drift_total(&self) { if self.telemetry_me_allows_normal() { self.me_kdf_drift_total.fetch_add(1, Ordering::Relaxed); @@ -1103,6 +1517,14 @@ impl Stats { pub fn get_me_endpoint_quarantine_total(&self) -> u64 { self.me_endpoint_quarantine_total.load(Ordering::Relaxed) } + pub fn get_me_endpoint_quarantine_unexpected_total(&self) -> u64 { + self.me_endpoint_quarantine_unexpected_total + .load(Ordering::Relaxed) + } + pub fn get_me_endpoint_quarantine_draining_suppressed_total(&self) -> u64 { + self.me_endpoint_quarantine_draining_suppressed_total + .load(Ordering::Relaxed) + } pub fn get_me_kdf_drift_total(&self) -> u64 { self.me_kdf_drift_total.load(Ordering::Relaxed) } @@ -1229,6 +1651,143 @@ impl Stats { pub fn get_me_route_drop_queue_full_high(&self) -> u64 { self.me_route_drop_queue_full_high.load(Ordering::Relaxed) } + pub fn get_me_d2c_batches_total(&self) -> u64 { + self.me_d2c_batches_total.load(Ordering::Relaxed) + } + pub fn get_me_d2c_batch_frames_total(&self) -> u64 { + self.me_d2c_batch_frames_total.load(Ordering::Relaxed) + } + pub fn get_me_d2c_batch_bytes_total(&self) -> u64 { + self.me_d2c_batch_bytes_total.load(Ordering::Relaxed) + } + pub fn get_me_d2c_flush_reason_queue_drain_total(&self) -> u64 { + self.me_d2c_flush_reason_queue_drain_total + .load(Ordering::Relaxed) + } + pub fn get_me_d2c_flush_reason_batch_frames_total(&self) -> u64 { + self.me_d2c_flush_reason_batch_frames_total + .load(Ordering::Relaxed) + } + pub fn get_me_d2c_flush_reason_batch_bytes_total(&self) -> u64 { + self.me_d2c_flush_reason_batch_bytes_total + .load(Ordering::Relaxed) + } + pub fn get_me_d2c_flush_reason_max_delay_total(&self) -> u64 { + self.me_d2c_flush_reason_max_delay_total + .load(Ordering::Relaxed) + } + pub fn get_me_d2c_flush_reason_ack_immediate_total(&self) -> u64 { + self.me_d2c_flush_reason_ack_immediate_total + .load(Ordering::Relaxed) + } + pub fn get_me_d2c_flush_reason_close_total(&self) -> u64 { + self.me_d2c_flush_reason_close_total.load(Ordering::Relaxed) + } + pub fn get_me_d2c_data_frames_total(&self) -> u64 { + self.me_d2c_data_frames_total.load(Ordering::Relaxed) + } + pub fn get_me_d2c_ack_frames_total(&self) -> u64 { + self.me_d2c_ack_frames_total.load(Ordering::Relaxed) + } + pub fn get_me_d2c_payload_bytes_total(&self) -> u64 { + self.me_d2c_payload_bytes_total.load(Ordering::Relaxed) + } + pub fn get_me_d2c_write_mode_coalesced_total(&self) -> u64 { + self.me_d2c_write_mode_coalesced_total + .load(Ordering::Relaxed) + } + pub fn get_me_d2c_write_mode_split_total(&self) -> u64 { + self.me_d2c_write_mode_split_total.load(Ordering::Relaxed) + } + pub fn get_me_d2c_quota_reject_pre_write_total(&self) -> u64 { + self.me_d2c_quota_reject_pre_write_total + .load(Ordering::Relaxed) + } + pub fn get_me_d2c_quota_reject_post_write_total(&self) -> u64 { + self.me_d2c_quota_reject_post_write_total + .load(Ordering::Relaxed) + } + pub fn get_me_d2c_frame_buf_shrink_total(&self) -> u64 { + self.me_d2c_frame_buf_shrink_total.load(Ordering::Relaxed) + } + pub fn get_me_d2c_frame_buf_shrink_bytes_total(&self) -> u64 { + self.me_d2c_frame_buf_shrink_bytes_total + .load(Ordering::Relaxed) + } + pub fn get_me_d2c_batch_frames_bucket_1(&self) -> u64 { + self.me_d2c_batch_frames_bucket_1.load(Ordering::Relaxed) + } + pub fn get_me_d2c_batch_frames_bucket_2_4(&self) -> u64 { + self.me_d2c_batch_frames_bucket_2_4.load(Ordering::Relaxed) + } + pub fn get_me_d2c_batch_frames_bucket_5_8(&self) -> u64 { + self.me_d2c_batch_frames_bucket_5_8.load(Ordering::Relaxed) + } + pub fn get_me_d2c_batch_frames_bucket_9_16(&self) -> u64 { + self.me_d2c_batch_frames_bucket_9_16.load(Ordering::Relaxed) + } + pub fn get_me_d2c_batch_frames_bucket_17_32(&self) -> u64 { + self.me_d2c_batch_frames_bucket_17_32 + .load(Ordering::Relaxed) + } + pub fn get_me_d2c_batch_frames_bucket_gt_32(&self) -> u64 { + self.me_d2c_batch_frames_bucket_gt_32 + .load(Ordering::Relaxed) + } + pub fn get_me_d2c_batch_bytes_bucket_0_1k(&self) -> u64 { + self.me_d2c_batch_bytes_bucket_0_1k.load(Ordering::Relaxed) + } + pub fn get_me_d2c_batch_bytes_bucket_1k_4k(&self) -> u64 { + self.me_d2c_batch_bytes_bucket_1k_4k.load(Ordering::Relaxed) + } + pub fn get_me_d2c_batch_bytes_bucket_4k_16k(&self) -> u64 { + self.me_d2c_batch_bytes_bucket_4k_16k + .load(Ordering::Relaxed) + } + pub fn get_me_d2c_batch_bytes_bucket_16k_64k(&self) -> u64 { + self.me_d2c_batch_bytes_bucket_16k_64k + .load(Ordering::Relaxed) + } + pub fn get_me_d2c_batch_bytes_bucket_64k_128k(&self) -> u64 { + self.me_d2c_batch_bytes_bucket_64k_128k + .load(Ordering::Relaxed) + } + pub fn get_me_d2c_batch_bytes_bucket_gt_128k(&self) -> u64 { + self.me_d2c_batch_bytes_bucket_gt_128k + .load(Ordering::Relaxed) + } + pub fn get_me_d2c_flush_duration_us_bucket_0_50(&self) -> u64 { + self.me_d2c_flush_duration_us_bucket_0_50 + .load(Ordering::Relaxed) + } + pub fn get_me_d2c_flush_duration_us_bucket_51_200(&self) -> u64 { + self.me_d2c_flush_duration_us_bucket_51_200 + .load(Ordering::Relaxed) + } + pub fn get_me_d2c_flush_duration_us_bucket_201_1000(&self) -> u64 { + self.me_d2c_flush_duration_us_bucket_201_1000 + .load(Ordering::Relaxed) + } + pub fn get_me_d2c_flush_duration_us_bucket_1001_5000(&self) -> u64 { + self.me_d2c_flush_duration_us_bucket_1001_5000 + .load(Ordering::Relaxed) + } + pub fn get_me_d2c_flush_duration_us_bucket_5001_20000(&self) -> u64 { + self.me_d2c_flush_duration_us_bucket_5001_20000 + .load(Ordering::Relaxed) + } + pub fn get_me_d2c_flush_duration_us_bucket_gt_20000(&self) -> u64 { + self.me_d2c_flush_duration_us_bucket_gt_20000 + .load(Ordering::Relaxed) + } + pub fn get_me_d2c_batch_timeout_armed_total(&self) -> u64 { + self.me_d2c_batch_timeout_armed_total + .load(Ordering::Relaxed) + } + pub fn get_me_d2c_batch_timeout_fired_total(&self) -> u64 { + self.me_d2c_batch_timeout_fired_total + .load(Ordering::Relaxed) + } pub fn get_me_writer_pick_sorted_rr_success_try_total(&self) -> u64 { self.me_writer_pick_sorted_rr_success_try_total .load(Ordering::Relaxed) @@ -1345,6 +1904,9 @@ impl Stats { pub fn get_me_no_writer_failfast_total(&self) -> u64 { self.me_no_writer_failfast_total.load(Ordering::Relaxed) } + pub fn get_me_hybrid_timeout_total(&self) -> u64 { + self.me_hybrid_timeout_total.load(Ordering::Relaxed) + } pub fn get_me_async_recovery_trigger_total(&self) -> u64 { self.me_async_recovery_trigger_total.load(Ordering::Relaxed) } @@ -1359,19 +1921,19 @@ impl Stats { self.ip_reservation_rollback_quota_limit_total .load(Ordering::Relaxed) } + pub fn get_quota_write_fail_bytes_total(&self) -> u64 { + self.quota_write_fail_bytes_total.load(Ordering::Relaxed) + } + pub fn get_quota_write_fail_events_total(&self) -> u64 { + self.quota_write_fail_events_total.load(Ordering::Relaxed) + } pub fn increment_user_connects(&self, user: &str) { if !self.telemetry_user_enabled() { return; } - self.maybe_cleanup_user_stats(); - if let Some(stats) = self.user_stats.get(user) { - Self::touch_user_stats(stats.value()); - stats.connects.fetch_add(1, Ordering::Relaxed); - return; - } - let stats = self.user_stats.entry(user.to_string()).or_default(); - Self::touch_user_stats(stats.value()); + let stats = self.get_or_create_user_stats_handle(user); + Self::touch_user_stats(stats.as_ref()); stats.connects.fetch_add(1, Ordering::Relaxed); } @@ -1379,14 +1941,8 @@ impl Stats { if !self.telemetry_user_enabled() { return; } - self.maybe_cleanup_user_stats(); - if let Some(stats) = self.user_stats.get(user) { - Self::touch_user_stats(stats.value()); - stats.curr_connects.fetch_add(1, Ordering::Relaxed); - return; - } - let stats = self.user_stats.entry(user.to_string()).or_default(); - Self::touch_user_stats(stats.value()); + let stats = self.get_or_create_user_stats_handle(user); + Self::touch_user_stats(stats.as_ref()); stats.curr_connects.fetch_add(1, Ordering::Relaxed); } @@ -1395,9 +1951,8 @@ impl Stats { return true; } - self.maybe_cleanup_user_stats(); - let stats = self.user_stats.entry(user.to_string()).or_default(); - Self::touch_user_stats(stats.value()); + let stats = self.get_or_create_user_stats_handle(user); + Self::touch_user_stats(stats.as_ref()); let counter = &stats.curr_connects; let mut current = counter.load(Ordering::Relaxed); @@ -1422,7 +1977,7 @@ impl Stats { pub fn decrement_user_curr_connects(&self, user: &str) { self.maybe_cleanup_user_stats(); if let Some(stats) = self.user_stats.get(user) { - Self::touch_user_stats(stats.value()); + Self::touch_user_stats(stats.value().as_ref()); let counter = &stats.curr_connects; let mut current = counter.load(Ordering::Relaxed); loop { @@ -1453,60 +2008,32 @@ impl Stats { if !self.telemetry_user_enabled() { return; } - self.maybe_cleanup_user_stats(); - if let Some(stats) = self.user_stats.get(user) { - Self::touch_user_stats(stats.value()); - stats.octets_from_client.fetch_add(bytes, Ordering::Relaxed); - return; - } - let stats = self.user_stats.entry(user.to_string()).or_default(); - Self::touch_user_stats(stats.value()); - stats.octets_from_client.fetch_add(bytes, Ordering::Relaxed); + let stats = self.get_or_create_user_stats_handle(user); + self.add_user_octets_from_handle(stats.as_ref(), bytes); } pub fn add_user_octets_to(&self, user: &str, bytes: u64) { if !self.telemetry_user_enabled() { return; } - self.maybe_cleanup_user_stats(); - if let Some(stats) = self.user_stats.get(user) { - Self::touch_user_stats(stats.value()); - stats.octets_to_client.fetch_add(bytes, Ordering::Relaxed); - return; - } - let stats = self.user_stats.entry(user.to_string()).or_default(); - Self::touch_user_stats(stats.value()); - stats.octets_to_client.fetch_add(bytes, Ordering::Relaxed); + let stats = self.get_or_create_user_stats_handle(user); + self.add_user_octets_to_handle(stats.as_ref(), bytes); } pub fn increment_user_msgs_from(&self, user: &str) { if !self.telemetry_user_enabled() { return; } - self.maybe_cleanup_user_stats(); - if let Some(stats) = self.user_stats.get(user) { - Self::touch_user_stats(stats.value()); - stats.msgs_from_client.fetch_add(1, Ordering::Relaxed); - return; - } - let stats = self.user_stats.entry(user.to_string()).or_default(); - Self::touch_user_stats(stats.value()); - stats.msgs_from_client.fetch_add(1, Ordering::Relaxed); + let stats = self.get_or_create_user_stats_handle(user); + self.increment_user_msgs_from_handle(stats.as_ref()); } pub fn increment_user_msgs_to(&self, user: &str) { if !self.telemetry_user_enabled() { return; } - self.maybe_cleanup_user_stats(); - if let Some(stats) = self.user_stats.get(user) { - Self::touch_user_stats(stats.value()); - stats.msgs_to_client.fetch_add(1, Ordering::Relaxed); - return; - } - let stats = self.user_stats.entry(user.to_string()).or_default(); - Self::touch_user_stats(stats.value()); - stats.msgs_to_client.fetch_add(1, Ordering::Relaxed); + let stats = self.get_or_create_user_stats_handle(user); + self.increment_user_msgs_to_handle(stats.as_ref()); } pub fn get_user_total_octets(&self, user: &str) -> u64 { @@ -1519,6 +2046,13 @@ impl Stats { .unwrap_or(0) } + pub fn get_user_quota_used(&self, user: &str) -> u64 { + self.user_stats + .get(user) + .map(|s| s.quota_used.load(Ordering::Relaxed)) + .unwrap_or(0) + } + pub fn get_handshake_timeouts(&self) -> u64 { self.handshake_timeouts.load(Ordering::Relaxed) } @@ -1584,7 +2118,7 @@ impl Stats { .load(Ordering::Relaxed) } - pub fn iter_user_stats(&self) -> dashmap::iter::Iter<'_, String, UserStats> { + pub fn iter_user_stats(&self) -> dashmap::iter::Iter<'_, String, Arc> { self.user_stats.iter() } @@ -1732,6 +2266,22 @@ impl ReplayChecker { found } + fn check_only_internal( + &self, + data: &[u8], + shards: &[Mutex], + window: Duration, + ) -> bool { + self.checks.fetch_add(1, Ordering::Relaxed); + let idx = self.get_shard_idx(data); + let mut shard = shards[idx].lock(); + let found = shard.check(data, Instant::now(), window); + if found { + self.hits.fetch_add(1, Ordering::Relaxed); + } + found + } + fn add_only(&self, data: &[u8], shards: &[Mutex], window: Duration) { self.additions.fetch_add(1, Ordering::Relaxed); let idx = self.get_shard_idx(data); @@ -1755,7 +2305,7 @@ impl ReplayChecker { self.add_only(data, &self.handshake_shards, self.window) } pub fn check_tls_digest(&self, data: &[u8]) -> bool { - self.check_and_add_tls_digest(data) + self.check_only_internal(data, &self.tls_shards, self.tls_window) } pub fn add_tls_digest(&self, data: &[u8]) { self.add_only(data, &self.tls_shards, self.tls_window) @@ -1859,6 +2409,7 @@ mod tests { use super::*; use crate::config::MeTelemetryLevel; use std::sync::Arc; + use std::sync::atomic::{AtomicU64, Ordering}; #[test] fn test_stats_shared_counters() { @@ -1898,9 +2449,83 @@ mod tests { stats.increment_me_crc_mismatch(); stats.increment_me_keepalive_sent(); stats.increment_me_route_drop_queue_full(); + stats.increment_me_d2c_batches_total(); + stats.add_me_d2c_batch_frames_total(4); + stats.add_me_d2c_batch_bytes_total(4096); + stats.increment_me_d2c_flush_reason(MeD2cFlushReason::BatchBytes); + stats.increment_me_d2c_write_mode(MeD2cWriteMode::Coalesced); + stats.increment_me_d2c_quota_reject_total(MeD2cQuotaRejectStage::PreWrite); + stats.observe_me_d2c_frame_buf_shrink(1024); + stats.observe_me_d2c_batch_frames(4); + stats.observe_me_d2c_batch_bytes(4096); + stats.observe_me_d2c_flush_duration_us(120); + stats.increment_me_d2c_batch_timeout_armed_total(); + stats.increment_me_d2c_batch_timeout_fired_total(); assert_eq!(stats.get_me_crc_mismatch(), 0); assert_eq!(stats.get_me_keepalive_sent(), 0); assert_eq!(stats.get_me_route_drop_queue_full(), 0); + assert_eq!(stats.get_me_d2c_batches_total(), 0); + assert_eq!(stats.get_me_d2c_flush_reason_batch_bytes_total(), 0); + assert_eq!(stats.get_me_d2c_write_mode_coalesced_total(), 0); + assert_eq!(stats.get_me_d2c_quota_reject_pre_write_total(), 0); + assert_eq!(stats.get_me_d2c_frame_buf_shrink_total(), 0); + assert_eq!(stats.get_me_d2c_batch_frames_bucket_2_4(), 0); + assert_eq!(stats.get_me_d2c_batch_bytes_bucket_1k_4k(), 0); + assert_eq!(stats.get_me_d2c_flush_duration_us_bucket_51_200(), 0); + assert_eq!(stats.get_me_d2c_batch_timeout_armed_total(), 0); + assert_eq!(stats.get_me_d2c_batch_timeout_fired_total(), 0); + } + + #[test] + fn test_telemetry_policy_me_normal_blocks_d2c_debug_metrics() { + let stats = Stats::new(); + stats.apply_telemetry_policy(TelemetryPolicy { + core_enabled: true, + user_enabled: true, + me_level: MeTelemetryLevel::Normal, + }); + + stats.increment_me_d2c_batches_total(); + stats.add_me_d2c_batch_frames_total(2); + stats.add_me_d2c_batch_bytes_total(2048); + stats.increment_me_d2c_flush_reason(MeD2cFlushReason::QueueDrain); + stats.observe_me_d2c_batch_frames(2); + stats.observe_me_d2c_batch_bytes(2048); + stats.observe_me_d2c_flush_duration_us(100); + stats.increment_me_d2c_batch_timeout_armed_total(); + stats.increment_me_d2c_batch_timeout_fired_total(); + + assert_eq!(stats.get_me_d2c_batches_total(), 1); + assert_eq!(stats.get_me_d2c_batch_frames_total(), 2); + assert_eq!(stats.get_me_d2c_batch_bytes_total(), 2048); + assert_eq!(stats.get_me_d2c_flush_reason_queue_drain_total(), 1); + assert_eq!(stats.get_me_d2c_batch_frames_bucket_2_4(), 0); + assert_eq!(stats.get_me_d2c_batch_bytes_bucket_1k_4k(), 0); + assert_eq!(stats.get_me_d2c_flush_duration_us_bucket_51_200(), 0); + assert_eq!(stats.get_me_d2c_batch_timeout_armed_total(), 0); + assert_eq!(stats.get_me_d2c_batch_timeout_fired_total(), 0); + } + + #[test] + fn test_telemetry_policy_me_debug_enables_d2c_debug_metrics() { + let stats = Stats::new(); + stats.apply_telemetry_policy(TelemetryPolicy { + core_enabled: true, + user_enabled: true, + me_level: MeTelemetryLevel::Debug, + }); + + stats.observe_me_d2c_batch_frames(7); + stats.observe_me_d2c_batch_bytes(70_000); + stats.observe_me_d2c_flush_duration_us(1400); + stats.increment_me_d2c_batch_timeout_armed_total(); + stats.increment_me_d2c_batch_timeout_fired_total(); + + assert_eq!(stats.get_me_d2c_batch_frames_bucket_5_8(), 1); + assert_eq!(stats.get_me_d2c_batch_bytes_bucket_64k_128k(), 1); + assert_eq!(stats.get_me_d2c_flush_duration_us_bucket_1001_5000(), 1); + assert_eq!(stats.get_me_d2c_batch_timeout_armed_total(), 1); + assert_eq!(stats.get_me_d2c_batch_timeout_fired_total(), 1); } #[test] @@ -1952,6 +2577,137 @@ mod tests { } assert_eq!(checker.stats().total_entries, 500); } + + #[test] + fn test_quota_reserve_under_contention_hits_limit_exactly() { + let user_stats = Arc::new(UserStats::default()); + let successes = Arc::new(AtomicU64::new(0)); + let limit = 8_192u64; + let mut workers = Vec::new(); + + for _ in 0..8 { + let user_stats = user_stats.clone(); + let successes = successes.clone(); + workers.push(std::thread::spawn(move || { + loop { + match user_stats.quota_try_reserve(1, limit) { + Ok(_) => { + successes.fetch_add(1, Ordering::Relaxed); + } + Err(QuotaReserveError::Contended) => { + std::hint::spin_loop(); + } + Err(QuotaReserveError::LimitExceeded) => { + break; + } + } + } + })); + } + + for worker in workers { + worker.join().expect("worker thread must finish"); + } + + assert_eq!( + successes.load(Ordering::Relaxed), + limit, + "successful reservations must stop exactly at limit" + ); + assert_eq!(user_stats.quota_used(), limit); + } + + #[test] + fn test_quota_reserve_200x_1k_reaches_100k_without_overshoot() { + let user_stats = Arc::new(UserStats::default()); + let successes = Arc::new(AtomicU64::new(0)); + let failures = Arc::new(AtomicU64::new(0)); + let attempts = 200usize; + let reserve_bytes = 1_024u64; + let limit = 100 * 1_024u64; + let mut workers = Vec::with_capacity(attempts); + + for _ in 0..attempts { + let user_stats = user_stats.clone(); + let successes = successes.clone(); + let failures = failures.clone(); + workers.push(std::thread::spawn(move || { + loop { + match user_stats.quota_try_reserve(reserve_bytes, limit) { + Ok(_) => { + successes.fetch_add(1, Ordering::Relaxed); + return; + } + Err(QuotaReserveError::LimitExceeded) => { + failures.fetch_add(1, Ordering::Relaxed); + return; + } + Err(QuotaReserveError::Contended) => { + std::hint::spin_loop(); + } + } + } + })); + } + + for worker in workers { + worker.join().expect("reservation worker must finish"); + } + + assert_eq!( + successes.load(Ordering::Relaxed), + 100, + "exactly 100 reservations of 1 KiB must fit into a 100 KiB quota" + ); + assert_eq!( + failures.load(Ordering::Relaxed), + 100, + "remaining workers must fail once quota is fully reserved" + ); + assert_eq!(user_stats.quota_used(), limit); + } + + #[test] + fn test_quota_used_is_authoritative_and_independent_from_octets_telemetry() { + let stats = Stats::new(); + let user = "quota-authoritative-user"; + let user_stats = stats.get_or_create_user_stats_handle(user); + + stats.add_user_octets_to_handle(&user_stats, 5); + assert_eq!(stats.get_user_total_octets(user), 5); + assert_eq!(stats.get_user_quota_used(user), 0); + + stats.quota_charge_post_write(&user_stats, 7); + assert_eq!(stats.get_user_total_octets(user), 5); + assert_eq!(stats.get_user_quota_used(user), 7); + } + + #[test] + fn test_cached_handle_survives_map_cleanup_until_last_drop() { + let stats = Stats::new(); + let user = "quota-handle-lifetime-user"; + let user_stats = stats.get_or_create_user_stats_handle(user); + let weak = Arc::downgrade(&user_stats); + + stats.user_stats.remove(user); + assert!( + stats.user_stats.get(user).is_none(), + "map cleanup should remove idle entry" + ); + assert!( + weak.upgrade().is_some(), + "cached handle must keep user stats object alive after map removal" + ); + + stats.quota_charge_post_write(user_stats.as_ref(), 3); + assert_eq!(user_stats.quota_used(), 3); + + drop(user_stats); + assert!( + weak.upgrade().is_none(), + "user stats object must be dropped after the last cached handle is released" + ); + } } #[cfg(test)] diff --git a/src/stream/frame_stream_padding_security_tests.rs b/src/stream/frame_stream_padding_security_tests.rs index 83b30f9..1ec787e 100644 --- a/src/stream/frame_stream_padding_security_tests.rs +++ b/src/stream/frame_stream_padding_security_tests.rs @@ -14,7 +14,10 @@ fn padding_rounding_equivalent_for_extensive_safe_domain() { let old = old_padding_round_up_to_4(len).expect("old expression must be safe"); let new = new_padding_round_up_to_4(len).expect("new expression must be safe"); assert_eq!(old, new, "mismatch for len={len}"); - assert!(new >= len, "rounded length must not shrink: len={len}, out={new}"); + assert!( + new >= len, + "rounded length must not shrink: len={len}, out={new}" + ); assert_eq!(new % 4, 0, "rounded length must stay 4-byte aligned"); } } diff --git a/src/tests/ip_tracker_encapsulation_adversarial_tests.rs b/src/tests/ip_tracker_encapsulation_adversarial_tests.rs index cf42e75..3fc9727 100644 --- a/src/tests/ip_tracker_encapsulation_adversarial_tests.rs +++ b/src/tests/ip_tracker_encapsulation_adversarial_tests.rs @@ -44,7 +44,10 @@ async fn encapsulation_repeated_queue_poison_recovery_preserves_forward_progress let ip_primary = ip_from_idx(10_001); let ip_alt = ip_from_idx(10_002); - tracker.check_and_add("encap-poison", ip_primary).await.unwrap(); + tracker + .check_and_add("encap-poison", ip_primary) + .await + .unwrap(); for _ in 0..128 { let queue = tracker.cleanup_queue_mutex_for_tests(); diff --git a/src/tls_front/fetcher.rs b/src/tls_front/fetcher.rs index 4408b5a..45d56ce 100644 --- a/src/tls_front/fetcher.rs +++ b/src/tls_front/fetcher.rs @@ -1,7 +1,9 @@ #![allow(clippy::too_many_arguments)] +use dashmap::DashMap; use std::sync::Arc; -use std::time::Duration; +use std::sync::OnceLock; +use std::time::{Duration, Instant}; use anyhow::{Result, anyhow}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; @@ -21,7 +23,8 @@ use rustls::{DigitallySignedStruct, Error as RustlsError}; use x509_parser::certificate::X509Certificate; use x509_parser::prelude::FromDer; -use crate::crypto::SecureRandom; +use crate::config::TlsFetchProfile; +use crate::crypto::{SecureRandom, sha256}; use crate::network::dns_overrides::resolve_socket_addr; use crate::protocol::constants::{ TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER, TLS_RECORD_HANDSHAKE, @@ -78,6 +81,199 @@ impl ServerCertVerifier for NoVerify { } } +#[derive(Debug, Clone)] +pub struct TlsFetchStrategy { + pub profiles: Vec, + pub strict_route: bool, + pub attempt_timeout: Duration, + pub total_budget: Duration, + pub grease_enabled: bool, + pub deterministic: bool, + pub profile_cache_ttl: Duration, +} + +impl TlsFetchStrategy { + #[allow(dead_code)] + pub fn single_attempt(connect_timeout: Duration) -> Self { + Self { + profiles: vec![TlsFetchProfile::CompatTls12], + strict_route: false, + attempt_timeout: connect_timeout.max(Duration::from_millis(1)), + total_budget: connect_timeout.max(Duration::from_millis(1)), + grease_enabled: false, + deterministic: false, + profile_cache_ttl: Duration::ZERO, + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +struct ProfileCacheKey { + host: String, + port: u16, + sni: String, + scope: Option, + proxy_protocol: u8, + route_hint: RouteHint, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +enum RouteHint { + Direct, + Upstream, + Unix, +} + +#[derive(Debug, Clone, Copy)] +struct ProfileCacheValue { + profile: TlsFetchProfile, + updated_at: Instant, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum FetchErrorKind { + Connect, + Route, + EarlyEof, + Timeout, + ServerHelloMissing, + TlsAlert, + Parse, + Other, +} + +static PROFILE_CACHE: OnceLock> = OnceLock::new(); + +fn profile_cache() -> &'static DashMap { + PROFILE_CACHE.get_or_init(DashMap::new) +} + +fn route_hint( + upstream: Option<&std::sync::Arc>, + unix_sock: Option<&str>, +) -> RouteHint { + if unix_sock.is_some() { + RouteHint::Unix + } else if upstream.is_some() { + RouteHint::Upstream + } else { + RouteHint::Direct + } +} + +fn profile_cache_key( + host: &str, + port: u16, + sni: &str, + upstream: Option<&std::sync::Arc>, + scope: Option<&str>, + proxy_protocol: u8, + unix_sock: Option<&str>, +) -> ProfileCacheKey { + ProfileCacheKey { + host: host.to_string(), + port, + sni: sni.to_string(), + scope: scope.map(ToString::to_string), + proxy_protocol, + route_hint: route_hint(upstream, unix_sock), + } +} + +fn classify_fetch_error(err: &anyhow::Error) -> FetchErrorKind { + for cause in err.chain() { + if let Some(io) = cause.downcast_ref::() { + return match io.kind() { + std::io::ErrorKind::TimedOut => FetchErrorKind::Timeout, + std::io::ErrorKind::UnexpectedEof => FetchErrorKind::EarlyEof, + std::io::ErrorKind::ConnectionRefused + | std::io::ErrorKind::ConnectionAborted + | std::io::ErrorKind::ConnectionReset + | std::io::ErrorKind::NotConnected + | std::io::ErrorKind::AddrNotAvailable => FetchErrorKind::Connect, + _ => FetchErrorKind::Other, + }; + } + } + + let message = err.to_string().to_lowercase(); + if message.contains("upstream route") { + FetchErrorKind::Route + } else if message.contains("serverhello not received") { + FetchErrorKind::ServerHelloMissing + } else if message.contains("alert") { + FetchErrorKind::TlsAlert + } else if message.contains("parse") { + FetchErrorKind::Parse + } else if message.contains("timed out") || message.contains("deadline has elapsed") { + FetchErrorKind::Timeout + } else if message.contains("eof") { + FetchErrorKind::EarlyEof + } else { + FetchErrorKind::Other + } +} + +fn order_profiles( + strategy: &TlsFetchStrategy, + cache_key: Option<&ProfileCacheKey>, + now: Instant, +) -> Vec { + let mut ordered = if strategy.profiles.is_empty() { + vec![TlsFetchProfile::CompatTls12] + } else { + strategy.profiles.clone() + }; + + if strategy.profile_cache_ttl.is_zero() { + return ordered; + } + + let Some(key) = cache_key else { + return ordered; + }; + + if let Some(cached) = profile_cache().get(key) { + let age = now.saturating_duration_since(cached.updated_at); + if age > strategy.profile_cache_ttl { + drop(cached); + profile_cache().remove(key); + return ordered; + } + + if let Some(pos) = ordered + .iter() + .position(|profile| *profile == cached.profile) + && pos != 0 + { + ordered.swap(0, pos); + } + } + + ordered +} + +fn remember_profile_success( + strategy: &TlsFetchStrategy, + cache_key: Option, + profile: TlsFetchProfile, + now: Instant, +) { + if strategy.profile_cache_ttl.is_zero() { + return; + } + let Some(key) = cache_key else { + return; + }; + profile_cache().insert( + key, + ProfileCacheValue { + profile, + updated_at: now, + }, + ); +} + fn build_client_config() -> Arc { let root = rustls::RootCertStore::empty(); @@ -95,7 +291,114 @@ fn build_client_config() -> Arc { Arc::new(config) } -fn build_client_hello(sni: &str, rng: &SecureRandom) -> Vec { +fn deterministic_bytes(seed: &str, len: usize) -> Vec { + let mut out = Vec::with_capacity(len); + let mut counter: u32 = 0; + while out.len() < len { + let mut chunk_seed = Vec::with_capacity(seed.len() + std::mem::size_of::()); + chunk_seed.extend_from_slice(seed.as_bytes()); + chunk_seed.extend_from_slice(&counter.to_le_bytes()); + out.extend_from_slice(&sha256(&chunk_seed)); + counter = counter.wrapping_add(1); + } + out.truncate(len); + out +} + +fn profile_cipher_suites(profile: TlsFetchProfile) -> &'static [u16] { + const MODERN_CHROME: &[u16] = &[ + 0x1301, 0x1302, 0x1303, 0xc02b, 0xc02c, 0xcca9, 0xc02f, 0xc030, 0xcca8, 0x009e, 0x00ff, + ]; + const MODERN_FIREFOX: &[u16] = &[ + 0x1301, 0x1303, 0x1302, 0xc02b, 0xcca9, 0xc02c, 0xc02f, 0xcca8, 0xc030, 0x009e, 0x00ff, + ]; + const COMPAT_TLS12: &[u16] = &[ + 0xc02b, 0xc02c, 0xc02f, 0xc030, 0xcca9, 0xcca8, 0x1301, 0x1302, 0x1303, 0x009e, 0x00ff, + ]; + const LEGACY_MINIMAL: &[u16] = &[0xc02b, 0xc02f, 0x1301, 0x1302, 0x00ff]; + + match profile { + TlsFetchProfile::ModernChromeLike => MODERN_CHROME, + TlsFetchProfile::ModernFirefoxLike => MODERN_FIREFOX, + TlsFetchProfile::CompatTls12 => COMPAT_TLS12, + TlsFetchProfile::LegacyMinimal => LEGACY_MINIMAL, + } +} + +fn profile_groups(profile: TlsFetchProfile) -> &'static [u16] { + const MODERN: &[u16] = &[0x001d, 0x0017, 0x0018]; // x25519, secp256r1, secp384r1 + const COMPAT: &[u16] = &[0x001d, 0x0017]; + const LEGACY: &[u16] = &[0x0017]; + + match profile { + TlsFetchProfile::ModernChromeLike | TlsFetchProfile::ModernFirefoxLike => MODERN, + TlsFetchProfile::CompatTls12 => COMPAT, + TlsFetchProfile::LegacyMinimal => LEGACY, + } +} + +fn profile_sig_algs(profile: TlsFetchProfile) -> &'static [u16] { + const MODERN: &[u16] = &[0x0804, 0x0805, 0x0403, 0x0503, 0x0806]; + const COMPAT: &[u16] = &[0x0403, 0x0503, 0x0804, 0x0805]; + const LEGACY: &[u16] = &[0x0403, 0x0804]; + + match profile { + TlsFetchProfile::ModernChromeLike | TlsFetchProfile::ModernFirefoxLike => MODERN, + TlsFetchProfile::CompatTls12 => COMPAT, + TlsFetchProfile::LegacyMinimal => LEGACY, + } +} + +fn profile_alpn(profile: TlsFetchProfile) -> &'static [&'static [u8]] { + const H2_HTTP11: &[&[u8]] = &[b"h2", b"http/1.1"]; + const HTTP11: &[&[u8]] = &[b"http/1.1"]; + match profile { + TlsFetchProfile::ModernChromeLike | TlsFetchProfile::ModernFirefoxLike => H2_HTTP11, + TlsFetchProfile::CompatTls12 | TlsFetchProfile::LegacyMinimal => HTTP11, + } +} + +fn profile_supported_versions(profile: TlsFetchProfile) -> &'static [u16] { + const MODERN: &[u16] = &[0x0304, 0x0303]; + const COMPAT: &[u16] = &[0x0303, 0x0304]; + const LEGACY: &[u16] = &[0x0303]; + match profile { + TlsFetchProfile::ModernChromeLike | TlsFetchProfile::ModernFirefoxLike => MODERN, + TlsFetchProfile::CompatTls12 => COMPAT, + TlsFetchProfile::LegacyMinimal => LEGACY, + } +} + +fn profile_padding_target(profile: TlsFetchProfile) -> usize { + match profile { + TlsFetchProfile::ModernChromeLike => 220, + TlsFetchProfile::ModernFirefoxLike => 200, + TlsFetchProfile::CompatTls12 => 180, + TlsFetchProfile::LegacyMinimal => 64, + } +} + +fn grease_value(rng: &SecureRandom, deterministic: bool, seed: &str) -> u16 { + const GREASE_VALUES: [u16; 16] = [ + 0x0a0a, 0x1a1a, 0x2a2a, 0x3a3a, 0x4a4a, 0x5a5a, 0x6a6a, 0x7a7a, 0x8a8a, 0x9a9a, 0xaaaa, + 0xbaba, 0xcaca, 0xdada, 0xeaea, 0xfafa, + ]; + if deterministic { + let idx = deterministic_bytes(seed, 1)[0] as usize % GREASE_VALUES.len(); + GREASE_VALUES[idx] + } else { + let idx = (rng.bytes(1)[0] as usize) % GREASE_VALUES.len(); + GREASE_VALUES[idx] + } +} + +fn build_client_hello( + sni: &str, + rng: &SecureRandom, + profile: TlsFetchProfile, + grease_enabled: bool, + deterministic: bool, +) -> Vec { // === ClientHello body === let mut body = Vec::new(); @@ -103,21 +406,24 @@ fn build_client_hello(sni: &str, rng: &SecureRandom) -> Vec { body.extend_from_slice(&[0x03, 0x03]); // Random - body.extend_from_slice(&rng.bytes(32)); + if deterministic { + body.extend_from_slice(&deterministic_bytes(&format!("tls-fetch-random:{sni}"), 32)); + } else { + body.extend_from_slice(&rng.bytes(32)); + } // Session ID: empty body.push(0); - // Cipher suites (common minimal set, TLS1.3 + a few 1.2 fallbacks) - let cipher_suites: [u8; 10] = [ - 0x13, 0x01, // TLS_AES_128_GCM_SHA256 - 0x13, 0x02, // TLS_AES_256_GCM_SHA384 - 0x13, 0x03, // TLS_CHACHA20_POLY1305_SHA256 - 0x00, 0x2f, // TLS_RSA_WITH_AES_128_CBC_SHA (legacy) - 0x00, 0xff, // RENEGOTIATION_INFO_SCSV - ]; - body.extend_from_slice(&(cipher_suites.len() as u16).to_be_bytes()); - body.extend_from_slice(&cipher_suites); + let mut cipher_suites = profile_cipher_suites(profile).to_vec(); + if grease_enabled { + let grease = grease_value(rng, deterministic, &format!("cipher:{sni}")); + cipher_suites.insert(0, grease); + } + body.extend_from_slice(&((cipher_suites.len() * 2) as u16).to_be_bytes()); + for suite in cipher_suites { + body.extend_from_slice(&suite.to_be_bytes()); + } // Compression methods: null only body.push(1); @@ -138,7 +444,11 @@ fn build_client_hello(sni: &str, rng: &SecureRandom) -> Vec { exts.extend_from_slice(&sni_ext); // supported_groups - let groups: [u16; 2] = [0x001d, 0x0017]; // x25519, secp256r1 + let mut groups = profile_groups(profile).to_vec(); + if grease_enabled { + let grease = grease_value(rng, deterministic, &format!("group:{sni}")); + groups.insert(0, grease); + } exts.extend_from_slice(&0x000au16.to_be_bytes()); exts.extend_from_slice(&((2 + groups.len() * 2) as u16).to_be_bytes()); exts.extend_from_slice(&(groups.len() as u16 * 2).to_be_bytes()); @@ -147,7 +457,11 @@ fn build_client_hello(sni: &str, rng: &SecureRandom) -> Vec { } // signature_algorithms - let sig_algs: [u16; 4] = [0x0804, 0x0805, 0x0403, 0x0503]; // rsa_pss_rsae_sha256/384, ecdsa_secp256r1_sha256, rsa_pkcs1_sha256 + let mut sig_algs = profile_sig_algs(profile).to_vec(); + if grease_enabled { + let grease = grease_value(rng, deterministic, &format!("sigalg:{sni}")); + sig_algs.insert(0, grease); + } exts.extend_from_slice(&0x000du16.to_be_bytes()); exts.extend_from_slice(&((2 + sig_algs.len() * 2) as u16).to_be_bytes()); exts.extend_from_slice(&(sig_algs.len() as u16 * 2).to_be_bytes()); @@ -155,8 +469,12 @@ fn build_client_hello(sni: &str, rng: &SecureRandom) -> Vec { exts.extend_from_slice(&a.to_be_bytes()); } - // supported_versions (TLS1.3 + TLS1.2) - let versions: [u16; 2] = [0x0304, 0x0303]; + // supported_versions + let mut versions = profile_supported_versions(profile).to_vec(); + if grease_enabled { + let grease = grease_value(rng, deterministic, &format!("version:{sni}")); + versions.insert(0, grease); + } exts.extend_from_slice(&0x002bu16.to_be_bytes()); exts.extend_from_slice(&((1 + versions.len() * 2) as u16).to_be_bytes()); exts.push((versions.len() * 2) as u8); @@ -165,7 +483,14 @@ fn build_client_hello(sni: &str, rng: &SecureRandom) -> Vec { } // key_share (x25519) - let key = gen_key_share(rng); + let key = if deterministic { + let det = deterministic_bytes(&format!("keyshare:{sni}"), 32); + let mut key = [0u8; 32]; + key.copy_from_slice(&det); + key + } else { + gen_key_share(rng) + }; let mut keyshare = Vec::with_capacity(4 + key.len()); keyshare.extend_from_slice(&0x001du16.to_be_bytes()); // group keyshare.extend_from_slice(&(key.len() as u16).to_be_bytes()); @@ -175,18 +500,29 @@ fn build_client_hello(sni: &str, rng: &SecureRandom) -> Vec { exts.extend_from_slice(&(keyshare.len() as u16).to_be_bytes()); exts.extend_from_slice(&keyshare); - // ALPN (http/1.1) - let alpn_proto = b"http/1.1"; - exts.extend_from_slice(&0x0010u16.to_be_bytes()); - exts.extend_from_slice(&((2 + 1 + alpn_proto.len()) as u16).to_be_bytes()); - exts.extend_from_slice(&((1 + alpn_proto.len()) as u16).to_be_bytes()); - exts.push(alpn_proto.len() as u8); - exts.extend_from_slice(alpn_proto); + // ALPN + let mut alpn_list = Vec::new(); + for proto in profile_alpn(profile) { + alpn_list.push(proto.len() as u8); + alpn_list.extend_from_slice(proto); + } + if !alpn_list.is_empty() { + exts.extend_from_slice(&0x0010u16.to_be_bytes()); + exts.extend_from_slice(&((2 + alpn_list.len()) as u16).to_be_bytes()); + exts.extend_from_slice(&(alpn_list.len() as u16).to_be_bytes()); + exts.extend_from_slice(&alpn_list); + } + + if grease_enabled { + let grease = grease_value(rng, deterministic, &format!("ext:{sni}")); + exts.extend_from_slice(&grease.to_be_bytes()); + exts.extend_from_slice(&0u16.to_be_bytes()); + } // padding to reduce recognizability and keep length ~500 bytes - const TARGET_EXT_LEN: usize = 180; - if exts.len() < TARGET_EXT_LEN { - let remaining = TARGET_EXT_LEN - exts.len(); + let target_ext_len = profile_padding_target(profile); + if exts.len() < target_ext_len { + let remaining = target_ext_len - exts.len(); if remaining > 4 { let pad_len = remaining - 4; // minus type+len exts.extend_from_slice(&0x0015u16.to_be_bytes()); // padding extension @@ -402,27 +738,41 @@ async fn connect_tcp_with_upstream( connect_timeout: Duration, upstream: Option>, scope: Option<&str>, + strict_route: bool, ) -> Result { if let Some(manager) = upstream { - if let Some(addr) = resolve_socket_addr(host, port) { - match manager.connect(addr, None, scope).await { - Ok(stream) => return Ok(stream), - Err(e) => { - warn!( - host = %host, - port = port, - scope = ?scope, - error = %e, - "Upstream connect failed, using direct connect" - ); - } - } - } else if let Ok(mut addrs) = tokio::net::lookup_host((host, port)).await - && let Some(addr) = addrs.find(|a| a.is_ipv4()) - { + let resolved = if let Some(addr) = resolve_socket_addr(host, port) { + Some(addr) + } else { + match tokio::net::lookup_host((host, port)).await { + Ok(mut addrs) => addrs.find(|a| a.is_ipv4()), + Err(e) => { + if strict_route { + return Err(anyhow!( + "upstream route DNS resolution failed for {host}:{port}: {e}" + )); + } + warn!( + host = %host, + port = port, + scope = ?scope, + error = %e, + "Upstream DNS resolution failed, using direct connect" + ); + None + } + } + }; + + if let Some(addr) = resolved { match manager.connect(addr, None, scope).await { Ok(stream) => return Ok(stream), Err(e) => { + if strict_route { + return Err(anyhow!( + "upstream route connect failed for {host}:{port}: {e}" + )); + } warn!( host = %host, port = port, @@ -432,6 +782,10 @@ async fn connect_tcp_with_upstream( ); } } + } else if strict_route { + return Err(anyhow!( + "upstream route resolution produced no usable address for {host}:{port}" + )); } } Ok(UpstreamStream::Tcp( @@ -471,12 +825,15 @@ async fn fetch_via_raw_tls_stream( sni: &str, connect_timeout: Duration, proxy_protocol: u8, + profile: TlsFetchProfile, + grease_enabled: bool, + deterministic: bool, ) -> Result where S: AsyncRead + AsyncWrite + Unpin, { let rng = SecureRandom::new(); - let client_hello = build_client_hello(sni, &rng); + let client_hello = build_client_hello(sni, &rng, profile, grease_enabled, deterministic); timeout(connect_timeout, async { if proxy_protocol > 0 { let header = match proxy_protocol { @@ -550,6 +907,10 @@ async fn fetch_via_raw_tls( scope: Option<&str>, proxy_protocol: u8, unix_sock: Option<&str>, + strict_route: bool, + profile: TlsFetchProfile, + grease_enabled: bool, + deterministic: bool, ) -> Result { #[cfg(unix)] if let Some(sock_path) = unix_sock { @@ -560,8 +921,16 @@ async fn fetch_via_raw_tls( sock = %sock_path, "Raw TLS fetch using mask unix socket" ); - return fetch_via_raw_tls_stream(stream, sni, connect_timeout, proxy_protocol) - .await; + return fetch_via_raw_tls_stream( + stream, + sni, + connect_timeout, + proxy_protocol, + profile, + grease_enabled, + deterministic, + ) + .await; } Ok(Err(e)) => { warn!( @@ -584,8 +953,19 @@ async fn fetch_via_raw_tls( #[cfg(not(unix))] let _ = unix_sock; - let stream = connect_tcp_with_upstream(host, port, connect_timeout, upstream, scope).await?; - fetch_via_raw_tls_stream(stream, sni, connect_timeout, proxy_protocol).await + let stream = + connect_tcp_with_upstream(host, port, connect_timeout, upstream, scope, strict_route) + .await?; + fetch_via_raw_tls_stream( + stream, + sni, + connect_timeout, + proxy_protocol, + profile, + grease_enabled, + deterministic, + ) + .await } async fn fetch_via_rustls_stream( @@ -691,6 +1071,7 @@ async fn fetch_via_rustls( scope: Option<&str>, proxy_protocol: u8, unix_sock: Option<&str>, + strict_route: bool, ) -> Result { #[cfg(unix)] if let Some(sock_path) = unix_sock { @@ -724,16 +1105,153 @@ async fn fetch_via_rustls( #[cfg(not(unix))] let _ = unix_sock; - let stream = connect_tcp_with_upstream(host, port, connect_timeout, upstream, scope).await?; + let stream = + connect_tcp_with_upstream(host, port, connect_timeout, upstream, scope, strict_route) + .await?; fetch_via_rustls_stream(stream, host, sni, proxy_protocol).await } -/// Fetch real TLS metadata for the given SNI. -/// -/// Strategy: -/// 1) Probe raw TLS for realistic ServerHello and ApplicationData record sizes. -/// 2) Fetch certificate chain via rustls to build cert payload. -/// 3) Merge both when possible; otherwise auto-fallback to whichever succeeded. +/// Fetch real TLS metadata with an adaptive multi-profile strategy. +pub async fn fetch_real_tls_with_strategy( + host: &str, + port: u16, + sni: &str, + strategy: &TlsFetchStrategy, + upstream: Option>, + scope: Option<&str>, + proxy_protocol: u8, + unix_sock: Option<&str>, +) -> Result { + let attempt_timeout = strategy.attempt_timeout.max(Duration::from_millis(1)); + let total_budget = strategy.total_budget.max(Duration::from_millis(1)); + let started_at = Instant::now(); + let cache_key = profile_cache_key( + host, + port, + sni, + upstream.as_ref(), + scope, + proxy_protocol, + unix_sock, + ); + let profiles = order_profiles(strategy, Some(&cache_key), started_at); + + let mut raw_result = None; + let mut raw_last_error: Option = None; + let mut raw_last_error_kind = FetchErrorKind::Other; + let mut selected_profile = None; + + for profile in profiles { + let elapsed = started_at.elapsed(); + if elapsed >= total_budget { + break; + } + let timeout_for_attempt = attempt_timeout.min(total_budget - elapsed); + + match fetch_via_raw_tls( + host, + port, + sni, + timeout_for_attempt, + upstream.clone(), + scope, + proxy_protocol, + unix_sock, + strategy.strict_route, + profile, + strategy.grease_enabled, + strategy.deterministic, + ) + .await + { + Ok(res) => { + selected_profile = Some(profile); + raw_result = Some(res); + break; + } + Err(err) => { + let kind = classify_fetch_error(&err); + warn!( + sni = %sni, + profile = profile.as_str(), + error_kind = ?kind, + error = %err, + "Raw TLS fetch attempt failed" + ); + raw_last_error_kind = kind; + raw_last_error = Some(err); + if strategy.strict_route && matches!(kind, FetchErrorKind::Route) { + break; + } + } + } + } + + if let Some(profile) = selected_profile { + remember_profile_success(strategy, Some(cache_key), profile, Instant::now()); + } + + if raw_result.is_none() + && strategy.strict_route + && matches!(raw_last_error_kind, FetchErrorKind::Route) + { + if let Some(err) = raw_last_error { + return Err(err); + } + return Err(anyhow!("TLS fetch strict-route failure")); + } + + let elapsed = started_at.elapsed(); + if elapsed >= total_budget { + return match raw_result { + Some(raw) => Ok(raw), + None => { + Err(raw_last_error.unwrap_or_else(|| anyhow!("TLS fetch total budget exhausted"))) + } + }; + } + + let rustls_timeout = attempt_timeout.min(total_budget - elapsed); + let rustls_result = fetch_via_rustls( + host, + port, + sni, + rustls_timeout, + upstream, + scope, + proxy_protocol, + unix_sock, + strategy.strict_route, + ) + .await; + + match rustls_result { + Ok(rustls) => { + if let Some(mut raw) = raw_result { + raw.cert_info = rustls.cert_info; + raw.cert_payload = rustls.cert_payload; + raw.behavior_profile.source = TlsProfileSource::Merged; + debug!(sni = %sni, "Fetched TLS metadata via adaptive raw probe + rustls cert chain"); + Ok(raw) + } else { + Ok(rustls) + } + } + Err(err) => { + if let Some(raw) = raw_result { + warn!(sni = %sni, error = %err, "Rustls cert fetch failed, using raw TLS metadata only"); + Ok(raw) + } else if let Some(raw_err) = raw_last_error { + Err(anyhow!("TLS fetch failed (raw: {raw_err}; rustls: {err})")) + } else { + Err(err) + } + } + } +} + +/// Fetch real TLS metadata for the given SNI using a single-attempt compatibility strategy. +#[allow(dead_code)] pub async fn fetch_real_tls( host: &str, port: u16, @@ -744,62 +1262,30 @@ pub async fn fetch_real_tls( proxy_protocol: u8, unix_sock: Option<&str>, ) -> Result { - let raw_result = match fetch_via_raw_tls( + let strategy = TlsFetchStrategy::single_attempt(connect_timeout); + fetch_real_tls_with_strategy( host, port, sni, - connect_timeout, - upstream.clone(), - scope, - proxy_protocol, - unix_sock, - ) - .await - { - Ok(res) => Some(res), - Err(e) => { - warn!(sni = %sni, error = %e, "Raw TLS fetch failed"); - None - } - }; - - match fetch_via_rustls( - host, - port, - sni, - connect_timeout, + &strategy, upstream, scope, proxy_protocol, unix_sock, ) .await - { - Ok(rustls_result) => { - if let Some(mut raw) = raw_result { - raw.cert_info = rustls_result.cert_info; - raw.cert_payload = rustls_result.cert_payload; - raw.behavior_profile.source = TlsProfileSource::Merged; - debug!(sni = %sni, "Fetched TLS metadata via raw probe + rustls cert chain"); - Ok(raw) - } else { - Ok(rustls_result) - } - } - Err(e) => { - if let Some(raw) = raw_result { - warn!(sni = %sni, error = %e, "Rustls cert fetch failed, using raw TLS metadata only"); - Ok(raw) - } else { - Err(e) - } - } - } } #[cfg(test)] mod tests { - use super::{derive_behavior_profile, encode_tls13_certificate_message}; + use std::time::{Duration, Instant}; + + use super::{ + ProfileCacheValue, TlsFetchStrategy, build_client_hello, derive_behavior_profile, + encode_tls13_certificate_message, order_profiles, profile_cache, profile_cache_key, + }; + use crate::config::TlsFetchProfile; + use crate::crypto::SecureRandom; use crate::protocol::constants::{ TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER, TLS_RECORD_HANDSHAKE, }; @@ -812,8 +1298,8 @@ mod tests { #[test] fn test_encode_tls13_certificate_message_single_cert() { let cert = vec![0x30, 0x03, 0x02, 0x01, 0x01]; - let message = encode_tls13_certificate_message(std::slice::from_ref(&cert)) - .expect("message"); + let message = + encode_tls13_certificate_message(std::slice::from_ref(&cert)).expect("message"); assert_eq!(message[0], 0x0b); assert_eq!(read_u24(&message[1..4]), message.len() - 4); @@ -848,4 +1334,93 @@ mod tests { assert_eq!(profile.ticket_record_sizes, vec![220, 180]); assert_eq!(profile.source, TlsProfileSource::Raw); } + + #[test] + fn test_order_profiles_prioritizes_fresh_cached_winner() { + let strategy = TlsFetchStrategy { + profiles: vec![ + TlsFetchProfile::ModernChromeLike, + TlsFetchProfile::CompatTls12, + TlsFetchProfile::LegacyMinimal, + ], + strict_route: true, + attempt_timeout: Duration::from_secs(1), + total_budget: Duration::from_secs(2), + grease_enabled: false, + deterministic: false, + profile_cache_ttl: Duration::from_secs(60), + }; + let cache_key = profile_cache_key( + "mask.example", + 443, + "tls.example", + None, + Some("tls"), + 0, + None, + ); + profile_cache().remove(&cache_key); + profile_cache().insert( + cache_key.clone(), + ProfileCacheValue { + profile: TlsFetchProfile::CompatTls12, + updated_at: Instant::now(), + }, + ); + + let ordered = order_profiles(&strategy, Some(&cache_key), Instant::now()); + assert_eq!(ordered[0], TlsFetchProfile::CompatTls12); + profile_cache().remove(&cache_key); + } + + #[test] + fn test_order_profiles_drops_expired_cached_winner() { + let strategy = TlsFetchStrategy { + profiles: vec![ + TlsFetchProfile::ModernFirefoxLike, + TlsFetchProfile::CompatTls12, + ], + strict_route: true, + attempt_timeout: Duration::from_secs(1), + total_budget: Duration::from_secs(2), + grease_enabled: false, + deterministic: false, + profile_cache_ttl: Duration::from_secs(5), + }; + let cache_key = + profile_cache_key("mask2.example", 443, "tls2.example", None, None, 0, None); + profile_cache().remove(&cache_key); + profile_cache().insert( + cache_key.clone(), + ProfileCacheValue { + profile: TlsFetchProfile::CompatTls12, + updated_at: Instant::now() - Duration::from_secs(6), + }, + ); + + let ordered = order_profiles(&strategy, Some(&cache_key), Instant::now()); + assert_eq!(ordered[0], TlsFetchProfile::ModernFirefoxLike); + assert!(profile_cache().get(&cache_key).is_none()); + } + + #[test] + fn test_deterministic_client_hello_is_stable() { + let rng = SecureRandom::new(); + let first = build_client_hello( + "stable.example", + &rng, + TlsFetchProfile::ModernChromeLike, + true, + true, + ); + let second = build_client_hello( + "stable.example", + &rng, + TlsFetchProfile::ModernChromeLike, + true, + true, + ); + + assert_eq!(first, second); + } } diff --git a/src/transport/middle_proxy/config_updater.rs b/src/transport/middle_proxy/config_updater.rs index 8e5a701..ebe45fc 100644 --- a/src/transport/middle_proxy/config_updater.rs +++ b/src/transport/middle_proxy/config_updater.rs @@ -11,17 +11,19 @@ use tracing::{debug, info, warn}; use crate::config::ProxyConfig; use crate::error::Result; +use crate::transport::UpstreamManager; use super::MePool; +use super::http_fetch::https_get; use super::rotation::{MeReinitTrigger, enqueue_reinit_trigger}; -use super::secret::download_proxy_secret_with_max_len; +use super::secret::download_proxy_secret_with_max_len_via_upstream; use super::selftest::record_timeskew_sample; use std::time::SystemTime; -async fn retry_fetch(url: &str) -> Option { +async fn retry_fetch(url: &str, upstream: Option>) -> Option { let delays = [1u64, 5, 15]; for (i, d) in delays.iter().enumerate() { - match fetch_proxy_config(url).await { + match fetch_proxy_config_via_upstream(url, upstream.clone()).await { Ok(cfg) => return Some(cfg), Err(e) => { if i == delays.len() - 1 { @@ -95,14 +97,19 @@ pub async fn save_proxy_config_cache(path: &str, raw_text: &str) -> Result<()> { Ok(()) } +#[allow(dead_code)] pub async fn fetch_proxy_config_with_raw(url: &str) -> Result<(ProxyConfigData, String)> { - let resp = reqwest::get(url).await.map_err(|e| { - crate::error::ProxyError::Proxy(format!("fetch_proxy_config GET failed: {e}")) - })?; - let http_status = resp.status().as_u16(); + fetch_proxy_config_with_raw_via_upstream(url, None).await +} - if let Some(date) = resp.headers().get(reqwest::header::DATE) - && let Ok(date_str) = date.to_str() +pub async fn fetch_proxy_config_with_raw_via_upstream( + url: &str, + upstream: Option>, +) -> Result<(ProxyConfigData, String)> { + let resp = https_get(url, upstream).await?; + let http_status = resp.status; + + if let Some(date_str) = resp.date_header.as_deref() && let Ok(server_time) = httpdate::parse_http_date(date_str) && let Ok(skew) = SystemTime::now() .duration_since(server_time) @@ -123,9 +130,7 @@ pub async fn fetch_proxy_config_with_raw(url: &str) -> Result<(ProxyConfigData, } } - let text = resp.text().await.map_err(|e| { - crate::error::ProxyError::Proxy(format!("fetch_proxy_config read failed: {e}")) - })?; + let text = String::from_utf8_lossy(&resp.body).into_owned(); let parsed = parse_proxy_config_text(&text, http_status); Ok((parsed, text)) } @@ -260,8 +265,16 @@ fn parse_proxy_line(line: &str) -> Option<(i32, IpAddr, u16)> { Some((dc, ip, port)) } +#[allow(dead_code)] pub async fn fetch_proxy_config(url: &str) -> Result { - fetch_proxy_config_with_raw(url) + fetch_proxy_config_via_upstream(url, None).await +} + +pub async fn fetch_proxy_config_via_upstream( + url: &str, + upstream: Option>, +) -> Result { + fetch_proxy_config_with_raw_via_upstream(url, upstream) .await .map(|(parsed, _raw)| parsed) } @@ -300,53 +313,7 @@ async fn run_update_cycle( state: &mut UpdaterState, reinit_tx: &mpsc::Sender, ) { - pool.update_runtime_reinit_policy( - cfg.general.hardswap, - cfg.general.me_pool_drain_ttl_secs, - cfg.general.me_instadrain, - cfg.general.me_pool_drain_threshold, - cfg.general.me_pool_drain_soft_evict_enabled, - cfg.general.me_pool_drain_soft_evict_grace_secs, - cfg.general.me_pool_drain_soft_evict_per_writer, - cfg.general.me_pool_drain_soft_evict_budget_per_core, - cfg.general.me_pool_drain_soft_evict_cooldown_ms, - cfg.general.effective_me_pool_force_close_secs(), - cfg.general.me_pool_min_fresh_ratio, - cfg.general.me_hardswap_warmup_delay_min_ms, - cfg.general.me_hardswap_warmup_delay_max_ms, - cfg.general.me_hardswap_warmup_extra_passes, - cfg.general.me_hardswap_warmup_pass_backoff_base_ms, - cfg.general.me_bind_stale_mode, - cfg.general.me_bind_stale_ttl_secs, - cfg.general.me_secret_atomic_snapshot, - cfg.general.me_deterministic_writer_sort, - cfg.general.me_writer_pick_mode, - cfg.general.me_writer_pick_sample_size, - cfg.general.me_single_endpoint_shadow_writers, - cfg.general.me_single_endpoint_outage_mode_enabled, - cfg.general.me_single_endpoint_outage_disable_quarantine, - cfg.general.me_single_endpoint_outage_backoff_min_ms, - cfg.general.me_single_endpoint_outage_backoff_max_ms, - cfg.general.me_single_endpoint_shadow_rotate_every_secs, - cfg.general.me_floor_mode, - cfg.general.me_adaptive_floor_idle_secs, - cfg.general.me_adaptive_floor_min_writers_single_endpoint, - cfg.general.me_adaptive_floor_min_writers_multi_endpoint, - cfg.general.me_adaptive_floor_recover_grace_secs, - cfg.general.me_adaptive_floor_writers_per_core_total, - cfg.general.me_adaptive_floor_cpu_cores_override, - cfg.general - .me_adaptive_floor_max_extra_writers_single_per_core, - cfg.general - .me_adaptive_floor_max_extra_writers_multi_per_core, - cfg.general.me_adaptive_floor_max_active_writers_per_core, - cfg.general.me_adaptive_floor_max_warm_writers_per_core, - cfg.general.me_adaptive_floor_max_active_writers_global, - cfg.general.me_adaptive_floor_max_warm_writers_global, - cfg.general.me_health_interval_ms_unhealthy, - cfg.general.me_health_interval_ms_healthy, - cfg.general.me_warn_rate_limit_ms, - ); + let upstream = pool.upstream.clone(); let required_cfg_snapshots = cfg.general.me_config_stable_snapshots.max(1); let required_secret_snapshots = cfg.general.proxy_secret_stable_snapshots.max(1); @@ -354,7 +321,7 @@ async fn run_update_cycle( let mut maps_changed = false; let mut ready_v4: Option<(ProxyConfigData, u64)> = None; - let cfg_v4 = retry_fetch("https://core.telegram.org/getProxyConfig").await; + let cfg_v4 = retry_fetch("https://core.telegram.org/getProxyConfig", upstream.clone()).await; if let Some(cfg_v4) = cfg_v4 && snapshot_passes_guards(cfg, &cfg_v4, "getProxyConfig") { @@ -378,7 +345,11 @@ async fn run_update_cycle( } let mut ready_v6: Option<(ProxyConfigData, u64)> = None; - let cfg_v6 = retry_fetch("https://core.telegram.org/getProxyConfigV6").await; + let cfg_v6 = retry_fetch( + "https://core.telegram.org/getProxyConfigV6", + upstream.clone(), + ) + .await; if let Some(cfg_v6) = cfg_v6 && snapshot_passes_guards(cfg, &cfg_v6, "getProxyConfigV6") { @@ -456,7 +427,12 @@ async fn run_update_cycle( pool.reset_stun_state(); if cfg.general.proxy_secret_rotate_runtime { - match download_proxy_secret_with_max_len(cfg.general.proxy_secret_len_max).await { + match download_proxy_secret_with_max_len_via_upstream( + cfg.general.proxy_secret_len_max, + upstream, + ) + .await + { Ok(secret) => { let secret_hash = hash_secret(&secret); let stable_hits = state.secret.observe(secret_hash); diff --git a/src/transport/middle_proxy/handshake.rs b/src/transport/middle_proxy/handshake.rs index b6eff37..01206e2 100644 --- a/src/transport/middle_proxy/handshake.rs +++ b/src/transport/middle_proxy/handshake.rs @@ -161,7 +161,7 @@ impl MePool { } else { let connect_fut = async { if addr.is_ipv6() - && let Some(v6) = self.detected_ipv6 + && let Some(v6) = self.nat_runtime.detected_ipv6 { match TcpSocket::new_v6() { Ok(sock) => { @@ -305,7 +305,7 @@ impl MePool { } MeSocksKdfPolicy::Compat => { self.stats.increment_me_socks_kdf_compat_fallback(); - if self.nat_probe { + if self.nat_runtime.nat_probe { let bind_ip = Self::direct_bind_ip_for_stun(family, upstream_egress); self.maybe_reflect_public_addr(family, bind_ip).await } else { @@ -313,7 +313,7 @@ impl MePool { } } } - } else if self.nat_probe { + } else if self.nat_runtime.nat_probe { let bind_ip = Self::direct_bind_ip_for_stun(family, upstream_egress); self.maybe_reflect_public_addr(family, bind_ip).await } else { @@ -343,7 +343,10 @@ impl MePool { .unwrap_or_default() .as_secs() as u32; - let secret_atomic_snapshot = self.secret_atomic_snapshot.load(Ordering::Relaxed); + let secret_atomic_snapshot = self + .writer_selection_policy + .secret_atomic_snapshot + .load(Ordering::Relaxed); let (ks, secret) = if secret_atomic_snapshot { let snapshot = self.secret_snapshot().await; (snapshot.key_selector, snapshot.secret) diff --git a/src/transport/middle_proxy/health.rs b/src/transport/middle_proxy/health.rs index 3e53f38..257d8f3 100644 --- a/src/transport/middle_proxy/health.rs +++ b/src/transport/middle_proxy/health.rs @@ -7,6 +7,8 @@ use std::sync::Arc; use std::time::{Duration, Instant}; use rand::RngExt; +use tokio::sync::Semaphore; +use tokio::task::JoinSet; use tracing::{debug, info, warn}; use crate::config::MeFloorMode; @@ -14,6 +16,7 @@ use crate::crypto::SecureRandom; use crate::network::IpFamily; use super::MePool; +use super::pool::MeFamilyRuntimeState; const JITTER_FRAC_NUM: u64 = 2; // jitter up to 50% of backoff #[allow(dead_code)] @@ -27,6 +30,9 @@ const HEALTH_RECONNECT_BUDGET_PER_CORE: usize = 2; const HEALTH_RECONNECT_BUDGET_PER_DC: usize = 1; const HEALTH_RECONNECT_BUDGET_MIN: usize = 4; const HEALTH_RECONNECT_BUDGET_MAX: usize = 128; +const FAMILY_SUPPRESS_FAIL_STREAK_THRESHOLD: u32 = 5; +const FAMILY_SUPPRESS_DURATION_SECS: u64 = 60; +const FAMILY_RECOVER_SUCCESS_STREAK_TARGET: u32 = 2; const HEALTH_DRAIN_CLOSE_BUDGET_PER_CORE: usize = 16; const HEALTH_DRAIN_CLOSE_BUDGET_MIN: usize = 16; const HEALTH_DRAIN_CLOSE_BUDGET_MAX: usize = 256; @@ -56,6 +62,17 @@ struct FamilyFloorPlan { target_writers_total: usize, } +#[derive(Debug)] +struct FamilyReconnectOutcome { + key: (i32, IpFamily), + dc: i32, + family: IpFamily, + alive: usize, + required: usize, + endpoint_count: usize, + restored: usize, +} + pub async fn me_health_monitor(pool: Arc, rng: Arc, _min_connections: usize) { let mut backoff: HashMap<(i32, IpFamily), u64> = HashMap::new(); let mut next_attempt: HashMap<(i32, IpFamily), Instant> = HashMap::new(); @@ -78,6 +95,7 @@ pub async fn me_health_monitor(pool: Arc, rng: Arc, _min_c }; tokio::time::sleep(interval).await; pool.prune_closed_writers().await; + pool.sweep_endpoint_quarantine().await; reap_draining_writers(&pool, &mut drain_warn_next_allowed).await; let v4_degraded = check_family( IpFamily::V4, @@ -113,6 +131,8 @@ pub async fn me_health_monitor(pool: Arc, rng: Arc, _min_c &mut floor_warn_next_allowed, ) .await; + update_family_runtime_state(&pool, IpFamily::V4, v4_degraded); + update_family_runtime_state(&pool, IpFamily::V6, v6_degraded); degraded_interval = v4_degraded || v6_degraded; } } @@ -135,9 +155,11 @@ pub(super) async fn reap_draining_writers( let now_epoch_secs = MePool::now_epoch_secs(); let now = Instant::now(); let drain_ttl_secs = pool + .drain_runtime .me_pool_drain_ttl_secs .load(std::sync::atomic::Ordering::Relaxed); let drain_threshold = pool + .drain_runtime .me_pool_drain_threshold .load(std::sync::atomic::Ordering::Relaxed); let activity = pool.registry.writer_activity_snapshot().await; @@ -221,7 +243,10 @@ pub(super) async fn reap_draining_writers( endpoint = %writer.addr, generation = writer.generation, drain_ttl_secs, - force_close_secs = pool.me_pool_force_close_secs.load(std::sync::atomic::Ordering::Relaxed), + force_close_secs = pool + .drain_runtime + .me_pool_force_close_secs + .load(std::sync::atomic::Ordering::Relaxed), allow_drain_fallback = writer.allow_drain_fallback, "ME draining writer remains non-empty past drain TTL" ); @@ -365,7 +390,8 @@ async fn check_family( endpoints.sort_unstable(); endpoints.dedup(); } - let mut reconnect_budget = health_reconnect_budget(pool, dc_endpoints.len()); + let reconnect_budget = health_reconnect_budget(pool, dc_endpoints.len()); + let reconnect_sem = Arc::new(Semaphore::new(reconnect_budget)); if pool.floor_mode() == MeFloorMode::Static { adaptive_idle_since.clear(); @@ -422,6 +448,10 @@ async fn check_family( floor_plan.active_writers_current, floor_plan.warm_writers_current, ); + let live_writer_ids_by_addr = Arc::new(live_writer_ids_by_addr); + let writer_idle_since = Arc::new(writer_idle_since); + let bound_clients_by_writer = Arc::new(bound_clients_by_writer); + let mut reconnect_set = JoinSet::::new(); for (dc, endpoints) in dc_endpoints { if endpoints.is_empty() { @@ -461,7 +491,7 @@ async fn check_family( required, outage_backoff, outage_next_attempt, - &mut reconnect_budget, + &reconnect_sem, ) .await; continue; @@ -495,9 +525,9 @@ async fn check_family( &endpoints, alive, required, - &live_writer_ids_by_addr, - &writer_idle_since, - &bound_clients_by_writer, + live_writer_ids_by_addr.as_ref(), + writer_idle_since.as_ref(), + bound_clients_by_writer.as_ref(), idle_refresh_next_attempt, ) .await; @@ -510,8 +540,8 @@ async fn check_family( &endpoints, alive, required, - &live_writer_ids_by_addr, - &bound_clients_by_writer, + live_writer_ids_by_addr.as_ref(), + bound_clients_by_writer.as_ref(), shadow_rotate_deadline, ) .await; @@ -521,8 +551,8 @@ async fn check_family( family_degraded = true; let now = Instant::now(); - if reconnect_budget == 0 { - let base_ms = pool.me_reconnect_backoff_base.as_millis() as u64; + if reconnect_sem.available_permits() == 0 { + let base_ms = pool.reconnect_runtime.me_reconnect_backoff_base.as_millis() as u64; let next_ms = (*backoff.get(&key).unwrap_or(&base_ms)).max(base_ms); let jitter = next_ms / JITTER_FRAC_NUM; let wait = Duration::from_millis(next_ms) @@ -545,7 +575,10 @@ async fn check_family( continue; } - let max_concurrent = pool.me_reconnect_max_concurrent_per_dc.max(1) as usize; + let max_concurrent = pool + .reconnect_runtime + .me_reconnect_max_concurrent_per_dc + .max(1) as usize; if *inflight.get(&key).unwrap_or(&0) >= max_concurrent { continue; } @@ -564,117 +597,165 @@ async fn check_family( continue; } *inflight.entry(key).or_insert(0) += 1; - - let mut restored = 0usize; - for _ in 0..missing { - if reconnect_budget == 0 { - break; - } - reconnect_budget = reconnect_budget.saturating_sub(1); - if pool.active_contour_writer_count_total().await - >= floor_plan.active_cap_effective_total - { - let swapped = maybe_swap_idle_writer_for_cap( - pool, - rng, - dc, - family, - &endpoints, - &live_writer_ids_by_addr, - &writer_idle_since, - &bound_clients_by_writer, + let pool_for_reconnect = pool.clone(); + let rng_for_reconnect = rng.clone(); + let reconnect_sem_for_dc = reconnect_sem.clone(); + let endpoints_for_dc = endpoints.clone(); + let live_writer_ids_by_addr_for_dc = live_writer_ids_by_addr.clone(); + let writer_idle_since_for_dc = writer_idle_since.clone(); + let bound_clients_by_writer_for_dc = bound_clients_by_writer.clone(); + let active_cap_effective_total = floor_plan.active_cap_effective_total; + reconnect_set.spawn(async move { + let mut restored = 0usize; + for _ in 0..missing { + let Ok(reconnect_permit) = reconnect_sem_for_dc.clone().try_acquire_owned() else { + break; + }; + if pool_for_reconnect.active_contour_writer_count_total().await + >= active_cap_effective_total + { + let swapped = maybe_swap_idle_writer_for_cap( + &pool_for_reconnect, + &rng_for_reconnect, + dc, + family, + &endpoints_for_dc, + live_writer_ids_by_addr_for_dc.as_ref(), + writer_idle_since_for_dc.as_ref(), + bound_clients_by_writer_for_dc.as_ref(), + ) + .await; + if swapped { + pool_for_reconnect + .stats + .increment_me_floor_swap_idle_total(); + restored += 1; + continue; + } + pool_for_reconnect + .stats + .increment_me_floor_cap_block_total(); + pool_for_reconnect + .stats + .increment_me_floor_swap_idle_failed_total(); + debug!( + dc = %dc, + ?family, + alive, + required, + active_cap_effective_total, + "Adaptive floor cap reached, reconnect attempt blocked" + ); + break; + } + let res = tokio::time::timeout( + pool_for_reconnect.reconnect_runtime.me_one_timeout, + pool_for_reconnect.connect_endpoints_round_robin( + dc, + &endpoints_for_dc, + rng_for_reconnect.as_ref(), + ), ) .await; - if swapped { - pool.stats.increment_me_floor_swap_idle_total(); - restored += 1; - continue; + match res { + Ok(true) => { + restored += 1; + pool_for_reconnect.stats.increment_me_reconnect_success(); + } + Ok(false) => { + pool_for_reconnect.stats.increment_me_reconnect_attempt(); + debug!(dc = %dc, ?family, "ME round-robin reconnect failed") + } + Err(_) => { + pool_for_reconnect.stats.increment_me_reconnect_attempt(); + debug!(dc = %dc, ?family, "ME reconnect timed out"); + } } - pool.stats.increment_me_floor_cap_block_total(); - pool.stats.increment_me_floor_swap_idle_failed_total(); - debug!( - dc = %dc, - ?family, - alive, - required, - active_cap_effective_total = floor_plan.active_cap_effective_total, - "Adaptive floor cap reached, reconnect attempt blocked" - ); - break; + drop(reconnect_permit); } - let res = tokio::time::timeout( - pool.me_one_timeout, - pool.connect_endpoints_round_robin(dc, &endpoints, rng.as_ref()), - ) - .await; - match res { - Ok(true) => { - restored += 1; - pool.stats.increment_me_reconnect_success(); - } - Ok(false) => { - pool.stats.increment_me_reconnect_attempt(); - debug!(dc = %dc, ?family, "ME round-robin reconnect failed") - } - Err(_) => { - pool.stats.increment_me_reconnect_attempt(); - debug!(dc = %dc, ?family, "ME reconnect timed out"); - } - } - } - let now_alive = alive + restored; - if now_alive >= required { - info!( - dc = %dc, - ?family, - alive = now_alive, + FamilyReconnectOutcome { + key, + dc, + family, + alive, required, - endpoint_count = endpoints.len(), + endpoint_count: endpoints_for_dc.len(), + restored, + } + }); + } + + while let Some(joined) = reconnect_set.join_next().await { + let outcome = match joined { + Ok(outcome) => outcome, + Err(join_error) => { + debug!(error = %join_error, "Health reconnect task failed"); + continue; + } + }; + let now = Instant::now(); + let now_alive = outcome.alive + outcome.restored; + if now_alive >= outcome.required { + info!( + dc = %outcome.dc, + family = ?outcome.family, + alive = now_alive, + required = outcome.required, + endpoint_count = outcome.endpoint_count, "ME writer floor restored for DC" ); - backoff.insert(key, pool.me_reconnect_backoff_base.as_millis() as u64); - let jitter = pool.me_reconnect_backoff_base.as_millis() as u64 / JITTER_FRAC_NUM; - let wait = pool.me_reconnect_backoff_base + backoff.insert( + outcome.key, + pool.reconnect_runtime.me_reconnect_backoff_base.as_millis() as u64, + ); + let jitter = pool.reconnect_runtime.me_reconnect_backoff_base.as_millis() as u64 + / JITTER_FRAC_NUM; + let wait = pool.reconnect_runtime.me_reconnect_backoff_base + Duration::from_millis(rand::rng().random_range(0..=jitter.max(1))); - next_attempt.insert(key, now + wait); + next_attempt.insert(outcome.key, now + wait); } else { let curr = *backoff - .get(&key) - .unwrap_or(&(pool.me_reconnect_backoff_base.as_millis() as u64)); - let next_ms = - (curr.saturating_mul(2)).min(pool.me_reconnect_backoff_cap.as_millis() as u64); - backoff.insert(key, next_ms); + .get(&outcome.key) + .unwrap_or(&(pool.reconnect_runtime.me_reconnect_backoff_base.as_millis() as u64)); + let next_ms = (curr.saturating_mul(2)) + .min(pool.reconnect_runtime.me_reconnect_backoff_cap.as_millis() as u64); + backoff.insert(outcome.key, next_ms); let jitter = next_ms / JITTER_FRAC_NUM; let wait = Duration::from_millis(next_ms) + Duration::from_millis(rand::rng().random_range(0..=jitter.max(1))); - next_attempt.insert(key, now + wait); + next_attempt.insert(outcome.key, now + wait); if pool.is_runtime_ready() { let warn_cooldown = pool.warn_rate_limit_duration(); - if should_emit_rate_limited_warn(floor_warn_next_allowed, key, now, warn_cooldown) { + if should_emit_rate_limited_warn( + floor_warn_next_allowed, + outcome.key, + now, + warn_cooldown, + ) { warn!( - dc = %dc, - ?family, + dc = %outcome.dc, + family = ?outcome.family, alive = now_alive, - required, - endpoint_count = endpoints.len(), + required = outcome.required, + endpoint_count = outcome.endpoint_count, backoff_ms = next_ms, "DC writer floor is below required level, scheduled reconnect" ); } } else { info!( - dc = %dc, - ?family, + dc = %outcome.dc, + family = ?outcome.family, alive = now_alive, - required, - endpoint_count = endpoints.len(), + required = outcome.required, + endpoint_count = outcome.endpoint_count, backoff_ms = next_ms, "DC writer floor is below required level during startup, scheduled reconnect" ); } } - if let Some(v) = inflight.get_mut(&key) { + if let Some(v) = inflight.get_mut(&outcome.key) { *v = v.saturating_sub(1); } } @@ -691,6 +772,68 @@ fn health_reconnect_budget(pool: &Arc, dc_groups: usize) -> usize { .clamp(HEALTH_RECONNECT_BUDGET_MIN, HEALTH_RECONNECT_BUDGET_MAX) } +fn update_family_runtime_state(pool: &Arc, family: IpFamily, degraded: bool) { + let now_epoch_secs = MePool::now_epoch_secs(); + let previous_state = pool.family_runtime_state(family); + let mut state_since_epoch_secs = pool.family_runtime_state_since_epoch_secs(family); + let previous_suppressed_until_epoch_secs = pool.family_suppressed_until_epoch_secs(family); + let previous_fail_streak = pool.family_fail_streak(family); + let previous_recover_success_streak = pool.family_recover_success_streak(family); + + let (next_state, suppressed_until_epoch_secs, fail_streak, recover_success_streak) = + if previous_suppressed_until_epoch_secs > now_epoch_secs { + let fail_streak = if degraded { + previous_fail_streak.saturating_add(1) + } else { + previous_fail_streak + }; + ( + MeFamilyRuntimeState::Suppressed, + previous_suppressed_until_epoch_secs, + fail_streak, + 0, + ) + } else if degraded { + let fail_streak = previous_fail_streak.saturating_add(1); + if fail_streak >= FAMILY_SUPPRESS_FAIL_STREAK_THRESHOLD { + ( + MeFamilyRuntimeState::Suppressed, + now_epoch_secs.saturating_add(FAMILY_SUPPRESS_DURATION_SECS), + fail_streak, + 0, + ) + } else { + (MeFamilyRuntimeState::Degraded, 0, fail_streak, 0) + } + } else if matches!(previous_state, MeFamilyRuntimeState::Healthy) { + (MeFamilyRuntimeState::Healthy, 0, 0, 0) + } else { + let recover_success_streak = previous_recover_success_streak.saturating_add(1); + if recover_success_streak >= FAMILY_RECOVER_SUCCESS_STREAK_TARGET { + (MeFamilyRuntimeState::Healthy, 0, 0, 0) + } else { + ( + MeFamilyRuntimeState::Recovering, + 0, + 0, + recover_success_streak, + ) + } + }; + + if next_state != previous_state || state_since_epoch_secs == 0 { + state_since_epoch_secs = now_epoch_secs; + } + pool.set_family_runtime_state( + family, + next_state, + state_since_epoch_secs, + suppressed_until_epoch_secs, + fail_streak, + recover_success_streak, + ); +} + fn should_emit_rate_limited_warn( next_allowed: &mut HashMap<(i32, IpFamily), Instant>, key: (i32, IpFamily), @@ -715,6 +858,7 @@ fn adaptive_floor_class_min( ) -> usize { if endpoint_count <= 1 { let min_single = (pool + .floor_runtime .me_adaptive_floor_min_writers_single_endpoint .load(std::sync::atomic::Ordering::Relaxed) as usize) .max(1); @@ -971,7 +1115,7 @@ async fn maybe_swap_idle_writer_for_cap( }; let connected = match tokio::time::timeout( - pool.me_one_timeout, + pool.reconnect_runtime.me_one_timeout, pool.connect_one_for_dc(endpoint, dc, rng.as_ref()), ) .await @@ -1077,7 +1221,7 @@ async fn maybe_refresh_idle_writer_for_dc( }; let rotate_ok = match tokio::time::timeout( - pool.me_one_timeout, + pool.reconnect_runtime.me_one_timeout, pool.connect_one_for_dc(endpoint, dc, rng.as_ref()), ) .await @@ -1188,7 +1332,7 @@ async fn recover_single_endpoint_outage( required: usize, outage_backoff: &mut HashMap<(i32, IpFamily), u64>, outage_next_attempt: &mut HashMap<(i32, IpFamily), Instant>, - reconnect_budget: &mut usize, + reconnect_sem: &Arc, ) { let now = Instant::now(); if let Some(ts) = outage_next_attempt.get(&key) @@ -1198,7 +1342,7 @@ async fn recover_single_endpoint_outage( } let (min_backoff_ms, max_backoff_ms) = pool.single_endpoint_outage_backoff_bounds_ms(); - if *reconnect_budget == 0 { + if reconnect_sem.available_permits() == 0 { outage_next_attempt.insert(key, now + Duration::from_millis(min_backoff_ms.max(250))); debug!( dc = %key.0, @@ -1209,7 +1353,17 @@ async fn recover_single_endpoint_outage( ); return; } - *reconnect_budget = (*reconnect_budget).saturating_sub(1); + let Ok(_reconnect_permit) = reconnect_sem.clone().try_acquire_owned() else { + outage_next_attempt.insert(key, now + Duration::from_millis(min_backoff_ms.max(250))); + debug!( + dc = %key.0, + family = ?key.1, + %endpoint, + required, + "Single-endpoint outage reconnect deferred by semaphore saturation" + ); + return; + }; pool.stats .increment_me_single_endpoint_outage_reconnect_attempt_total(); @@ -1218,7 +1372,7 @@ async fn recover_single_endpoint_outage( pool.stats .increment_me_single_endpoint_quarantine_bypass_total(); match tokio::time::timeout( - pool.me_one_timeout, + pool.reconnect_runtime.me_one_timeout, pool.connect_one_for_dc(endpoint, key.0, rng.as_ref()), ) .await @@ -1247,7 +1401,7 @@ async fn recover_single_endpoint_outage( } else { let one_endpoint = [endpoint]; match tokio::time::timeout( - pool.me_one_timeout, + pool.reconnect_runtime.me_one_timeout, pool.connect_endpoints_round_robin(key.0, &one_endpoint, rng.as_ref()), ) .await @@ -1372,7 +1526,7 @@ async fn maybe_rotate_single_endpoint_shadow( }; let rotate_ok = match tokio::time::timeout( - pool.me_one_timeout, + pool.reconnect_runtime.me_one_timeout, pool.connect_one_for_dc(endpoint, dc, rng.as_ref()), ) .await @@ -1687,6 +1841,8 @@ mod tests { general.me_warn_rate_limit_ms, MeRouteNoWriterMode::default(), general.me_route_no_writer_wait_ms, + general.me_route_hybrid_max_wait_ms, + general.me_route_blocking_send_timeout_ms, general.me_route_inline_recovery_attempts, general.me_route_inline_recovery_wait_ms, ) diff --git a/src/transport/middle_proxy/http_fetch.rs b/src/transport/middle_proxy/http_fetch.rs new file mode 100644 index 0000000..5be601e --- /dev/null +++ b/src/transport/middle_proxy/http_fetch.rs @@ -0,0 +1,183 @@ +use std::sync::Arc; +use std::time::Duration; + +use http_body_util::{BodyExt, Empty}; +use hyper::header::{CONNECTION, DATE, HOST, USER_AGENT}; +use hyper::{Method, Request}; +use hyper_util::rt::TokioIo; +use rustls::pki_types::ServerName; +use tokio::net::TcpStream; +use tokio::time::timeout; +use tokio_rustls::TlsConnector; +use tracing::debug; + +use crate::error::{ProxyError, Result}; +use crate::network::dns_overrides::resolve_socket_addr; +use crate::transport::{UpstreamManager, UpstreamStream}; + +const HTTP_CONNECT_TIMEOUT: Duration = Duration::from_secs(10); +const HTTP_REQUEST_TIMEOUT: Duration = Duration::from_secs(15); + +pub(crate) struct HttpsGetResponse { + pub(crate) status: u16, + pub(crate) date_header: Option, + pub(crate) body: Vec, +} + +fn build_tls_client_config() -> Arc { + let mut root_store = rustls::RootCertStore::empty(); + root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); + let provider = rustls::crypto::ring::default_provider(); + let config = rustls::ClientConfig::builder_with_provider(Arc::new(provider)) + .with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12]) + .expect("HTTPS fetch rustls protocol versions must be valid") + .with_root_certificates(root_store) + .with_no_client_auth(); + Arc::new(config) +} + +fn extract_host_port_path(url: &str) -> Result<(String, u16, String)> { + let parsed = + url::Url::parse(url).map_err(|e| ProxyError::Proxy(format!("invalid URL '{url}': {e}")))?; + if parsed.scheme() != "https" { + return Err(ProxyError::Proxy(format!( + "unsupported URL scheme '{}': only https is supported", + parsed.scheme() + ))); + } + + let host = parsed + .host_str() + .ok_or_else(|| ProxyError::Proxy(format!("URL has no host: {url}")))? + .to_string(); + let port = parsed + .port_or_known_default() + .ok_or_else(|| ProxyError::Proxy(format!("URL has no known port: {url}")))?; + + let mut path = parsed.path().to_string(); + if path.is_empty() { + path.push('/'); + } + if let Some(query) = parsed.query() { + path.push('?'); + path.push_str(query); + } + + Ok((host, port, path)) +} + +async fn resolve_target_addr(host: &str, port: u16) -> Result { + if let Some(addr) = resolve_socket_addr(host, port) { + return Ok(addr); + } + + let addrs: Vec = tokio::net::lookup_host((host, port)) + .await + .map_err(|e| ProxyError::Proxy(format!("DNS resolve failed for {host}:{port}: {e}")))? + .collect(); + + if let Some(addr) = addrs.iter().copied().find(|addr| addr.is_ipv4()) { + return Ok(addr); + } + + addrs + .first() + .copied() + .ok_or_else(|| ProxyError::Proxy(format!("DNS returned no addresses for {host}:{port}"))) +} + +async fn connect_https_transport( + host: &str, + port: u16, + upstream: Option>, +) -> Result { + if let Some(manager) = upstream { + let target = resolve_target_addr(host, port).await?; + return timeout(HTTP_CONNECT_TIMEOUT, manager.connect(target, None, None)) + .await + .map_err(|_| ProxyError::Proxy(format!("upstream connect timeout for {host}:{port}")))? + .map_err(|e| { + ProxyError::Proxy(format!("upstream connect failed for {host}:{port}: {e}")) + }); + } + + if let Some(addr) = resolve_socket_addr(host, port) { + let stream = timeout(HTTP_CONNECT_TIMEOUT, TcpStream::connect(addr)) + .await + .map_err(|_| ProxyError::Proxy(format!("connect timeout for {host}:{port}")))? + .map_err(|e| ProxyError::Proxy(format!("connect failed for {host}:{port}: {e}")))?; + return Ok(UpstreamStream::Tcp(stream)); + } + + let stream = timeout(HTTP_CONNECT_TIMEOUT, TcpStream::connect((host, port))) + .await + .map_err(|_| ProxyError::Proxy(format!("connect timeout for {host}:{port}")))? + .map_err(|e| ProxyError::Proxy(format!("connect failed for {host}:{port}: {e}")))?; + Ok(UpstreamStream::Tcp(stream)) +} + +pub(crate) async fn https_get( + url: &str, + upstream: Option>, +) -> Result { + let (host, port, path_and_query) = extract_host_port_path(url)?; + let stream = connect_https_transport(&host, port, upstream).await?; + + let server_name = ServerName::try_from(host.clone()) + .map_err(|_| ProxyError::Proxy(format!("invalid TLS server name: {host}")))?; + let connector = TlsConnector::from(build_tls_client_config()); + let tls_stream = timeout(HTTP_REQUEST_TIMEOUT, connector.connect(server_name, stream)) + .await + .map_err(|_| ProxyError::Proxy(format!("TLS handshake timeout for {host}:{port}")))? + .map_err(|e| ProxyError::Proxy(format!("TLS handshake failed for {host}:{port}: {e}")))?; + + let (mut sender, connection) = hyper::client::conn::http1::handshake(TokioIo::new(tls_stream)) + .await + .map_err(|e| ProxyError::Proxy(format!("HTTP handshake failed for {host}:{port}: {e}")))?; + + tokio::spawn(async move { + if let Err(e) = connection.await { + debug!(error = %e, "HTTPS fetch connection task failed"); + } + }); + + let host_header = if port == 443 { + host.clone() + } else { + format!("{host}:{port}") + }; + + let request = Request::builder() + .method(Method::GET) + .uri(path_and_query) + .header(HOST, host_header) + .header(USER_AGENT, "telemt-middle-proxy/1") + .header(CONNECTION, "close") + .body(Empty::::new()) + .map_err(|e| ProxyError::Proxy(format!("build HTTP request failed for {url}: {e}")))?; + + let response = timeout(HTTP_REQUEST_TIMEOUT, sender.send_request(request)) + .await + .map_err(|_| ProxyError::Proxy(format!("HTTP request timeout for {url}")))? + .map_err(|e| ProxyError::Proxy(format!("HTTP request failed for {url}: {e}")))?; + + let status = response.status().as_u16(); + let date_header = response + .headers() + .get(DATE) + .and_then(|value| value.to_str().ok()) + .map(|value| value.to_string()); + + let body = timeout(HTTP_REQUEST_TIMEOUT, response.into_body().collect()) + .await + .map_err(|_| ProxyError::Proxy(format!("HTTP body read timeout for {url}")))? + .map_err(|e| ProxyError::Proxy(format!("HTTP body read failed for {url}: {e}")))? + .to_bytes() + .to_vec(); + + Ok(HttpsGetResponse { + status, + date_header, + body, + }) +} diff --git a/src/transport/middle_proxy/mod.rs b/src/transport/middle_proxy/mod.rs index 5536869..6dfbee6 100644 --- a/src/transport/middle_proxy/mod.rs +++ b/src/transport/middle_proxy/mod.rs @@ -13,6 +13,7 @@ mod health_integration_tests; #[cfg(test)] #[path = "tests/health_regression_tests.rs"] mod health_regression_tests; +mod http_fetch; mod ping; mod pool; mod pool_config; @@ -44,7 +45,8 @@ use bytes::Bytes; #[allow(unused_imports)] pub use config_updater::{ - ProxyConfigData, fetch_proxy_config, fetch_proxy_config_with_raw, load_proxy_config_cache, + ProxyConfigData, fetch_proxy_config, fetch_proxy_config_via_upstream, + fetch_proxy_config_with_raw, fetch_proxy_config_with_raw_via_upstream, load_proxy_config_cache, me_config_updater, save_proxy_config_cache, }; pub use health::{me_drain_timeout_enforcer, me_health_monitor, me_zombie_writer_watchdog}; @@ -57,7 +59,8 @@ pub use pool::MePool; pub use pool_nat::{detect_public_ip, stun_probe}; pub use registry::ConnRegistry; pub use rotation::{MeReinitTrigger, me_reinit_scheduler, me_rotation_task}; -pub use secret::fetch_proxy_secret; +#[allow(unused_imports)] +pub use secret::{fetch_proxy_secret, fetch_proxy_secret_with_upstream}; pub(crate) use selftest::{bnd_snapshot, timeskew_snapshot, upstream_bnd_snapshots}; pub use wire::proto_flags_for_tag; diff --git a/src/transport/middle_proxy/pool.rs b/src/transport/middle_proxy/pool.rs index 71ab257..249d387 100644 --- a/src/transport/middle_proxy/pool.rs +++ b/src/transport/middle_proxy/pool.rs @@ -2,13 +2,15 @@ use std::collections::{HashMap, HashSet}; use std::net::{IpAddr, Ipv6Addr, SocketAddr}; +use std::ops::{Deref, DerefMut}; use std::sync::Arc; use std::sync::atomic::{ AtomicBool, AtomicI32, AtomicU8, AtomicU32, AtomicU64, AtomicUsize, Ordering, }; use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; -use tokio::sync::{Mutex, Notify, RwLock, mpsc}; +use arc_swap::ArcSwap; +use tokio::sync::{Mutex, RwLock, mpsc, watch}; use tokio_util::sync::CancellationToken; use crate::config::{ @@ -55,6 +57,87 @@ pub struct MeWriter { pub allow_drain_fallback: Arc, } +pub(super) struct WritersState { + // HARD INVARIANT: + // All writers.store() calls MUST be guarded by writers_write_guard. + writers: ArcSwap>, + writers_write_guard: Mutex<()>, +} + +impl WritersState { + pub(super) fn new() -> Self { + Self { + writers: ArcSwap::from_pointee(Vec::new()), + writers_write_guard: Mutex::new(()), + } + } + + pub(super) fn snapshot(&self) -> Arc> { + self.writers.load_full() + } + + pub(super) async fn read(&self) -> Arc> { + self.snapshot() + } + + pub(super) async fn write(&self) -> WritersWriteGuard<'_> { + let guard = self.writers_write_guard.lock().await; + let writers = (*self.writers.load_full()).clone(); + WritersWriteGuard { + state: self, + _guard: guard, + writers, + } + } + + pub(super) async fn update(&self, f: F) -> R + where + F: FnOnce(&mut Vec) -> R, + { + let mut guard = self.write().await; + f(&mut guard) + } + + fn debug_assert_store_guarded(&self) { + debug_assert!( + self.writers_write_guard.try_lock().is_err(), + "HARD INVARIANT violated: writers.store() without writers_write_guard" + ); + } + + fn store_guarded(&self, writers: Vec) { + self.debug_assert_store_guarded(); + self.writers.store(Arc::new(writers)); + } +} + +pub(super) struct WritersWriteGuard<'a> { + state: &'a WritersState, + _guard: tokio::sync::MutexGuard<'a, ()>, + writers: Vec, +} + +impl Deref for WritersWriteGuard<'_> { + type Target = Vec; + + fn deref(&self) -> &Self::Target { + &self.writers + } +} + +impl DerefMut for WritersWriteGuard<'_> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.writers + } +} + +impl Drop for WritersWriteGuard<'_> { + fn drop(&mut self) { + let writers = std::mem::take(&mut self.writers); + self.state.store_guarded(writers); + } +} + #[derive(Debug, Clone, Copy, PartialEq, Eq)] #[repr(u8)] pub(super) enum WriterContour { @@ -69,6 +152,10 @@ impl WriterContour { } pub(super) fn from_u8(value: u8) -> Self { + debug_assert!( + value <= Self::Draining as u8, + "Unexpected WriterContour discriminant: {value}" + ); match value { 0 => Self::Warm, 1 => Self::Active, @@ -87,16 +174,34 @@ pub(crate) enum MeFamilyRuntimeState { Recovering = 3, } -impl MeFamilyRuntimeState { - pub(crate) fn from_u8(value: u8) -> Self { - match value { - 1 => Self::Degraded, - 2 => Self::Suppressed, - 3 => Self::Recovering, - _ => Self::Healthy, +#[derive(Debug, Clone)] +pub(crate) struct FamilyHealthSnapshot { + pub(crate) state: MeFamilyRuntimeState, + pub(crate) state_since_epoch_secs: u64, + pub(crate) suppressed_until_epoch_secs: u64, + pub(crate) fail_streak: u32, + pub(crate) recover_success_streak: u32, +} + +impl FamilyHealthSnapshot { + fn new( + state: MeFamilyRuntimeState, + state_since_epoch_secs: u64, + suppressed_until_epoch_secs: u64, + fail_streak: u32, + recover_success_streak: u32, + ) -> Self { + Self { + state, + state_since_epoch_secs, + suppressed_until_epoch_secs, + fail_streak, + recover_success_streak, } } +} +impl MeFamilyRuntimeState { pub(crate) fn as_str(self) -> &'static str { match self { Self::Healthy => "healthy", @@ -143,16 +248,89 @@ pub struct SecretSnapshot { pub secret: Vec, } -#[allow(dead_code)] -pub struct MePool { +pub struct RoutingCore { pub(super) registry: Arc, - pub(super) writers: Arc>>, + pub(super) writers: Arc, pub(super) rr: AtomicU64, - pub(super) decision: NetworkDecision, - pub(super) upstream: Option>, - pub(super) rng: Arc, - pub(super) proxy_tag: Option>, - pub(super) proxy_secret: Arc>, + pub(super) writer_epoch: watch::Sender, + pub(super) preferred_endpoints_by_dc: ArcSwap>>, +} + +pub(super) struct ReinitCore { + pub(super) generation: AtomicU64, + pub(super) active_generation: AtomicU64, + pub(super) warm_generation: AtomicU64, + pub(super) pending_hardswap_generation: AtomicU64, + pub(super) pending_hardswap_started_at_epoch_secs: AtomicU64, + pub(super) pending_hardswap_map_hash: AtomicU64, + pub(super) hardswap: AtomicBool, + pub(super) me_hardswap_warmup_delay_min_ms: AtomicU64, + pub(super) me_hardswap_warmup_delay_max_ms: AtomicU64, + pub(super) me_hardswap_warmup_extra_passes: AtomicU32, + pub(super) me_hardswap_warmup_pass_backoff_base_ms: AtomicU64, +} + +pub(super) struct WriterLifecycleCore { + pub(super) me_keepalive_enabled: bool, + pub(super) me_keepalive_interval: Duration, + pub(super) me_keepalive_jitter: Duration, + pub(super) me_keepalive_payload_random: bool, + pub(super) rpc_proxy_req_every_secs: AtomicU64, + pub(super) writer_cmd_channel_capacity: usize, +} + +pub(super) struct RouteRuntimeCore { + pub(super) me_route_no_writer_mode: AtomicU8, + pub(super) me_route_no_writer_wait: Duration, + pub(super) me_route_hybrid_max_wait: Duration, + pub(super) me_route_blocking_send_timeout: Option, + pub(super) me_route_last_success_epoch_ms: AtomicU64, + pub(super) me_route_hybrid_timeout_warn_epoch_ms: AtomicU64, + pub(super) me_async_recovery_last_trigger_epoch_ms: AtomicU64, + pub(super) me_route_inline_recovery_attempts: u32, + pub(super) me_route_inline_recovery_wait: Duration, +} + +pub(super) struct HealthRuntimeCore { + pub(super) me_health_interval_ms_unhealthy: AtomicU64, + pub(super) me_health_interval_ms_healthy: AtomicU64, + pub(super) me_warn_rate_limit_ms: AtomicU64, + pub(super) family_health_v4: ArcSwap, + pub(super) family_health_v6: ArcSwap, +} + +pub(super) struct DrainRuntimeCore { + pub(super) me_pool_drain_ttl_secs: AtomicU64, + pub(super) me_instadrain: AtomicBool, + pub(super) me_pool_drain_threshold: AtomicU64, + pub(super) me_pool_drain_soft_evict_enabled: AtomicBool, + pub(super) me_pool_drain_soft_evict_grace_secs: AtomicU64, + pub(super) me_pool_drain_soft_evict_per_writer: AtomicU8, + pub(super) me_pool_drain_soft_evict_budget_per_core: AtomicU32, + pub(super) me_pool_drain_soft_evict_cooldown_ms: AtomicU64, + pub(super) me_pool_force_close_secs: AtomicU64, + pub(super) me_pool_min_fresh_ratio_permille: AtomicU32, + pub(super) me_last_drain_gate_route_quorum_ok: AtomicBool, + pub(super) me_last_drain_gate_redundancy_ok: AtomicBool, + pub(super) me_last_drain_gate_block_reason: AtomicU8, + pub(super) me_last_drain_gate_updated_at_epoch_secs: AtomicU64, +} + +pub(super) struct SingleEndpointRuntimeCore { + pub(super) me_single_endpoint_shadow_writers: AtomicU8, + pub(super) me_single_endpoint_outage_mode_enabled: AtomicBool, + pub(super) me_single_endpoint_outage_disable_quarantine: AtomicBool, + pub(super) me_single_endpoint_outage_backoff_min_ms: AtomicU64, + pub(super) me_single_endpoint_outage_backoff_max_ms: AtomicU64, + pub(super) me_single_endpoint_shadow_rotate_every_secs: AtomicU64, +} + +pub(super) struct BindingPolicyCore { + pub(super) me_bind_stale_mode: AtomicU8, + pub(super) me_bind_stale_ttl_secs: AtomicU64, +} + +pub(super) struct NatRuntimeCore { pub(super) nat_ip_cfg: Option, pub(super) nat_ip_detected: Arc>>, pub(super) nat_probe: bool, @@ -164,14 +342,15 @@ pub struct MePool { pub(super) nat_probe_attempts: std::sync::atomic::AtomicU8, pub(super) nat_probe_disabled: std::sync::atomic::AtomicBool, pub(super) stun_backoff_until: Arc>>, + pub(super) nat_reflection_cache: Arc>, + pub(super) nat_reflection_singleflight_v4: Arc>, + pub(super) nat_reflection_singleflight_v6: Arc>, +} + +pub(super) struct ReconnectRuntimeCore { + #[allow(dead_code)] pub(super) me_one_retry: u8, pub(super) me_one_timeout: Duration, - pub(super) me_keepalive_enabled: bool, - pub(super) me_keepalive_interval: Duration, - pub(super) me_keepalive_jitter: Duration, - pub(super) me_keepalive_payload_random: bool, - pub(super) rpc_proxy_req_every_secs: AtomicU64, - pub(super) writer_cmd_channel_capacity: usize, pub(super) me_warmup_stagger_enabled: bool, pub(super) me_warmup_step_delay: Duration, pub(super) me_warmup_step_jitter: Duration, @@ -179,12 +358,9 @@ pub struct MePool { pub(super) me_reconnect_backoff_base: Duration, pub(super) me_reconnect_backoff_cap: Duration, pub(super) me_reconnect_fast_retry_count: u32, - pub(super) me_single_endpoint_shadow_writers: AtomicU8, - pub(super) me_single_endpoint_outage_mode_enabled: AtomicBool, - pub(super) me_single_endpoint_outage_disable_quarantine: AtomicBool, - pub(super) me_single_endpoint_outage_backoff_min_ms: AtomicU64, - pub(super) me_single_endpoint_outage_backoff_max_ms: AtomicU64, - pub(super) me_single_endpoint_shadow_rotate_every_secs: AtomicU64, +} + +pub(super) struct FloorRuntimeCore { pub(super) me_floor_mode: AtomicU8, pub(super) me_adaptive_floor_idle_secs: AtomicU64, pub(super) me_adaptive_floor_min_writers_single_endpoint: AtomicU8, @@ -209,78 +385,63 @@ pub struct MePool { pub(super) me_adaptive_floor_warm_cap_effective: AtomicU64, pub(super) me_adaptive_floor_active_writers_current: AtomicU64, pub(super) me_adaptive_floor_warm_writers_current: AtomicU64, +} + +pub(super) struct WriterSelectionPolicyCore { + pub(super) secret_atomic_snapshot: AtomicBool, + pub(super) me_deterministic_writer_sort: AtomicBool, + pub(super) me_writer_pick_mode: AtomicU8, + pub(super) me_writer_pick_sample_size: AtomicU8, +} + +pub(super) struct TransportPolicyCore { + pub(super) me_socks_kdf_policy: AtomicU8, + pub(super) me_reader_route_data_wait_ms: Arc, +} + +#[allow(dead_code)] +pub struct MePool { + pub(super) routing: Arc, + pub(super) reinit: Arc, + pub(super) writer_lifecycle: Arc, + pub(super) route_runtime: Arc, + pub(super) health_runtime: Arc, + pub(super) drain_runtime: Arc, + pub(super) single_endpoint_runtime: Arc, + pub(super) binding_policy: Arc, + pub(super) nat_runtime: Arc, + pub(super) reconnect_runtime: Arc, + pub(super) floor_runtime: Arc, + pub(super) writer_selection_policy: Arc, + pub(super) transport_policy: Arc, + pub(super) decision: NetworkDecision, + pub(super) upstream: Option>, + pub(super) rng: Arc, + pub(super) proxy_tag: Option>, + pub(super) proxy_secret: Arc>, pub(super) proxy_map_v4: Arc>>>, pub(super) proxy_map_v6: Arc>>>, pub(super) endpoint_dc_map: Arc>>>, pub(super) default_dc: AtomicI32, pub(super) next_writer_id: AtomicU64, - pub(super) ping_tracker: Arc>>, - pub(super) ping_tracker_last_cleanup_epoch_ms: AtomicU64, pub(super) rtt_stats: Arc>>, - pub(super) nat_reflection_cache: Arc>, - pub(super) nat_reflection_singleflight_v4: Arc>, - pub(super) nat_reflection_singleflight_v6: Arc>, - pub(super) writer_available: Arc, pub(super) refill_inflight: Arc>>, pub(super) refill_inflight_dc: Arc>>, pub(super) conn_count: AtomicUsize, pub(super) draining_active_runtime: AtomicU64, pub(super) stats: Arc, - pub(super) generation: AtomicU64, - pub(super) active_generation: AtomicU64, - pub(super) warm_generation: AtomicU64, - pub(super) pending_hardswap_generation: AtomicU64, - pub(super) pending_hardswap_started_at_epoch_secs: AtomicU64, - pub(super) pending_hardswap_map_hash: AtomicU64, - pub(super) hardswap: AtomicBool, pub(super) endpoint_quarantine: Arc>>, pub(super) kdf_material_fingerprint: Arc>>, - pub(super) me_pool_drain_ttl_secs: AtomicU64, - pub(super) me_instadrain: AtomicBool, - pub(super) me_pool_drain_threshold: AtomicU64, - pub(super) me_pool_drain_soft_evict_enabled: AtomicBool, - pub(super) me_pool_drain_soft_evict_grace_secs: AtomicU64, - pub(super) me_pool_drain_soft_evict_per_writer: AtomicU8, - pub(super) me_pool_drain_soft_evict_budget_per_core: AtomicU32, - pub(super) me_pool_drain_soft_evict_cooldown_ms: AtomicU64, - pub(super) me_pool_force_close_secs: AtomicU64, - pub(super) me_pool_min_fresh_ratio_permille: AtomicU32, - pub(super) me_hardswap_warmup_delay_min_ms: AtomicU64, - pub(super) me_hardswap_warmup_delay_max_ms: AtomicU64, - pub(super) me_hardswap_warmup_extra_passes: AtomicU32, - pub(super) me_hardswap_warmup_pass_backoff_base_ms: AtomicU64, - pub(super) me_bind_stale_mode: AtomicU8, - pub(super) me_bind_stale_ttl_secs: AtomicU64, - pub(super) secret_atomic_snapshot: AtomicBool, - pub(super) me_deterministic_writer_sort: AtomicBool, - pub(super) me_writer_pick_mode: AtomicU8, - pub(super) me_writer_pick_sample_size: AtomicU8, - pub(super) me_socks_kdf_policy: AtomicU8, - pub(super) me_reader_route_data_wait_ms: Arc, - pub(super) me_route_no_writer_mode: AtomicU8, - pub(super) me_route_no_writer_wait: Duration, - pub(super) me_route_inline_recovery_attempts: u32, - pub(super) me_route_inline_recovery_wait: Duration, - pub(super) me_health_interval_ms_unhealthy: AtomicU64, - pub(super) me_health_interval_ms_healthy: AtomicU64, - pub(super) me_warn_rate_limit_ms: AtomicU64, - pub(super) me_family_v4_runtime_state: AtomicU8, - pub(super) me_family_v6_runtime_state: AtomicU8, - pub(super) me_family_v4_state_since_epoch_secs: AtomicU64, - pub(super) me_family_v6_state_since_epoch_secs: AtomicU64, - pub(super) me_family_v4_suppressed_until_epoch_secs: AtomicU64, - pub(super) me_family_v6_suppressed_until_epoch_secs: AtomicU64, - pub(super) me_family_v4_fail_streak: AtomicU32, - pub(super) me_family_v6_fail_streak: AtomicU32, - pub(super) me_family_v4_recover_success_streak: AtomicU32, - pub(super) me_family_v6_recover_success_streak: AtomicU32, - pub(super) me_last_drain_gate_route_quorum_ok: AtomicBool, - pub(super) me_last_drain_gate_redundancy_ok: AtomicBool, - pub(super) me_last_drain_gate_block_reason: AtomicU8, - pub(super) me_last_drain_gate_updated_at_epoch_secs: AtomicU64, pub(super) runtime_ready: AtomicBool, pool_size: usize, - pub(super) preferred_endpoints_by_dc: Arc>>>, +} + +impl Deref for MePool { + type Target = RoutingCore; + + fn deref(&self) -> &Self::Target { + self.routing.as_ref() + } } #[derive(Debug, Default)] @@ -396,6 +557,8 @@ impl MePool { me_warn_rate_limit_ms: u64, me_route_no_writer_mode: MeRouteNoWriterMode, me_route_no_writer_wait_ms: u64, + me_route_hybrid_max_wait_ms: u64, + me_route_blocking_send_timeout_ms: u64, me_route_inline_recovery_attempts: u32, me_route_inline_recovery_wait_ms: u64, ) -> Arc { @@ -410,10 +573,220 @@ impl MePool { me_route_backpressure_high_timeout_ms, me_route_backpressure_high_watermark_pct, ); + let (writer_epoch, _) = watch::channel(0u64); + let now_epoch_secs = Self::now_epoch_secs(); Arc::new(Self { - registry, - writers: Arc::new(RwLock::new(Vec::new())), - rr: AtomicU64::new(0), + routing: Arc::new(RoutingCore { + registry, + writers: Arc::new(WritersState::new()), + rr: AtomicU64::new(0), + writer_epoch, + preferred_endpoints_by_dc: ArcSwap::from_pointee(preferred_endpoints_by_dc), + }), + reinit: Arc::new(ReinitCore { + generation: AtomicU64::new(1), + active_generation: AtomicU64::new(1), + warm_generation: AtomicU64::new(0), + pending_hardswap_generation: AtomicU64::new(0), + pending_hardswap_started_at_epoch_secs: AtomicU64::new(0), + pending_hardswap_map_hash: AtomicU64::new(0), + hardswap: AtomicBool::new(hardswap), + me_hardswap_warmup_delay_min_ms: AtomicU64::new(me_hardswap_warmup_delay_min_ms), + me_hardswap_warmup_delay_max_ms: AtomicU64::new(me_hardswap_warmup_delay_max_ms), + me_hardswap_warmup_extra_passes: AtomicU32::new( + me_hardswap_warmup_extra_passes as u32, + ), + me_hardswap_warmup_pass_backoff_base_ms: AtomicU64::new( + me_hardswap_warmup_pass_backoff_base_ms, + ), + }), + writer_lifecycle: Arc::new(WriterLifecycleCore { + me_keepalive_enabled, + me_keepalive_interval: Duration::from_secs(me_keepalive_interval_secs), + me_keepalive_jitter: Duration::from_secs(me_keepalive_jitter_secs), + me_keepalive_payload_random, + rpc_proxy_req_every_secs: AtomicU64::new(rpc_proxy_req_every_secs), + writer_cmd_channel_capacity: me_writer_cmd_channel_capacity.max(1), + }), + route_runtime: Arc::new(RouteRuntimeCore { + me_route_no_writer_mode: AtomicU8::new(me_route_no_writer_mode.as_u8()), + me_route_no_writer_wait: Duration::from_millis(me_route_no_writer_wait_ms), + me_route_hybrid_max_wait: Duration::from_millis( + me_route_hybrid_max_wait_ms.max(50), + ), + me_route_blocking_send_timeout: if me_route_blocking_send_timeout_ms == 0 { + None + } else { + Some(Duration::from_millis( + me_route_blocking_send_timeout_ms.min(5_000), + )) + }, + me_route_last_success_epoch_ms: AtomicU64::new(0), + me_route_hybrid_timeout_warn_epoch_ms: AtomicU64::new(0), + me_async_recovery_last_trigger_epoch_ms: AtomicU64::new(0), + me_route_inline_recovery_attempts, + me_route_inline_recovery_wait: Duration::from_millis( + me_route_inline_recovery_wait_ms, + ), + }), + health_runtime: Arc::new(HealthRuntimeCore { + me_health_interval_ms_unhealthy: AtomicU64::new( + me_health_interval_ms_unhealthy.max(1), + ), + me_health_interval_ms_healthy: AtomicU64::new(me_health_interval_ms_healthy.max(1)), + me_warn_rate_limit_ms: AtomicU64::new(me_warn_rate_limit_ms.max(1)), + family_health_v4: ArcSwap::from_pointee(FamilyHealthSnapshot::new( + MeFamilyRuntimeState::Healthy, + now_epoch_secs, + 0, + 0, + 0, + )), + family_health_v6: ArcSwap::from_pointee(FamilyHealthSnapshot::new( + MeFamilyRuntimeState::Healthy, + now_epoch_secs, + 0, + 0, + 0, + )), + }), + drain_runtime: Arc::new(DrainRuntimeCore { + me_pool_drain_ttl_secs: AtomicU64::new(me_pool_drain_ttl_secs), + me_instadrain: AtomicBool::new(me_instadrain), + me_pool_drain_threshold: AtomicU64::new(me_pool_drain_threshold), + me_pool_drain_soft_evict_enabled: AtomicBool::new(me_pool_drain_soft_evict_enabled), + me_pool_drain_soft_evict_grace_secs: AtomicU64::new( + me_pool_drain_soft_evict_grace_secs, + ), + me_pool_drain_soft_evict_per_writer: AtomicU8::new( + me_pool_drain_soft_evict_per_writer.max(1), + ), + me_pool_drain_soft_evict_budget_per_core: AtomicU32::new( + me_pool_drain_soft_evict_budget_per_core.max(1) as u32, + ), + me_pool_drain_soft_evict_cooldown_ms: AtomicU64::new( + me_pool_drain_soft_evict_cooldown_ms.max(1), + ), + me_pool_force_close_secs: AtomicU64::new(Self::normalize_force_close_secs( + me_pool_force_close_secs, + )), + me_pool_min_fresh_ratio_permille: AtomicU32::new(Self::ratio_to_permille( + me_pool_min_fresh_ratio, + )), + me_last_drain_gate_route_quorum_ok: AtomicBool::new(false), + me_last_drain_gate_redundancy_ok: AtomicBool::new(false), + me_last_drain_gate_block_reason: AtomicU8::new(MeDrainGateReason::Open as u8), + me_last_drain_gate_updated_at_epoch_secs: AtomicU64::new(now_epoch_secs), + }), + single_endpoint_runtime: Arc::new(SingleEndpointRuntimeCore { + me_single_endpoint_shadow_writers: AtomicU8::new(me_single_endpoint_shadow_writers), + me_single_endpoint_outage_mode_enabled: AtomicBool::new( + me_single_endpoint_outage_mode_enabled, + ), + me_single_endpoint_outage_disable_quarantine: AtomicBool::new( + me_single_endpoint_outage_disable_quarantine, + ), + me_single_endpoint_outage_backoff_min_ms: AtomicU64::new( + me_single_endpoint_outage_backoff_min_ms, + ), + me_single_endpoint_outage_backoff_max_ms: AtomicU64::new( + me_single_endpoint_outage_backoff_max_ms, + ), + me_single_endpoint_shadow_rotate_every_secs: AtomicU64::new( + me_single_endpoint_shadow_rotate_every_secs, + ), + }), + binding_policy: Arc::new(BindingPolicyCore { + me_bind_stale_mode: AtomicU8::new(me_bind_stale_mode.as_u8()), + me_bind_stale_ttl_secs: AtomicU64::new(me_bind_stale_ttl_secs), + }), + nat_runtime: Arc::new(NatRuntimeCore { + nat_ip_cfg: nat_ip, + nat_ip_detected: Arc::new(RwLock::new(None)), + nat_probe, + nat_stun, + nat_stun_servers, + nat_stun_live_servers: Arc::new(RwLock::new(Vec::new())), + nat_probe_concurrency: nat_probe_concurrency.max(1), + detected_ipv6, + nat_probe_attempts: std::sync::atomic::AtomicU8::new(0), + nat_probe_disabled: std::sync::atomic::AtomicBool::new(false), + stun_backoff_until: Arc::new(RwLock::new(None)), + nat_reflection_cache: Arc::new(Mutex::new(NatReflectionCache::default())), + nat_reflection_singleflight_v4: Arc::new(Mutex::new(())), + nat_reflection_singleflight_v6: Arc::new(Mutex::new(())), + }), + reconnect_runtime: Arc::new(ReconnectRuntimeCore { + me_one_retry, + me_one_timeout: Duration::from_millis(me_one_timeout_ms), + me_warmup_stagger_enabled, + me_warmup_step_delay: Duration::from_millis(me_warmup_step_delay_ms), + me_warmup_step_jitter: Duration::from_millis(me_warmup_step_jitter_ms), + me_reconnect_max_concurrent_per_dc, + me_reconnect_backoff_base: Duration::from_millis(me_reconnect_backoff_base_ms), + me_reconnect_backoff_cap: Duration::from_millis(me_reconnect_backoff_cap_ms), + me_reconnect_fast_retry_count, + }), + floor_runtime: Arc::new(FloorRuntimeCore { + me_floor_mode: AtomicU8::new(me_floor_mode.as_u8()), + me_adaptive_floor_idle_secs: AtomicU64::new(me_adaptive_floor_idle_secs), + me_adaptive_floor_min_writers_single_endpoint: AtomicU8::new( + me_adaptive_floor_min_writers_single_endpoint, + ), + me_adaptive_floor_min_writers_multi_endpoint: AtomicU8::new( + me_adaptive_floor_min_writers_multi_endpoint, + ), + me_adaptive_floor_recover_grace_secs: AtomicU64::new( + me_adaptive_floor_recover_grace_secs, + ), + me_adaptive_floor_writers_per_core_total: AtomicU32::new( + me_adaptive_floor_writers_per_core_total as u32, + ), + me_adaptive_floor_cpu_cores_override: AtomicU32::new( + me_adaptive_floor_cpu_cores_override as u32, + ), + me_adaptive_floor_max_extra_writers_single_per_core: AtomicU32::new( + me_adaptive_floor_max_extra_writers_single_per_core as u32, + ), + me_adaptive_floor_max_extra_writers_multi_per_core: AtomicU32::new( + me_adaptive_floor_max_extra_writers_multi_per_core as u32, + ), + me_adaptive_floor_max_active_writers_per_core: AtomicU32::new( + me_adaptive_floor_max_active_writers_per_core as u32, + ), + me_adaptive_floor_max_warm_writers_per_core: AtomicU32::new( + me_adaptive_floor_max_warm_writers_per_core as u32, + ), + me_adaptive_floor_max_active_writers_global: AtomicU32::new( + me_adaptive_floor_max_active_writers_global, + ), + me_adaptive_floor_max_warm_writers_global: AtomicU32::new( + me_adaptive_floor_max_warm_writers_global, + ), + me_adaptive_floor_cpu_cores_detected: AtomicU32::new(1), + me_adaptive_floor_cpu_cores_effective: AtomicU32::new(1), + me_adaptive_floor_global_cap_raw: AtomicU64::new(0), + me_adaptive_floor_global_cap_effective: AtomicU64::new(0), + me_adaptive_floor_target_writers_total: AtomicU64::new(0), + me_adaptive_floor_active_cap_configured: AtomicU64::new(0), + me_adaptive_floor_active_cap_effective: AtomicU64::new(0), + me_adaptive_floor_warm_cap_configured: AtomicU64::new(0), + me_adaptive_floor_warm_cap_effective: AtomicU64::new(0), + me_adaptive_floor_active_writers_current: AtomicU64::new(0), + me_adaptive_floor_warm_writers_current: AtomicU64::new(0), + }), + writer_selection_policy: Arc::new(WriterSelectionPolicyCore { + secret_atomic_snapshot: AtomicBool::new(me_secret_atomic_snapshot), + me_deterministic_writer_sort: AtomicBool::new(me_deterministic_writer_sort), + me_writer_pick_mode: AtomicU8::new(me_writer_pick_mode.as_u8()), + me_writer_pick_sample_size: AtomicU8::new(me_writer_pick_sample_size.clamp(2, 4)), + }), + transport_policy: Arc::new(TransportPolicyCore { + me_socks_kdf_policy: AtomicU8::new(me_socks_kdf_policy.as_u8()), + me_reader_route_data_wait_ms: Arc::new(AtomicU64::new( + me_reader_route_data_wait_ms, + )), + }), decision, upstream, rng, @@ -432,185 +805,26 @@ impl MePool { }, secret: proxy_secret, })), - nat_ip_cfg: nat_ip, - nat_ip_detected: Arc::new(RwLock::new(None)), - nat_probe, - nat_stun, - nat_stun_servers, - nat_stun_live_servers: Arc::new(RwLock::new(Vec::new())), - nat_probe_concurrency: nat_probe_concurrency.max(1), - detected_ipv6, - nat_probe_attempts: std::sync::atomic::AtomicU8::new(0), - nat_probe_disabled: std::sync::atomic::AtomicBool::new(false), - stun_backoff_until: Arc::new(RwLock::new(None)), - me_one_retry, - me_one_timeout: Duration::from_millis(me_one_timeout_ms), stats, - me_keepalive_enabled, - me_keepalive_interval: Duration::from_secs(me_keepalive_interval_secs), - me_keepalive_jitter: Duration::from_secs(me_keepalive_jitter_secs), - me_keepalive_payload_random, - rpc_proxy_req_every_secs: AtomicU64::new(rpc_proxy_req_every_secs), - writer_cmd_channel_capacity: me_writer_cmd_channel_capacity.max(1), - me_warmup_stagger_enabled, - me_warmup_step_delay: Duration::from_millis(me_warmup_step_delay_ms), - me_warmup_step_jitter: Duration::from_millis(me_warmup_step_jitter_ms), - me_reconnect_max_concurrent_per_dc, - me_reconnect_backoff_base: Duration::from_millis(me_reconnect_backoff_base_ms), - me_reconnect_backoff_cap: Duration::from_millis(me_reconnect_backoff_cap_ms), - me_reconnect_fast_retry_count, - me_single_endpoint_shadow_writers: AtomicU8::new(me_single_endpoint_shadow_writers), - me_single_endpoint_outage_mode_enabled: AtomicBool::new( - me_single_endpoint_outage_mode_enabled, - ), - me_single_endpoint_outage_disable_quarantine: AtomicBool::new( - me_single_endpoint_outage_disable_quarantine, - ), - me_single_endpoint_outage_backoff_min_ms: AtomicU64::new( - me_single_endpoint_outage_backoff_min_ms, - ), - me_single_endpoint_outage_backoff_max_ms: AtomicU64::new( - me_single_endpoint_outage_backoff_max_ms, - ), - me_single_endpoint_shadow_rotate_every_secs: AtomicU64::new( - me_single_endpoint_shadow_rotate_every_secs, - ), - me_floor_mode: AtomicU8::new(me_floor_mode.as_u8()), - me_adaptive_floor_idle_secs: AtomicU64::new(me_adaptive_floor_idle_secs), - me_adaptive_floor_min_writers_single_endpoint: AtomicU8::new( - me_adaptive_floor_min_writers_single_endpoint, - ), - me_adaptive_floor_min_writers_multi_endpoint: AtomicU8::new( - me_adaptive_floor_min_writers_multi_endpoint, - ), - me_adaptive_floor_recover_grace_secs: AtomicU64::new( - me_adaptive_floor_recover_grace_secs, - ), - me_adaptive_floor_writers_per_core_total: AtomicU32::new( - me_adaptive_floor_writers_per_core_total as u32, - ), - me_adaptive_floor_cpu_cores_override: AtomicU32::new( - me_adaptive_floor_cpu_cores_override as u32, - ), - me_adaptive_floor_max_extra_writers_single_per_core: AtomicU32::new( - me_adaptive_floor_max_extra_writers_single_per_core as u32, - ), - me_adaptive_floor_max_extra_writers_multi_per_core: AtomicU32::new( - me_adaptive_floor_max_extra_writers_multi_per_core as u32, - ), - me_adaptive_floor_max_active_writers_per_core: AtomicU32::new( - me_adaptive_floor_max_active_writers_per_core as u32, - ), - me_adaptive_floor_max_warm_writers_per_core: AtomicU32::new( - me_adaptive_floor_max_warm_writers_per_core as u32, - ), - me_adaptive_floor_max_active_writers_global: AtomicU32::new( - me_adaptive_floor_max_active_writers_global, - ), - me_adaptive_floor_max_warm_writers_global: AtomicU32::new( - me_adaptive_floor_max_warm_writers_global, - ), - me_adaptive_floor_cpu_cores_detected: AtomicU32::new(1), - me_adaptive_floor_cpu_cores_effective: AtomicU32::new(1), - me_adaptive_floor_global_cap_raw: AtomicU64::new(0), - me_adaptive_floor_global_cap_effective: AtomicU64::new(0), - me_adaptive_floor_target_writers_total: AtomicU64::new(0), - me_adaptive_floor_active_cap_configured: AtomicU64::new(0), - me_adaptive_floor_active_cap_effective: AtomicU64::new(0), - me_adaptive_floor_warm_cap_configured: AtomicU64::new(0), - me_adaptive_floor_warm_cap_effective: AtomicU64::new(0), - me_adaptive_floor_active_writers_current: AtomicU64::new(0), - me_adaptive_floor_warm_writers_current: AtomicU64::new(0), pool_size: 2, proxy_map_v4: Arc::new(RwLock::new(proxy_map_v4)), proxy_map_v6: Arc::new(RwLock::new(proxy_map_v6)), endpoint_dc_map: Arc::new(RwLock::new(endpoint_dc_map)), default_dc: AtomicI32::new(default_dc.unwrap_or(2)), next_writer_id: AtomicU64::new(1), - ping_tracker: Arc::new(Mutex::new(HashMap::new())), - ping_tracker_last_cleanup_epoch_ms: AtomicU64::new(0), rtt_stats: Arc::new(Mutex::new(HashMap::new())), - nat_reflection_cache: Arc::new(Mutex::new(NatReflectionCache::default())), - nat_reflection_singleflight_v4: Arc::new(Mutex::new(())), - nat_reflection_singleflight_v6: Arc::new(Mutex::new(())), - writer_available: Arc::new(Notify::new()), refill_inflight: Arc::new(Mutex::new(HashSet::new())), refill_inflight_dc: Arc::new(Mutex::new(HashSet::new())), conn_count: AtomicUsize::new(0), draining_active_runtime: AtomicU64::new(0), - generation: AtomicU64::new(1), - active_generation: AtomicU64::new(1), - warm_generation: AtomicU64::new(0), - pending_hardswap_generation: AtomicU64::new(0), - pending_hardswap_started_at_epoch_secs: AtomicU64::new(0), - pending_hardswap_map_hash: AtomicU64::new(0), - hardswap: AtomicBool::new(hardswap), endpoint_quarantine: Arc::new(Mutex::new(HashMap::new())), kdf_material_fingerprint: Arc::new(RwLock::new(HashMap::new())), - me_pool_drain_ttl_secs: AtomicU64::new(me_pool_drain_ttl_secs), - me_instadrain: AtomicBool::new(me_instadrain), - me_pool_drain_threshold: AtomicU64::new(me_pool_drain_threshold), - me_pool_drain_soft_evict_enabled: AtomicBool::new(me_pool_drain_soft_evict_enabled), - me_pool_drain_soft_evict_grace_secs: AtomicU64::new( - me_pool_drain_soft_evict_grace_secs, - ), - me_pool_drain_soft_evict_per_writer: AtomicU8::new( - me_pool_drain_soft_evict_per_writer.max(1), - ), - me_pool_drain_soft_evict_budget_per_core: AtomicU32::new( - me_pool_drain_soft_evict_budget_per_core.max(1) as u32, - ), - me_pool_drain_soft_evict_cooldown_ms: AtomicU64::new( - me_pool_drain_soft_evict_cooldown_ms.max(1), - ), - me_pool_force_close_secs: AtomicU64::new(Self::normalize_force_close_secs( - me_pool_force_close_secs, - )), - me_pool_min_fresh_ratio_permille: AtomicU32::new(Self::ratio_to_permille( - me_pool_min_fresh_ratio, - )), - me_hardswap_warmup_delay_min_ms: AtomicU64::new(me_hardswap_warmup_delay_min_ms), - me_hardswap_warmup_delay_max_ms: AtomicU64::new(me_hardswap_warmup_delay_max_ms), - me_hardswap_warmup_extra_passes: AtomicU32::new(me_hardswap_warmup_extra_passes as u32), - me_hardswap_warmup_pass_backoff_base_ms: AtomicU64::new( - me_hardswap_warmup_pass_backoff_base_ms, - ), - me_bind_stale_mode: AtomicU8::new(me_bind_stale_mode.as_u8()), - me_bind_stale_ttl_secs: AtomicU64::new(me_bind_stale_ttl_secs), - secret_atomic_snapshot: AtomicBool::new(me_secret_atomic_snapshot), - me_deterministic_writer_sort: AtomicBool::new(me_deterministic_writer_sort), - me_writer_pick_mode: AtomicU8::new(me_writer_pick_mode.as_u8()), - me_writer_pick_sample_size: AtomicU8::new(me_writer_pick_sample_size.clamp(2, 4)), - me_socks_kdf_policy: AtomicU8::new(me_socks_kdf_policy.as_u8()), - me_reader_route_data_wait_ms: Arc::new(AtomicU64::new(me_reader_route_data_wait_ms)), - me_route_no_writer_mode: AtomicU8::new(me_route_no_writer_mode.as_u8()), - me_route_no_writer_wait: Duration::from_millis(me_route_no_writer_wait_ms), - me_route_inline_recovery_attempts, - me_route_inline_recovery_wait: Duration::from_millis(me_route_inline_recovery_wait_ms), - me_health_interval_ms_unhealthy: AtomicU64::new(me_health_interval_ms_unhealthy.max(1)), - me_health_interval_ms_healthy: AtomicU64::new(me_health_interval_ms_healthy.max(1)), - me_warn_rate_limit_ms: AtomicU64::new(me_warn_rate_limit_ms.max(1)), - me_family_v4_runtime_state: AtomicU8::new(MeFamilyRuntimeState::Healthy as u8), - me_family_v6_runtime_state: AtomicU8::new(MeFamilyRuntimeState::Healthy as u8), - me_family_v4_state_since_epoch_secs: AtomicU64::new(Self::now_epoch_secs()), - me_family_v6_state_since_epoch_secs: AtomicU64::new(Self::now_epoch_secs()), - me_family_v4_suppressed_until_epoch_secs: AtomicU64::new(0), - me_family_v6_suppressed_until_epoch_secs: AtomicU64::new(0), - me_family_v4_fail_streak: AtomicU32::new(0), - me_family_v6_fail_streak: AtomicU32::new(0), - me_family_v4_recover_success_streak: AtomicU32::new(0), - me_family_v6_recover_success_streak: AtomicU32::new(0), - me_last_drain_gate_route_quorum_ok: AtomicBool::new(false), - me_last_drain_gate_redundancy_ok: AtomicBool::new(false), - me_last_drain_gate_block_reason: AtomicU8::new(MeDrainGateReason::Open as u8), - me_last_drain_gate_updated_at_epoch_secs: AtomicU64::new(Self::now_epoch_secs()), runtime_ready: AtomicBool::new(false), - preferred_endpoints_by_dc: Arc::new(RwLock::new(preferred_endpoints_by_dc)), }) } pub fn current_generation(&self) -> u64 { - self.active_generation.load(Ordering::Relaxed) + self.reinit.active_generation.load(Ordering::Relaxed) } pub fn set_runtime_ready(&self, ready: bool) { @@ -621,7 +835,19 @@ impl MePool { self.runtime_ready.load(Ordering::Relaxed) } - #[allow(dead_code)] + pub(super) fn now_epoch_millis() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_millis() as u64 + } + + pub(super) fn notify_writer_epoch(&self) { + self.writer_epoch.send_modify(|epoch| { + *epoch = epoch.wrapping_add(1); + }); + } + pub(super) fn set_family_runtime_state( &self, family: IpFamily, @@ -631,82 +857,81 @@ impl MePool { fail_streak: u32, recover_success_streak: u32, ) { + let snapshot = Arc::new(FamilyHealthSnapshot::new( + state, + state_since_epoch_secs, + suppressed_until_epoch_secs, + fail_streak, + recover_success_streak, + )); match family { - IpFamily::V4 => { - self.me_family_v4_runtime_state - .store(state as u8, Ordering::Relaxed); - self.me_family_v4_state_since_epoch_secs - .store(state_since_epoch_secs, Ordering::Relaxed); - self.me_family_v4_suppressed_until_epoch_secs - .store(suppressed_until_epoch_secs, Ordering::Relaxed); - self.me_family_v4_fail_streak - .store(fail_streak, Ordering::Relaxed); - self.me_family_v4_recover_success_streak - .store(recover_success_streak, Ordering::Relaxed); - } - IpFamily::V6 => { - self.me_family_v6_runtime_state - .store(state as u8, Ordering::Relaxed); - self.me_family_v6_state_since_epoch_secs - .store(state_since_epoch_secs, Ordering::Relaxed); - self.me_family_v6_suppressed_until_epoch_secs - .store(suppressed_until_epoch_secs, Ordering::Relaxed); - self.me_family_v6_fail_streak - .store(fail_streak, Ordering::Relaxed); - self.me_family_v6_recover_success_streak - .store(recover_success_streak, Ordering::Relaxed); - } + IpFamily::V4 => self.health_runtime.family_health_v4.store(snapshot), + IpFamily::V6 => self.health_runtime.family_health_v6.store(snapshot), } } pub(crate) fn family_runtime_state(&self, family: IpFamily) -> MeFamilyRuntimeState { match family { - IpFamily::V4 => MeFamilyRuntimeState::from_u8( - self.me_family_v4_runtime_state.load(Ordering::Relaxed), - ), - IpFamily::V6 => MeFamilyRuntimeState::from_u8( - self.me_family_v6_runtime_state.load(Ordering::Relaxed), - ), + IpFamily::V4 => self.health_runtime.family_health_v4.load().state, + IpFamily::V6 => self.health_runtime.family_health_v6.load().state, } } pub(crate) fn family_runtime_state_since_epoch_secs(&self, family: IpFamily) -> u64 { match family { - IpFamily::V4 => self - .me_family_v4_state_since_epoch_secs - .load(Ordering::Relaxed), - IpFamily::V6 => self - .me_family_v6_state_since_epoch_secs - .load(Ordering::Relaxed), + IpFamily::V4 => { + self.health_runtime + .family_health_v4 + .load() + .state_since_epoch_secs + } + IpFamily::V6 => { + self.health_runtime + .family_health_v6 + .load() + .state_since_epoch_secs + } } } pub(crate) fn family_suppressed_until_epoch_secs(&self, family: IpFamily) -> u64 { match family { - IpFamily::V4 => self - .me_family_v4_suppressed_until_epoch_secs - .load(Ordering::Relaxed), - IpFamily::V6 => self - .me_family_v6_suppressed_until_epoch_secs - .load(Ordering::Relaxed), + IpFamily::V4 => { + self.health_runtime + .family_health_v4 + .load() + .suppressed_until_epoch_secs + } + IpFamily::V6 => { + self.health_runtime + .family_health_v6 + .load() + .suppressed_until_epoch_secs + } } } pub(crate) fn family_fail_streak(&self, family: IpFamily) -> u32 { match family { - IpFamily::V4 => self.me_family_v4_fail_streak.load(Ordering::Relaxed), - IpFamily::V6 => self.me_family_v6_fail_streak.load(Ordering::Relaxed), + IpFamily::V4 => self.health_runtime.family_health_v4.load().fail_streak, + IpFamily::V6 => self.health_runtime.family_health_v6.load().fail_streak, } } pub(crate) fn family_recover_success_streak(&self, family: IpFamily) -> u32 { match family { - IpFamily::V4 => self - .me_family_v4_recover_success_streak - .load(Ordering::Relaxed), - IpFamily::V6 => self - .me_family_v6_recover_success_streak - .load(Ordering::Relaxed), + IpFamily::V4 => { + self.health_runtime + .family_health_v4 + .load() + .recover_success_streak + } + IpFamily::V6 => { + self.health_runtime + .family_health_v6 + .load() + .recover_success_streak + } } } @@ -737,32 +962,43 @@ impl MePool { block_reason: MeDrainGateReason, updated_at_epoch_secs: u64, ) { - self.me_last_drain_gate_route_quorum_ok + self.drain_runtime + .me_last_drain_gate_route_quorum_ok .store(route_quorum_ok, Ordering::Relaxed); - self.me_last_drain_gate_redundancy_ok + self.drain_runtime + .me_last_drain_gate_redundancy_ok .store(redundancy_ok, Ordering::Relaxed); - self.me_last_drain_gate_block_reason + self.drain_runtime + .me_last_drain_gate_block_reason .store(block_reason as u8, Ordering::Relaxed); - self.me_last_drain_gate_updated_at_epoch_secs + self.drain_runtime + .me_last_drain_gate_updated_at_epoch_secs .store(updated_at_epoch_secs, Ordering::Relaxed); } pub(crate) fn last_drain_gate_route_quorum_ok(&self) -> bool { - self.me_last_drain_gate_route_quorum_ok + self.drain_runtime + .me_last_drain_gate_route_quorum_ok .load(Ordering::Relaxed) } pub(crate) fn last_drain_gate_redundancy_ok(&self) -> bool { - self.me_last_drain_gate_redundancy_ok + self.drain_runtime + .me_last_drain_gate_redundancy_ok .load(Ordering::Relaxed) } pub(crate) fn last_drain_gate_block_reason(&self) -> MeDrainGateReason { - MeDrainGateReason::from_u8(self.me_last_drain_gate_block_reason.load(Ordering::Relaxed)) + MeDrainGateReason::from_u8( + self.drain_runtime + .me_last_drain_gate_block_reason + .load(Ordering::Relaxed), + ) } pub(crate) fn last_drain_gate_updated_at_epoch_secs(&self) -> u64 { - self.me_last_drain_gate_updated_at_epoch_secs + self.drain_runtime + .me_last_drain_gate_updated_at_epoch_secs .load(Ordering::Relaxed) } @@ -812,112 +1048,162 @@ impl MePool { me_health_interval_ms_healthy: u64, me_warn_rate_limit_ms: u64, ) { - self.hardswap.store(hardswap, Ordering::Relaxed); - self.me_pool_drain_ttl_secs + self.reinit.hardswap.store(hardswap, Ordering::Relaxed); + self.drain_runtime + .me_pool_drain_ttl_secs .store(drain_ttl_secs, Ordering::Relaxed); - self.me_instadrain.store(instadrain, Ordering::Relaxed); - self.me_pool_drain_threshold + self.drain_runtime + .me_instadrain + .store(instadrain, Ordering::Relaxed); + self.drain_runtime + .me_pool_drain_threshold .store(pool_drain_threshold, Ordering::Relaxed); - self.me_pool_drain_soft_evict_enabled + // Runtime soft-evict knobs are updated lock-free to keep control-plane + // writes non-blocking; readers observe a short eventual-consistency + // window by design. + self.drain_runtime + .me_pool_drain_soft_evict_enabled .store(pool_drain_soft_evict_enabled, Ordering::Relaxed); - self.me_pool_drain_soft_evict_grace_secs + self.drain_runtime + .me_pool_drain_soft_evict_grace_secs .store(pool_drain_soft_evict_grace_secs, Ordering::Relaxed); - self.me_pool_drain_soft_evict_per_writer + self.drain_runtime + .me_pool_drain_soft_evict_per_writer .store(pool_drain_soft_evict_per_writer.max(1), Ordering::Relaxed); - self.me_pool_drain_soft_evict_budget_per_core.store( - pool_drain_soft_evict_budget_per_core.max(1) as u32, - Ordering::Relaxed, - ); - self.me_pool_drain_soft_evict_cooldown_ms + self.drain_runtime + .me_pool_drain_soft_evict_budget_per_core + .store( + pool_drain_soft_evict_budget_per_core.max(1) as u32, + Ordering::Relaxed, + ); + self.drain_runtime + .me_pool_drain_soft_evict_cooldown_ms .store(pool_drain_soft_evict_cooldown_ms.max(1), Ordering::Relaxed); - self.me_pool_force_close_secs.store( + self.drain_runtime.me_pool_force_close_secs.store( Self::normalize_force_close_secs(force_close_secs), Ordering::Relaxed, ); - self.me_pool_min_fresh_ratio_permille + self.drain_runtime + .me_pool_min_fresh_ratio_permille .store(Self::ratio_to_permille(min_fresh_ratio), Ordering::Relaxed); - self.me_hardswap_warmup_delay_min_ms + self.reinit + .me_hardswap_warmup_delay_min_ms .store(hardswap_warmup_delay_min_ms, Ordering::Relaxed); - self.me_hardswap_warmup_delay_max_ms + self.reinit + .me_hardswap_warmup_delay_max_ms .store(hardswap_warmup_delay_max_ms, Ordering::Relaxed); - self.me_hardswap_warmup_extra_passes + self.reinit + .me_hardswap_warmup_extra_passes .store(hardswap_warmup_extra_passes as u32, Ordering::Relaxed); - self.me_hardswap_warmup_pass_backoff_base_ms + self.reinit + .me_hardswap_warmup_pass_backoff_base_ms .store(hardswap_warmup_pass_backoff_base_ms, Ordering::Relaxed); - self.me_bind_stale_mode + self.binding_policy + .me_bind_stale_mode .store(bind_stale_mode.as_u8(), Ordering::Relaxed); - self.me_bind_stale_ttl_secs + self.binding_policy + .me_bind_stale_ttl_secs .store(bind_stale_ttl_secs, Ordering::Relaxed); - self.secret_atomic_snapshot + self.writer_selection_policy + .secret_atomic_snapshot .store(secret_atomic_snapshot, Ordering::Relaxed); - self.me_deterministic_writer_sort + self.writer_selection_policy + .me_deterministic_writer_sort .store(deterministic_writer_sort, Ordering::Relaxed); let previous_writer_pick_mode = self.writer_pick_mode(); - self.me_writer_pick_mode + self.writer_selection_policy + .me_writer_pick_mode .store(writer_pick_mode.as_u8(), Ordering::Relaxed); - self.me_writer_pick_sample_size + self.writer_selection_policy + .me_writer_pick_sample_size .store(writer_pick_sample_size.clamp(2, 4), Ordering::Relaxed); if previous_writer_pick_mode != writer_pick_mode { self.stats.increment_me_writer_pick_mode_switch_total(); } - self.me_single_endpoint_shadow_writers + self.single_endpoint_runtime + .me_single_endpoint_shadow_writers .store(single_endpoint_shadow_writers, Ordering::Relaxed); - self.me_single_endpoint_outage_mode_enabled + self.single_endpoint_runtime + .me_single_endpoint_outage_mode_enabled .store(single_endpoint_outage_mode_enabled, Ordering::Relaxed); - self.me_single_endpoint_outage_disable_quarantine + self.single_endpoint_runtime + .me_single_endpoint_outage_disable_quarantine .store(single_endpoint_outage_disable_quarantine, Ordering::Relaxed); - self.me_single_endpoint_outage_backoff_min_ms + self.single_endpoint_runtime + .me_single_endpoint_outage_backoff_min_ms .store(single_endpoint_outage_backoff_min_ms, Ordering::Relaxed); - self.me_single_endpoint_outage_backoff_max_ms + self.single_endpoint_runtime + .me_single_endpoint_outage_backoff_max_ms .store(single_endpoint_outage_backoff_max_ms, Ordering::Relaxed); - self.me_single_endpoint_shadow_rotate_every_secs + self.single_endpoint_runtime + .me_single_endpoint_shadow_rotate_every_secs .store(single_endpoint_shadow_rotate_every_secs, Ordering::Relaxed); let previous_floor_mode = self.floor_mode(); - self.me_floor_mode + self.floor_runtime + .me_floor_mode .store(floor_mode.as_u8(), Ordering::Relaxed); - self.me_adaptive_floor_idle_secs + self.floor_runtime + .me_adaptive_floor_idle_secs .store(adaptive_floor_idle_secs, Ordering::Relaxed); - self.me_adaptive_floor_min_writers_single_endpoint.store( - adaptive_floor_min_writers_single_endpoint, - Ordering::Relaxed, - ); - self.me_adaptive_floor_min_writers_multi_endpoint + self.floor_runtime + .me_adaptive_floor_min_writers_single_endpoint + .store( + adaptive_floor_min_writers_single_endpoint, + Ordering::Relaxed, + ); + self.floor_runtime + .me_adaptive_floor_min_writers_multi_endpoint .store(adaptive_floor_min_writers_multi_endpoint, Ordering::Relaxed); - self.me_adaptive_floor_recover_grace_secs + self.floor_runtime + .me_adaptive_floor_recover_grace_secs .store(adaptive_floor_recover_grace_secs, Ordering::Relaxed); - self.me_adaptive_floor_writers_per_core_total.store( - adaptive_floor_writers_per_core_total as u32, - Ordering::Relaxed, - ); - self.me_adaptive_floor_cpu_cores_override + self.floor_runtime + .me_adaptive_floor_writers_per_core_total + .store( + adaptive_floor_writers_per_core_total as u32, + Ordering::Relaxed, + ); + self.floor_runtime + .me_adaptive_floor_cpu_cores_override .store(adaptive_floor_cpu_cores_override as u32, Ordering::Relaxed); - self.me_adaptive_floor_max_extra_writers_single_per_core + self.floor_runtime + .me_adaptive_floor_max_extra_writers_single_per_core .store( adaptive_floor_max_extra_writers_single_per_core as u32, Ordering::Relaxed, ); - self.me_adaptive_floor_max_extra_writers_multi_per_core + self.floor_runtime + .me_adaptive_floor_max_extra_writers_multi_per_core .store( adaptive_floor_max_extra_writers_multi_per_core as u32, Ordering::Relaxed, ); - self.me_adaptive_floor_max_active_writers_per_core.store( - adaptive_floor_max_active_writers_per_core as u32, - Ordering::Relaxed, - ); - self.me_adaptive_floor_max_warm_writers_per_core.store( - adaptive_floor_max_warm_writers_per_core as u32, - Ordering::Relaxed, - ); - self.me_adaptive_floor_max_active_writers_global + self.floor_runtime + .me_adaptive_floor_max_active_writers_per_core + .store( + adaptive_floor_max_active_writers_per_core as u32, + Ordering::Relaxed, + ); + self.floor_runtime + .me_adaptive_floor_max_warm_writers_per_core + .store( + adaptive_floor_max_warm_writers_per_core as u32, + Ordering::Relaxed, + ); + self.floor_runtime + .me_adaptive_floor_max_active_writers_global .store(adaptive_floor_max_active_writers_global, Ordering::Relaxed); - self.me_adaptive_floor_max_warm_writers_global + self.floor_runtime + .me_adaptive_floor_max_warm_writers_global .store(adaptive_floor_max_warm_writers_global, Ordering::Relaxed); - self.me_health_interval_ms_unhealthy + self.health_runtime + .me_health_interval_ms_unhealthy .store(me_health_interval_ms_unhealthy.max(1), Ordering::Relaxed); - self.me_health_interval_ms_healthy + self.health_runtime + .me_health_interval_ms_healthy .store(me_health_interval_ms_healthy.max(1), Ordering::Relaxed); - self.me_warn_rate_limit_ms + self.health_runtime + .me_warn_rate_limit_ms .store(me_warn_rate_limit_ms.max(1), Ordering::Relaxed); if previous_floor_mode != floor_mode { self.stats.increment_me_floor_mode_switch_total(); @@ -936,9 +1222,13 @@ impl MePool { } pub fn reset_stun_state(&self) { - self.nat_probe_attempts.store(0, Ordering::Relaxed); - self.nat_probe_disabled.store(false, Ordering::Relaxed); - if let Ok(mut live) = self.nat_stun_live_servers.try_write() { + self.nat_runtime + .nat_probe_attempts + .store(0, Ordering::Relaxed); + self.nat_runtime + .nat_probe_disabled + .store(false, Ordering::Relaxed); + if let Ok(mut live) = self.nat_runtime.nat_stun_live_servers.try_write() { live.clear(); } } @@ -960,9 +1250,11 @@ impl MePool { route_backpressure_high_watermark_pct: u8, reader_route_data_wait_ms: u64, ) { - self.me_socks_kdf_policy + self.transport_policy + .me_socks_kdf_policy .store(socks_kdf_policy.as_u8(), Ordering::Relaxed); - self.me_reader_route_data_wait_ms + self.transport_policy + .me_reader_route_data_wait_ms .store(reader_route_data_wait_ms, Ordering::Relaxed); self.registry.update_route_backpressure_policy( route_backpressure_base_timeout_ms, @@ -972,41 +1264,52 @@ impl MePool { } pub(super) fn socks_kdf_policy(&self) -> MeSocksKdfPolicy { - MeSocksKdfPolicy::from_u8(self.me_socks_kdf_policy.load(Ordering::Relaxed)) + MeSocksKdfPolicy::from_u8( + self.transport_policy + .me_socks_kdf_policy + .load(Ordering::Relaxed), + ) } - pub(super) fn writers_arc(&self) -> Arc>> { + pub(super) fn writers_arc(&self) -> Arc { self.writers.clone() } pub(super) fn force_close_timeout(&self) -> Option { - let secs = - Self::normalize_force_close_secs(self.me_pool_force_close_secs.load(Ordering::Relaxed)); + let secs = Self::normalize_force_close_secs( + self.drain_runtime + .me_pool_force_close_secs + .load(Ordering::Relaxed), + ); Some(Duration::from_secs(secs)) } #[allow(dead_code)] pub(super) fn drain_soft_evict_enabled(&self) -> bool { - self.me_pool_drain_soft_evict_enabled + self.drain_runtime + .me_pool_drain_soft_evict_enabled .load(Ordering::Relaxed) } #[allow(dead_code)] pub(super) fn drain_soft_evict_grace_secs(&self) -> u64 { - self.me_pool_drain_soft_evict_grace_secs + self.drain_runtime + .me_pool_drain_soft_evict_grace_secs .load(Ordering::Relaxed) } #[allow(dead_code)] pub(super) fn drain_soft_evict_per_writer(&self) -> usize { - self.me_pool_drain_soft_evict_per_writer + self.drain_runtime + .me_pool_drain_soft_evict_per_writer .load(Ordering::Relaxed) .max(1) as usize } #[allow(dead_code)] pub(super) fn drain_soft_evict_budget_per_core(&self) -> usize { - self.me_pool_drain_soft_evict_budget_per_core + self.drain_runtime + .me_pool_drain_soft_evict_budget_per_core .load(Ordering::Relaxed) .max(1) as usize } @@ -1014,7 +1317,8 @@ impl MePool { #[allow(dead_code)] pub(super) fn drain_soft_evict_cooldown(&self) -> Duration { Duration::from_millis( - self.me_pool_drain_soft_evict_cooldown_ms + self.drain_runtime + .me_pool_drain_soft_evict_cooldown_ms .load(Ordering::Relaxed) .max(1), ) @@ -1078,15 +1382,24 @@ impl MePool { } pub(super) fn bind_stale_mode(&self) -> MeBindStaleMode { - MeBindStaleMode::from_u8(self.me_bind_stale_mode.load(Ordering::Relaxed)) + MeBindStaleMode::from_u8( + self.binding_policy + .me_bind_stale_mode + .load(Ordering::Relaxed), + ) } pub(super) fn writer_pick_mode(&self) -> MeWriterPickMode { - MeWriterPickMode::from_u8(self.me_writer_pick_mode.load(Ordering::Relaxed)) + MeWriterPickMode::from_u8( + self.writer_selection_policy + .me_writer_pick_mode + .load(Ordering::Relaxed), + ) } pub(super) fn writer_pick_sample_size(&self) -> usize { - self.me_writer_pick_sample_size + self.writer_selection_policy + .me_writer_pick_sample_size .load(Ordering::Relaxed) .clamp(2, 4) as usize } @@ -1097,6 +1410,7 @@ impl MePool { } if endpoint_count == 1 { let shadow = self + .single_endpoint_runtime .me_single_endpoint_shadow_writers .load(Ordering::Relaxed) as usize; return (1 + shadow).max(3); @@ -1105,39 +1419,48 @@ impl MePool { } pub(super) fn floor_mode(&self) -> MeFloorMode { - MeFloorMode::from_u8(self.me_floor_mode.load(Ordering::Relaxed)) + MeFloorMode::from_u8(self.floor_runtime.me_floor_mode.load(Ordering::Relaxed)) } pub(super) fn adaptive_floor_idle_duration(&self) -> Duration { - Duration::from_secs(self.me_adaptive_floor_idle_secs.load(Ordering::Relaxed)) + Duration::from_secs( + self.floor_runtime + .me_adaptive_floor_idle_secs + .load(Ordering::Relaxed), + ) } pub(super) fn adaptive_floor_recover_grace_duration(&self) -> Duration { Duration::from_secs( - self.me_adaptive_floor_recover_grace_secs + self.floor_runtime + .me_adaptive_floor_recover_grace_secs .load(Ordering::Relaxed), ) } pub(super) fn adaptive_floor_min_writers_multi_endpoint(&self) -> usize { (self + .floor_runtime .me_adaptive_floor_min_writers_multi_endpoint .load(Ordering::Relaxed) as usize) .max(1) } pub(super) fn adaptive_floor_max_extra_single_per_core(&self) -> usize { - self.me_adaptive_floor_max_extra_writers_single_per_core + self.floor_runtime + .me_adaptive_floor_max_extra_writers_single_per_core .load(Ordering::Relaxed) as usize } pub(super) fn adaptive_floor_max_extra_multi_per_core(&self) -> usize { - self.me_adaptive_floor_max_extra_writers_multi_per_core + self.floor_runtime + .me_adaptive_floor_max_extra_writers_multi_per_core .load(Ordering::Relaxed) as usize } pub(super) fn adaptive_floor_max_active_writers_per_core(&self) -> usize { (self + .floor_runtime .me_adaptive_floor_max_active_writers_per_core .load(Ordering::Relaxed) as usize) .max(1) @@ -1145,6 +1468,7 @@ impl MePool { pub(super) fn adaptive_floor_max_warm_writers_per_core(&self) -> usize { (self + .floor_runtime .me_adaptive_floor_max_warm_writers_per_core .load(Ordering::Relaxed) as usize) .max(1) @@ -1152,6 +1476,7 @@ impl MePool { pub(super) fn adaptive_floor_max_active_writers_global(&self) -> usize { (self + .floor_runtime .me_adaptive_floor_max_active_writers_global .load(Ordering::Relaxed) as usize) .max(1) @@ -1159,6 +1484,7 @@ impl MePool { pub(super) fn adaptive_floor_max_warm_writers_global(&self) -> usize { (self + .floor_runtime .me_adaptive_floor_max_warm_writers_global .load(Ordering::Relaxed) as usize) .max(1) @@ -1174,6 +1500,7 @@ impl MePool { pub(super) fn adaptive_floor_effective_cpu_cores(&self) -> usize { let detected = self.adaptive_floor_detected_cpu_cores(); let override_cores = self + .floor_runtime .me_adaptive_floor_cpu_cores_override .load(Ordering::Relaxed) as usize; let effective = if override_cores == 0 { @@ -1181,9 +1508,11 @@ impl MePool { } else { override_cores.max(1) }; - self.me_adaptive_floor_cpu_cores_detected + self.floor_runtime + .me_adaptive_floor_cpu_cores_detected .store(detected as u32, Ordering::Relaxed); - self.me_adaptive_floor_cpu_cores_effective + self.floor_runtime + .me_adaptive_floor_cpu_cores_effective .store(effective as u32, Ordering::Relaxed); self.stats .set_me_floor_cpu_cores_detected_gauge(detected as u64); @@ -1215,7 +1544,8 @@ impl MePool { .min(self.adaptive_floor_max_active_writers_global()) .min(per_contour_budget) .max(1); - self.me_adaptive_floor_active_cap_configured + self.floor_runtime + .me_adaptive_floor_active_cap_configured .store(configured as u64, Ordering::Relaxed); self.stats .set_me_floor_active_cap_configured_gauge(configured as u64); @@ -1230,7 +1560,8 @@ impl MePool { .min(self.adaptive_floor_max_warm_writers_global()) .min(per_contour_budget) .max(1); - self.me_adaptive_floor_warm_cap_configured + self.floor_runtime + .me_adaptive_floor_warm_cap_configured .store(configured as u64, Ordering::Relaxed); self.stats .set_me_floor_warm_cap_configured_gauge(configured as u64); @@ -1247,23 +1578,32 @@ impl MePool { active_writers_current: usize, warm_writers_current: usize, ) { - self.me_adaptive_floor_global_cap_raw + self.floor_runtime + .me_adaptive_floor_global_cap_raw .store(active_cap_configured as u64, Ordering::Relaxed); - self.me_adaptive_floor_global_cap_effective + self.floor_runtime + .me_adaptive_floor_global_cap_effective .store(active_cap_effective as u64, Ordering::Relaxed); - self.me_adaptive_floor_target_writers_total + self.floor_runtime + .me_adaptive_floor_target_writers_total .store(target_writers_total as u64, Ordering::Relaxed); - self.me_adaptive_floor_active_cap_configured + self.floor_runtime + .me_adaptive_floor_active_cap_configured .store(active_cap_configured as u64, Ordering::Relaxed); - self.me_adaptive_floor_active_cap_effective + self.floor_runtime + .me_adaptive_floor_active_cap_effective .store(active_cap_effective as u64, Ordering::Relaxed); - self.me_adaptive_floor_warm_cap_configured + self.floor_runtime + .me_adaptive_floor_warm_cap_configured .store(warm_cap_configured as u64, Ordering::Relaxed); - self.me_adaptive_floor_warm_cap_effective + self.floor_runtime + .me_adaptive_floor_warm_cap_effective .store(warm_cap_effective as u64, Ordering::Relaxed); - self.me_adaptive_floor_active_writers_current + self.floor_runtime + .me_adaptive_floor_active_writers_current .store(active_writers_current as u64, Ordering::Relaxed); - self.me_adaptive_floor_warm_writers_current + self.floor_runtime + .me_adaptive_floor_warm_writers_current .store(warm_writers_current as u64, Ordering::Relaxed); self.stats .set_me_floor_global_cap_raw_gauge(active_cap_configured as u64); @@ -1352,11 +1692,13 @@ impl MePool { } let min_writers = if endpoint_count == 1 { (self + .floor_runtime .me_adaptive_floor_min_writers_single_endpoint .load(Ordering::Relaxed) as usize) .max(1) } else { (self + .floor_runtime .me_adaptive_floor_min_writers_multi_endpoint .load(Ordering::Relaxed) as usize) .max(1) @@ -1365,20 +1707,24 @@ impl MePool { } pub(super) fn single_endpoint_outage_mode_enabled(&self) -> bool { - self.me_single_endpoint_outage_mode_enabled + self.single_endpoint_runtime + .me_single_endpoint_outage_mode_enabled .load(Ordering::Relaxed) } pub(super) fn single_endpoint_outage_disable_quarantine(&self) -> bool { - self.me_single_endpoint_outage_disable_quarantine + self.single_endpoint_runtime + .me_single_endpoint_outage_disable_quarantine .load(Ordering::Relaxed) } pub(super) fn single_endpoint_outage_backoff_bounds_ms(&self) -> (u64, u64) { let min_ms = self + .single_endpoint_runtime .me_single_endpoint_outage_backoff_min_ms .load(Ordering::Relaxed); let max_ms = self + .single_endpoint_runtime .me_single_endpoint_outage_backoff_max_ms .load(Ordering::Relaxed); if min_ms <= max_ms { @@ -1390,6 +1736,7 @@ impl MePool { pub(super) fn single_endpoint_shadow_rotate_interval(&self) -> Option { let secs = self + .single_endpoint_runtime .me_single_endpoint_shadow_rotate_every_secs .load(Ordering::Relaxed); if secs == 0 { @@ -1573,17 +1920,34 @@ impl MePool { let rebuilt = Self::build_endpoint_dc_map_from_maps(&map_v4, &map_v6); let preferred = Self::build_preferred_endpoints_by_dc(&self.decision, &map_v4, &map_v6); *self.endpoint_dc_map.write().await = rebuilt; - *self.preferred_endpoints_by_dc.write().await = preferred; + self.preferred_endpoints_by_dc.store(Arc::new(preferred)); + let configured_endpoints = self + .endpoint_dc_map + .read() + .await + .keys() + .copied() + .collect::>(); + { + let mut quarantine = self.endpoint_quarantine.lock().await; + let now = Instant::now(); + quarantine.retain(|addr, expiry| *expiry > now && configured_endpoints.contains(addr)); + } + { + let mut kdf_fp = self.kdf_material_fingerprint.write().await; + kdf_fp.retain(|addr, _| configured_endpoints.contains(addr)); + } } pub(super) async fn preferred_endpoints_for_dc(&self, dc: i32) -> Vec { - let guard = self.preferred_endpoints_by_dc.read().await; + let guard = self.preferred_endpoints_by_dc.load(); guard.get(&dc).cloned().unwrap_or_default() } pub(super) fn health_interval_unhealthy(&self) -> Duration { Duration::from_millis( - self.me_health_interval_ms_unhealthy + self.health_runtime + .me_health_interval_ms_unhealthy .load(Ordering::Relaxed) .max(1), ) @@ -1591,13 +1955,19 @@ impl MePool { pub(super) fn health_interval_healthy(&self) -> Duration { Duration::from_millis( - self.me_health_interval_ms_healthy + self.health_runtime + .me_health_interval_ms_healthy .load(Ordering::Relaxed) .max(1), ) } pub(super) fn warn_rate_limit_duration(&self) -> Duration { - Duration::from_millis(self.me_warn_rate_limit_ms.load(Ordering::Relaxed).max(1)) + Duration::from_millis( + self.health_runtime + .me_warn_rate_limit_ms + .load(Ordering::Relaxed) + .max(1), + ) } } diff --git a/src/transport/middle_proxy/pool_config.rs b/src/transport/middle_proxy/pool_config.rs index 486fad0..6e29918 100644 --- a/src/transport/middle_proxy/pool_config.rs +++ b/src/transport/middle_proxy/pool_config.rs @@ -72,7 +72,7 @@ impl MePool { } if changed { self.rebuild_endpoint_dc_map().await; - self.writer_available.notify_waiters(); + self.notify_writer_epoch(); } if changed { SnapshotApplyOutcome::AppliedChanged @@ -112,7 +112,7 @@ impl MePool { pub async fn reconnect_all(self: &Arc) { let ws = self.writers.read().await.clone(); - for w in ws { + for w in ws.iter() { if let Ok(()) = self .connect_one_for_dc(w.addr, w.writer_dc, self.rng.as_ref()) .await diff --git a/src/transport/middle_proxy/pool_init.rs b/src/transport/middle_proxy/pool_init.rs index 2e3bc1d..3f7cad7 100644 --- a/src/transport/middle_proxy/pool_init.rs +++ b/src/transport/middle_proxy/pool_init.rs @@ -14,7 +14,10 @@ use super::pool::MePool; impl MePool { pub async fn init(self: &Arc, pool_size: usize, rng: &Arc) -> Result<()> { let family_order = self.family_order(); - let connect_concurrency = self.me_reconnect_max_concurrent_per_dc.max(1) as usize; + let connect_concurrency = self + .reconnect_runtime + .me_reconnect_max_concurrent_per_dc + .max(1) as usize; let ks = self.key_selector().await; info!( me_servers = self.proxy_map_v4.read().await.len(), @@ -250,10 +253,12 @@ impl MePool { return false; } - if self.me_warmup_stagger_enabled { - let jitter = - rand::rng().random_range(0..=self.me_warmup_step_jitter.as_millis() as u64); - let delay_ms = self.me_warmup_step_delay.as_millis() as u64 + jitter; + if self.reconnect_runtime.me_warmup_stagger_enabled { + let jitter = rand::rng().random_range( + 0..=self.reconnect_runtime.me_warmup_step_jitter.as_millis() as u64, + ); + let delay_ms = + self.reconnect_runtime.me_warmup_step_delay.as_millis() as u64 + jitter; tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await; } } diff --git a/src/transport/middle_proxy/pool_nat.rs b/src/transport/middle_proxy/pool_nat.rs index f382fd4..be2d9df 100644 --- a/src/transport/middle_proxy/pool_nat.rs +++ b/src/transport/middle_proxy/pool_nat.rs @@ -42,10 +42,10 @@ pub async fn detect_public_ip() -> Option { impl MePool { fn configured_stun_servers(&self) -> Vec { - if !self.nat_stun_servers.is_empty() { - return self.nat_stun_servers.clone(); + if !self.nat_runtime.nat_stun_servers.is_empty() { + return self.nat_runtime.nat_stun_servers.clone(); } - if let Some(s) = &self.nat_stun + if let Some(s) = &self.nat_runtime.nat_stun && !s.trim().is_empty() { return vec![s.clone()]; @@ -64,7 +64,7 @@ impl MePool { let mut next_idx = 0usize; let mut live_servers = Vec::new(); let mut best_by_ip: HashMap = HashMap::new(); - let concurrency = self.nat_probe_concurrency.max(1); + let concurrency = self.nat_runtime.nat_probe_concurrency.max(1); while next_idx < servers.len() || !join_set.is_empty() { while next_idx < servers.len() && join_set.len() < concurrency { @@ -137,9 +137,13 @@ impl MePool { } pub(super) fn translate_ip_for_nat(&self, ip: IpAddr) -> IpAddr { - let nat_ip = self - .nat_ip_cfg - .or_else(|| self.nat_ip_detected.try_read().ok().and_then(|g| *g)); + let nat_ip = self.nat_runtime.nat_ip_cfg.or_else(|| { + self.nat_runtime + .nat_ip_detected + .try_read() + .ok() + .and_then(|g| *g) + }); let Some(nat_ip) = nat_ip else { return ip; @@ -163,7 +167,7 @@ impl MePool { addr: std::net::SocketAddr, reflected: Option, ) -> std::net::SocketAddr { - let ip = if let Some(nat_ip) = self.nat_ip_cfg { + let ip = if let Some(nat_ip) = self.nat_runtime.nat_ip_cfg { match (addr.ip(), nat_ip) { (IpAddr::V4(_), IpAddr::V4(dst)) => IpAddr::V4(dst), (IpAddr::V6(_), IpAddr::V6(dst)) => IpAddr::V6(dst), @@ -185,22 +189,22 @@ impl MePool { } pub(super) async fn maybe_detect_nat_ip(&self, local_ip: IpAddr) -> Option { - if self.nat_ip_cfg.is_some() { - return self.nat_ip_cfg; + if self.nat_runtime.nat_ip_cfg.is_some() { + return self.nat_runtime.nat_ip_cfg; } if !(is_bogon(local_ip) || local_ip.is_loopback() || local_ip.is_unspecified()) { return None; } - if let Some(ip) = *self.nat_ip_detected.read().await { + if let Some(ip) = *self.nat_runtime.nat_ip_detected.read().await { return Some(ip); } match fetch_public_ipv4_with_retry().await { Ok(Some(ip)) => { { - let mut guard = self.nat_ip_detected.write().await; + let mut guard = self.nat_runtime.nat_ip_detected.write().await; *guard = Some(IpAddr::V4(ip)); } info!(public_ip = %ip, "Auto-detected public IP for NAT translation"); @@ -231,10 +235,10 @@ impl MePool { } // Backoff window if use_shared_cache - && let Some(until) = *self.stun_backoff_until.read().await + && let Some(until) = *self.nat_runtime.stun_backoff_until.read().await && Instant::now() < until { - if let Ok(cache) = self.nat_reflection_cache.try_lock() { + if let Ok(cache) = self.nat_runtime.nat_reflection_cache.try_lock() { let slot = match family { IpFamily::V4 => cache.v4, IpFamily::V6 => cache.v6, @@ -244,7 +248,8 @@ impl MePool { return None; } - if use_shared_cache && let Ok(mut cache) = self.nat_reflection_cache.try_lock() { + if use_shared_cache && let Ok(mut cache) = self.nat_runtime.nat_reflection_cache.try_lock() + { let slot = match family { IpFamily::V4 => &mut cache.v4, IpFamily::V6 => &mut cache.v6, @@ -258,18 +263,18 @@ impl MePool { let _singleflight_guard = if use_shared_cache { Some(match family { - IpFamily::V4 => self.nat_reflection_singleflight_v4.lock().await, - IpFamily::V6 => self.nat_reflection_singleflight_v6.lock().await, + IpFamily::V4 => self.nat_runtime.nat_reflection_singleflight_v4.lock().await, + IpFamily::V6 => self.nat_runtime.nat_reflection_singleflight_v6.lock().await, }) } else { None }; if use_shared_cache - && let Some(until) = *self.stun_backoff_until.read().await + && let Some(until) = *self.nat_runtime.stun_backoff_until.read().await && Instant::now() < until { - if let Ok(cache) = self.nat_reflection_cache.try_lock() { + if let Ok(cache) = self.nat_runtime.nat_reflection_cache.try_lock() { let slot = match family { IpFamily::V4 => cache.v4, IpFamily::V6 => cache.v6, @@ -279,7 +284,8 @@ impl MePool { return None; } - if use_shared_cache && let Ok(mut cache) = self.nat_reflection_cache.try_lock() { + if use_shared_cache && let Ok(mut cache) = self.nat_runtime.nat_reflection_cache.try_lock() + { let slot = match family { IpFamily::V4 => &mut cache.v4, IpFamily::V6 => &mut cache.v6, @@ -292,13 +298,14 @@ impl MePool { } let attempt = if use_shared_cache { - self.nat_probe_attempts + self.nat_runtime + .nat_probe_attempts .fetch_add(1, std::sync::atomic::Ordering::Relaxed) } else { 0 }; let configured_servers = self.configured_stun_servers(); - let live_snapshot = self.nat_stun_live_servers.read().await.clone(); + let live_snapshot = self.nat_runtime.nat_stun_live_servers.read().await.clone(); let primary_servers = if live_snapshot.is_empty() { configured_servers.clone() } else { @@ -322,14 +329,15 @@ impl MePool { let live_server_count = live_servers.len(); if !live_servers.is_empty() { - *self.nat_stun_live_servers.write().await = live_servers; + *self.nat_runtime.nat_stun_live_servers.write().await = live_servers; } else { - self.nat_stun_live_servers.write().await.clear(); + self.nat_runtime.nat_stun_live_servers.write().await.clear(); } if let Some(reflected_addr) = selected_reflected { if use_shared_cache { - self.nat_probe_attempts + self.nat_runtime + .nat_probe_attempts .store(0, std::sync::atomic::Ordering::Relaxed); } info!( @@ -338,7 +346,9 @@ impl MePool { "STUN-Quorum reached, IP: {}", reflected_addr.ip() ); - if use_shared_cache && let Ok(mut cache) = self.nat_reflection_cache.try_lock() { + if use_shared_cache + && let Ok(mut cache) = self.nat_runtime.nat_reflection_cache.try_lock() + { let slot = match family { IpFamily::V4 => &mut cache.v4, IpFamily::V6 => &mut cache.v6, @@ -350,7 +360,7 @@ impl MePool { if use_shared_cache { let backoff = Duration::from_secs(60 * 2u64.pow((attempt as u32).min(6))); - *self.stun_backoff_until.write().await = Some(Instant::now() + backoff); + *self.nat_runtime.stun_backoff_until.write().await = Some(Instant::now() + backoff); } None } diff --git a/src/transport/middle_proxy/pool_refill.rs b/src/transport/middle_proxy/pool_refill.rs index d93bcfe..69d8aa0 100644 --- a/src/transport/middle_proxy/pool_refill.rs +++ b/src/transport/middle_proxy/pool_refill.rs @@ -13,13 +13,40 @@ use super::pool::{MePool, RefillDcKey, RefillEndpointKey, WriterContour}; const ME_FLAP_UPTIME_THRESHOLD_SECS: u64 = 20; const ME_FLAP_QUARANTINE_SECS: u64 = 25; +const ME_FLAP_MIN_UPTIME_MILLIS: u64 = 500; +const ME_REFILL_TOTAL_ATTEMPT_CAP: u32 = 20; impl MePool { + pub(super) async fn sweep_endpoint_quarantine(&self) { + let configured = self + .endpoint_dc_map + .read() + .await + .keys() + .copied() + .collect::>(); + let now = Instant::now(); + let mut guard = self.endpoint_quarantine.lock().await; + guard.retain(|addr, expiry| *expiry > now && configured.contains(addr)); + } + pub(super) async fn maybe_quarantine_flapping_endpoint( &self, addr: SocketAddr, uptime: Duration, + reason: &'static str, ) { + if uptime < Duration::from_millis(ME_FLAP_MIN_UPTIME_MILLIS) { + debug!( + %addr, + reason, + uptime_ms = uptime.as_millis(), + min_uptime_ms = ME_FLAP_MIN_UPTIME_MILLIS, + "Skipping flap quarantine for ultra-short writer lifetime" + ); + return; + } + if uptime > Duration::from_secs(ME_FLAP_UPTIME_THRESHOLD_SECS) { return; } @@ -31,6 +58,7 @@ impl MePool { self.stats.increment_me_endpoint_quarantine_total(); warn!( %addr, + reason, uptime_ms = uptime.as_millis(), quarantine_secs = ME_FLAP_QUARANTINE_SECS, "ME endpoint temporarily quarantined due to rapid writer flap" @@ -205,11 +233,16 @@ impl MePool { } async fn refill_writer_after_loss(self: &Arc, addr: SocketAddr, writer_dc: i32) -> bool { - let fast_retries = self.me_reconnect_fast_retry_count.max(1); + let fast_retries = self.reconnect_runtime.me_reconnect_fast_retry_count.max(1); + let mut total_attempts = 0u32; let same_endpoint_quarantined = self.is_endpoint_quarantined(addr).await; if !same_endpoint_quarantined { for attempt in 0..fast_retries { + if total_attempts >= ME_REFILL_TOTAL_ATTEMPT_CAP { + break; + } + total_attempts = total_attempts.saturating_add(1); self.stats.increment_me_reconnect_attempt(); match self .connect_one_for_dc(addr, writer_dc, self.rng.as_ref()) @@ -250,6 +283,10 @@ impl MePool { } for attempt in 0..fast_retries { + if total_attempts >= ME_REFILL_TOTAL_ATTEMPT_CAP { + break; + } + total_attempts = total_attempts.saturating_add(1); self.stats.increment_me_reconnect_attempt(); if self .connect_endpoints_round_robin(writer_dc, &dc_endpoints, self.rng.as_ref()) diff --git a/src/transport/middle_proxy/pool_reinit.rs b/src/transport/middle_proxy/pool_reinit.rs index 663007b..db6411c 100644 --- a/src/transport/middle_proxy/pool_reinit.rs +++ b/src/transport/middle_proxy/pool_reinit.rs @@ -37,16 +37,23 @@ impl MePool { } fn clear_pending_hardswap_state(&self) { - self.pending_hardswap_generation.store(0, Ordering::Relaxed); - self.pending_hardswap_started_at_epoch_secs + self.reinit + .pending_hardswap_generation .store(0, Ordering::Relaxed); - self.pending_hardswap_map_hash.store(0, Ordering::Relaxed); - self.warm_generation.store(0, Ordering::Relaxed); + self.reinit + .pending_hardswap_started_at_epoch_secs + .store(0, Ordering::Relaxed); + self.reinit + .pending_hardswap_map_hash + .store(0, Ordering::Relaxed); + self.reinit.warm_generation.store(0, Ordering::Relaxed); } async fn promote_warm_generation_to_active(&self, generation: u64) { - self.active_generation.store(generation, Ordering::Relaxed); - self.warm_generation.store(0, Ordering::Relaxed); + self.reinit + .active_generation + .store(generation, Ordering::Relaxed); + self.reinit.warm_generation.store(0, Ordering::Relaxed); let ws = self.writers.read().await; for writer in ws.iter() { @@ -184,8 +191,14 @@ impl MePool { } fn hardswap_warmup_connect_delay_ms(&self) -> u64 { - let min_ms = self.me_hardswap_warmup_delay_min_ms.load(Ordering::Relaxed); - let max_ms = self.me_hardswap_warmup_delay_max_ms.load(Ordering::Relaxed); + let min_ms = self + .reinit + .me_hardswap_warmup_delay_min_ms + .load(Ordering::Relaxed); + let max_ms = self + .reinit + .me_hardswap_warmup_delay_max_ms + .load(Ordering::Relaxed); let (min_ms, max_ms) = if min_ms <= max_ms { (min_ms, max_ms) } else { @@ -199,9 +212,11 @@ impl MePool { fn hardswap_warmup_backoff_ms(&self, pass_idx: usize) -> u64 { let base_ms = self + .reinit .me_hardswap_warmup_pass_backoff_base_ms .load(Ordering::Relaxed); - let cap_ms = (self.me_reconnect_backoff_cap.as_millis() as u64).max(base_ms); + let cap_ms = + (self.reconnect_runtime.me_reconnect_backoff_cap.as_millis() as u64).max(base_ms); let shift = (pass_idx as u32).min(20); let scaled = base_ms.saturating_mul(1u64 << shift); let core = scaled.min(cap_ms); @@ -244,6 +259,7 @@ impl MePool { desired_by_dc: &HashMap>, ) { let extra_passes = self + .reinit .me_hardswap_warmup_extra_passes .load(Ordering::Relaxed) .min(10) as usize; @@ -369,13 +385,20 @@ impl MePool { let desired_map_hash = Self::desired_map_hash(&desired_by_dc); let previous_generation = self.current_generation(); - let hardswap = self.hardswap.load(Ordering::Relaxed); + let hardswap = self.reinit.hardswap.load(Ordering::Relaxed); let generation = if hardswap { - let pending_generation = self.pending_hardswap_generation.load(Ordering::Relaxed); + let pending_generation = self + .reinit + .pending_hardswap_generation + .load(Ordering::Relaxed); let pending_started_at = self + .reinit .pending_hardswap_started_at_epoch_secs .load(Ordering::Relaxed); - let pending_map_hash = self.pending_hardswap_map_hash.load(Ordering::Relaxed); + let pending_map_hash = self + .reinit + .pending_hardswap_map_hash + .load(Ordering::Relaxed); let pending_age_secs = now_epoch_secs.saturating_sub(pending_started_at); let pending_ttl_expired = pending_started_at > 0 && pending_age_secs > ME_HARDSWAP_PENDING_TTL_SECS; @@ -405,24 +428,30 @@ impl MePool { "ME hardswap pending generation expired by TTL; starting fresh generation" ); } - let next_generation = self.generation.fetch_add(1, Ordering::Relaxed) + 1; - self.pending_hardswap_generation + let next_generation = self.reinit.generation.fetch_add(1, Ordering::Relaxed) + 1; + self.reinit + .pending_hardswap_generation .store(next_generation, Ordering::Relaxed); - self.pending_hardswap_started_at_epoch_secs + self.reinit + .pending_hardswap_started_at_epoch_secs .store(now_epoch_secs, Ordering::Relaxed); - self.pending_hardswap_map_hash + self.reinit + .pending_hardswap_map_hash .store(desired_map_hash, Ordering::Relaxed); - self.warm_generation + self.reinit + .warm_generation .store(next_generation, Ordering::Relaxed); next_generation } } else { self.clear_pending_hardswap_state(); - self.generation.fetch_add(1, Ordering::Relaxed) + 1 + self.reinit.generation.fetch_add(1, Ordering::Relaxed) + 1 }; if hardswap { - self.warm_generation.store(generation, Ordering::Relaxed); + self.reinit + .warm_generation + .store(generation, Ordering::Relaxed); self.warmup_generation_for_all_dcs(rng, generation, &desired_by_dc) .await; } else { @@ -436,7 +465,8 @@ impl MePool { .map(|w| (w.writer_dc, w.addr)) .collect(); let min_ratio = Self::permille_to_ratio( - self.me_pool_min_fresh_ratio_permille + self.drain_runtime + .me_pool_min_fresh_ratio_permille .load(Ordering::Relaxed), ); let (coverage_ratio, missing_dc) = diff --git a/src/transport/middle_proxy/pool_runtime_api.rs b/src/transport/middle_proxy/pool_runtime_api.rs index 7c15216..539f397 100644 --- a/src/transport/middle_proxy/pool_runtime_api.rs +++ b/src/transport/middle_proxy/pool_runtime_api.rs @@ -94,9 +94,9 @@ impl MePool { pub(crate) async fn api_nat_stun_snapshot(&self) -> MeApiNatStunSnapshot { let now = Instant::now(); - let mut configured_servers = if !self.nat_stun_servers.is_empty() { - self.nat_stun_servers.clone() - } else if let Some(stun) = &self.nat_stun { + let mut configured_servers = if !self.nat_runtime.nat_stun_servers.is_empty() { + self.nat_runtime.nat_stun_servers.clone() + } else if let Some(stun) = &self.nat_runtime.nat_stun { if stun.trim().is_empty() { Vec::new() } else { @@ -108,11 +108,11 @@ impl MePool { configured_servers.sort(); configured_servers.dedup(); - let mut live_servers = self.nat_stun_live_servers.read().await.clone(); + let mut live_servers = self.nat_runtime.nat_stun_live_servers.read().await.clone(); live_servers.sort(); live_servers.dedup(); - let reflection = self.nat_reflection_cache.lock().await; + let reflection = self.nat_runtime.nat_reflection_cache.lock().await; let reflection_v4 = reflection.v4.map(|(ts, addr)| MeApiNatReflectionSnapshot { addr, age_secs: now.saturating_duration_since(ts).as_secs(), @@ -123,17 +123,19 @@ impl MePool { }); drop(reflection); - let backoff_until = *self.stun_backoff_until.read().await; + let backoff_until = *self.nat_runtime.stun_backoff_until.read().await; let stun_backoff_remaining_ms = backoff_until.and_then(|until| { (until > now).then_some(until.duration_since(now).as_millis() as u64) }); MeApiNatStunSnapshot { - nat_probe_enabled: self.nat_probe, + nat_probe_enabled: self.nat_runtime.nat_probe, nat_probe_disabled_runtime: self + .nat_runtime .nat_probe_disabled .load(std::sync::atomic::Ordering::Relaxed), nat_probe_attempts: self + .nat_runtime .nat_probe_attempts .load(std::sync::atomic::Ordering::Relaxed), configured_servers, diff --git a/src/transport/middle_proxy/pool_status.rs b/src/transport/middle_proxy/pool_status.rs index 918ccd4..ae9038b 100644 --- a/src/transport/middle_proxy/pool_status.rs +++ b/src/transport/middle_proxy/pool_status.rs @@ -160,7 +160,7 @@ impl MePool { let writers = self.writers.read().await.clone(); let mut live_writers_by_dc = HashMap::::new(); - for writer in writers { + for writer in writers.iter() { if writer.draining.load(Ordering::Relaxed) { continue; } @@ -197,7 +197,7 @@ impl MePool { let writers = self.writers.read().await.clone(); let mut live_writers_by_dc = HashMap::::new(); - for writer in writers { + for writer in writers.iter() { if writer.draining.load(Ordering::Relaxed) { continue; } @@ -224,7 +224,10 @@ impl MePool { pub(crate) async fn api_status_snapshot(&self) -> MeApiStatusSnapshot { let now_epoch_secs = Self::now_epoch_secs(); let active_generation = self.current_generation(); - let drain_ttl_secs = self.me_pool_drain_ttl_secs.load(Ordering::Relaxed); + let drain_ttl_secs = self + .drain_runtime + .me_pool_drain_ttl_secs + .load(Ordering::Relaxed); let mut endpoints_by_dc = BTreeMap::>::new(); if self.decision.ipv4_me { @@ -255,7 +258,7 @@ impl MePool { let mut dc_rtt_agg = HashMap::::new(); let mut writer_rows = Vec::::with_capacity(writers.len()); - for writer in writers { + for writer in writers.iter() { let endpoint = writer.addr; let dc = i16::try_from(writer.writer_dc).ok(); let draining = writer.draining.load(Ordering::Relaxed); @@ -293,9 +296,7 @@ impl MePool { WriterContour::Draining => "draining", }; - if !draining - && let Some(dc_idx) = dc - { + if !draining && let Some(dc_idx) = dc { *live_writers_by_dc_endpoint .entry((dc_idx, endpoint)) .or_insert(0) += 1; @@ -338,6 +339,7 @@ impl MePool { let mut fresh_alive_writers = 0usize; let floor_mode = self.floor_mode(); let adaptive_cpu_cores = (self + .floor_runtime .me_adaptive_floor_cpu_cores_effective .load(Ordering::Relaxed) as usize) .max(1); @@ -352,22 +354,26 @@ impl MePool { self.required_writers_for_dc_with_floor_mode(endpoint_count, false); let floor_min = if endpoint_count <= 1 { (self + .floor_runtime .me_adaptive_floor_min_writers_single_endpoint .load(Ordering::Relaxed) as usize) .max(1) .min(base_required.max(1)) } else { (self + .floor_runtime .me_adaptive_floor_min_writers_multi_endpoint .load(Ordering::Relaxed) as usize) .max(1) .min(base_required.max(1)) }; let extra_per_core = if endpoint_count <= 1 { - self.me_adaptive_floor_max_extra_writers_single_per_core + self.floor_runtime + .me_adaptive_floor_max_extra_writers_single_per_core .load(Ordering::Relaxed) as usize } else { - self.me_adaptive_floor_max_extra_writers_multi_per_core + self.floor_runtime + .me_adaptive_floor_max_extra_writers_multi_per_core .load(Ordering::Relaxed) as usize }; let floor_max = @@ -438,6 +444,7 @@ impl MePool { let now = Instant::now(); let now_epoch_secs = Self::now_epoch_secs(); let pending_started_at = self + .reinit .pending_hardswap_started_at_epoch_secs .load(Ordering::Relaxed); let pending_hardswap_age_secs = @@ -479,119 +486,175 @@ impl MePool { } MeApiRuntimeSnapshot { - active_generation: self.active_generation.load(Ordering::Relaxed), - warm_generation: self.warm_generation.load(Ordering::Relaxed), - pending_hardswap_generation: self.pending_hardswap_generation.load(Ordering::Relaxed), + active_generation: self.reinit.active_generation.load(Ordering::Relaxed), + warm_generation: self.reinit.warm_generation.load(Ordering::Relaxed), + pending_hardswap_generation: self + .reinit + .pending_hardswap_generation + .load(Ordering::Relaxed), pending_hardswap_age_secs, - hardswap_enabled: self.hardswap.load(Ordering::Relaxed), + hardswap_enabled: self.reinit.hardswap.load(Ordering::Relaxed), floor_mode: floor_mode_label(self.floor_mode()), - adaptive_floor_idle_secs: self.me_adaptive_floor_idle_secs.load(Ordering::Relaxed), + adaptive_floor_idle_secs: self + .floor_runtime + .me_adaptive_floor_idle_secs + .load(Ordering::Relaxed), adaptive_floor_min_writers_single_endpoint: self + .floor_runtime .me_adaptive_floor_min_writers_single_endpoint .load(Ordering::Relaxed), adaptive_floor_min_writers_multi_endpoint: self + .floor_runtime .me_adaptive_floor_min_writers_multi_endpoint .load(Ordering::Relaxed), adaptive_floor_recover_grace_secs: self + .floor_runtime .me_adaptive_floor_recover_grace_secs .load(Ordering::Relaxed), adaptive_floor_writers_per_core_total: self + .floor_runtime .me_adaptive_floor_writers_per_core_total .load(Ordering::Relaxed) as u16, adaptive_floor_cpu_cores_override: self + .floor_runtime .me_adaptive_floor_cpu_cores_override .load(Ordering::Relaxed) as u16, adaptive_floor_max_extra_writers_single_per_core: self + .floor_runtime .me_adaptive_floor_max_extra_writers_single_per_core .load(Ordering::Relaxed) as u16, adaptive_floor_max_extra_writers_multi_per_core: self + .floor_runtime .me_adaptive_floor_max_extra_writers_multi_per_core .load(Ordering::Relaxed) as u16, adaptive_floor_max_active_writers_per_core: self + .floor_runtime .me_adaptive_floor_max_active_writers_per_core .load(Ordering::Relaxed) as u16, adaptive_floor_max_warm_writers_per_core: self + .floor_runtime .me_adaptive_floor_max_warm_writers_per_core .load(Ordering::Relaxed) as u16, adaptive_floor_max_active_writers_global: self + .floor_runtime .me_adaptive_floor_max_active_writers_global .load(Ordering::Relaxed), adaptive_floor_max_warm_writers_global: self + .floor_runtime .me_adaptive_floor_max_warm_writers_global .load(Ordering::Relaxed), adaptive_floor_cpu_cores_detected: self + .floor_runtime .me_adaptive_floor_cpu_cores_detected .load(Ordering::Relaxed), adaptive_floor_cpu_cores_effective: self + .floor_runtime .me_adaptive_floor_cpu_cores_effective .load(Ordering::Relaxed), adaptive_floor_global_cap_raw: self + .floor_runtime .me_adaptive_floor_global_cap_raw .load(Ordering::Relaxed), adaptive_floor_global_cap_effective: self + .floor_runtime .me_adaptive_floor_global_cap_effective .load(Ordering::Relaxed), adaptive_floor_target_writers_total: self + .floor_runtime .me_adaptive_floor_target_writers_total .load(Ordering::Relaxed), adaptive_floor_active_cap_configured: self + .floor_runtime .me_adaptive_floor_active_cap_configured .load(Ordering::Relaxed), adaptive_floor_active_cap_effective: self + .floor_runtime .me_adaptive_floor_active_cap_effective .load(Ordering::Relaxed), adaptive_floor_warm_cap_configured: self + .floor_runtime .me_adaptive_floor_warm_cap_configured .load(Ordering::Relaxed), adaptive_floor_warm_cap_effective: self + .floor_runtime .me_adaptive_floor_warm_cap_effective .load(Ordering::Relaxed), adaptive_floor_active_writers_current: self + .floor_runtime .me_adaptive_floor_active_writers_current .load(Ordering::Relaxed), adaptive_floor_warm_writers_current: self + .floor_runtime .me_adaptive_floor_warm_writers_current .load(Ordering::Relaxed), - me_keepalive_enabled: self.me_keepalive_enabled, - me_keepalive_interval_secs: self.me_keepalive_interval.as_secs(), - me_keepalive_jitter_secs: self.me_keepalive_jitter.as_secs(), - me_keepalive_payload_random: self.me_keepalive_payload_random, - rpc_proxy_req_every_secs: self.rpc_proxy_req_every_secs.load(Ordering::Relaxed), - me_reconnect_max_concurrent_per_dc: self.me_reconnect_max_concurrent_per_dc, - me_reconnect_backoff_base_ms: self.me_reconnect_backoff_base.as_millis() as u64, - me_reconnect_backoff_cap_ms: self.me_reconnect_backoff_cap.as_millis() as u64, - me_reconnect_fast_retry_count: self.me_reconnect_fast_retry_count, - me_pool_drain_ttl_secs: self.me_pool_drain_ttl_secs.load(Ordering::Relaxed), - me_pool_force_close_secs: self.me_pool_force_close_secs.load(Ordering::Relaxed), + me_keepalive_enabled: self.writer_lifecycle.me_keepalive_enabled, + me_keepalive_interval_secs: self.writer_lifecycle.me_keepalive_interval.as_secs(), + me_keepalive_jitter_secs: self.writer_lifecycle.me_keepalive_jitter.as_secs(), + me_keepalive_payload_random: self.writer_lifecycle.me_keepalive_payload_random, + rpc_proxy_req_every_secs: self + .writer_lifecycle + .rpc_proxy_req_every_secs + .load(Ordering::Relaxed), + me_reconnect_max_concurrent_per_dc: self + .reconnect_runtime + .me_reconnect_max_concurrent_per_dc, + me_reconnect_backoff_base_ms: self + .reconnect_runtime + .me_reconnect_backoff_base + .as_millis() as u64, + me_reconnect_backoff_cap_ms: self.reconnect_runtime.me_reconnect_backoff_cap.as_millis() + as u64, + me_reconnect_fast_retry_count: self.reconnect_runtime.me_reconnect_fast_retry_count, + me_pool_drain_ttl_secs: self + .drain_runtime + .me_pool_drain_ttl_secs + .load(Ordering::Relaxed), + me_pool_force_close_secs: self + .drain_runtime + .me_pool_force_close_secs + .load(Ordering::Relaxed), me_pool_min_fresh_ratio: Self::permille_to_ratio( - self.me_pool_min_fresh_ratio_permille + self.drain_runtime + .me_pool_min_fresh_ratio_permille .load(Ordering::Relaxed), ), me_bind_stale_mode: bind_stale_mode_label(self.bind_stale_mode()), - me_bind_stale_ttl_secs: self.me_bind_stale_ttl_secs.load(Ordering::Relaxed), + me_bind_stale_ttl_secs: self + .binding_policy + .me_bind_stale_ttl_secs + .load(Ordering::Relaxed), me_single_endpoint_shadow_writers: self + .single_endpoint_runtime .me_single_endpoint_shadow_writers .load(Ordering::Relaxed), me_single_endpoint_outage_mode_enabled: self + .single_endpoint_runtime .me_single_endpoint_outage_mode_enabled .load(Ordering::Relaxed), me_single_endpoint_outage_disable_quarantine: self + .single_endpoint_runtime .me_single_endpoint_outage_disable_quarantine .load(Ordering::Relaxed), me_single_endpoint_outage_backoff_min_ms: self + .single_endpoint_runtime .me_single_endpoint_outage_backoff_min_ms .load(Ordering::Relaxed), me_single_endpoint_outage_backoff_max_ms: self + .single_endpoint_runtime .me_single_endpoint_outage_backoff_max_ms .load(Ordering::Relaxed), me_single_endpoint_shadow_rotate_every_secs: self + .single_endpoint_runtime .me_single_endpoint_shadow_rotate_every_secs .load(Ordering::Relaxed), - me_deterministic_writer_sort: self.me_deterministic_writer_sort.load(Ordering::Relaxed), + me_deterministic_writer_sort: self + .writer_selection_policy + .me_deterministic_writer_sort + .load(Ordering::Relaxed), me_writer_pick_mode: writer_pick_mode_label(self.writer_pick_mode()), me_writer_pick_sample_size: self.writer_pick_sample_size() as u8, me_socks_kdf_policy: socks_kdf_policy_label(self.socks_kdf_policy()), diff --git a/src/transport/middle_proxy/pool_writer.rs b/src/transport/middle_proxy/pool_writer.rs index 22fb909..fae68b9 100644 --- a/src/transport/middle_proxy/pool_writer.rs +++ b/src/transport/middle_proxy/pool_writer.rs @@ -1,3 +1,4 @@ +use std::collections::HashMap; use std::io::ErrorKind; use std::net::SocketAddr; use std::sync::Arc; @@ -25,6 +26,7 @@ const ME_ACTIVE_PING_SECS: u64 = 25; const ME_ACTIVE_PING_JITTER_SECS: i64 = 5; const ME_IDLE_KEEPALIVE_MAX_SECS: u64 = 5; const ME_RPC_PROXY_REQ_RESPONSE_WAIT_MS: u64 = 700; +const ME_PING_TRACKER_CLEANUP_EVERY: u32 = 32; #[derive(Clone, Copy)] enum WriterTeardownMode { @@ -36,6 +38,240 @@ fn is_me_peer_closed_error(error: &ProxyError) -> bool { matches!(error, ProxyError::Io(ioe) if ioe.kind() == ErrorKind::UnexpectedEof) } +enum WriterLifecycleExit { + Reader(Result<()>), + Writer(Result<()>), + Ping, + Signal, + Cancelled, +} + +async fn writer_command_loop( + mut rx: mpsc::Receiver, + mut rpc_writer: RpcWriter, + cancel: CancellationToken, +) -> Result<()> { + loop { + tokio::select! { + cmd = rx.recv() => { + match cmd { + Some(WriterCommand::Data(payload)) => { + rpc_writer.send(&payload).await?; + } + Some(WriterCommand::DataAndFlush(payload)) => { + rpc_writer.send_and_flush(&payload).await?; + } + Some(WriterCommand::Close) | None => return Ok(()), + } + } + _ = cancel.cancelled() => return Ok(()), + } + } +} + +#[allow(clippy::too_many_arguments)] +async fn ping_loop( + pool_ping: std::sync::Weak, + writer_id: u64, + tx_ping: mpsc::Sender, + ping_tracker_ping: Arc>>, + stats_ping: Arc, + keepalive_enabled: bool, + keepalive_interval: Duration, + keepalive_jitter: Duration, + cancel_ping_token: CancellationToken, +) { + let mut ping_id: i64 = rand::random::(); + let mut cleanup_tick: u32 = 0; + let idle_interval_cap = Duration::from_secs(ME_IDLE_KEEPALIVE_MAX_SECS); + // Per-writer jittered start to avoid phase sync. + let startup_jitter = if keepalive_enabled { + let mut interval = keepalive_interval; + let Some(pool) = pool_ping.upgrade() else { + return; + }; + if pool.registry.is_writer_empty(writer_id).await { + interval = interval.min(idle_interval_cap); + } + let jitter_cap_ms = interval.as_millis() / 2; + let effective_jitter_ms = keepalive_jitter.as_millis().min(jitter_cap_ms).max(1); + Duration::from_millis(rand::rng().random_range(0..=effective_jitter_ms as u64)) + } else { + let jitter = + rand::rng().random_range(-ME_ACTIVE_PING_JITTER_SECS..=ME_ACTIVE_PING_JITTER_SECS); + let wait = (ME_ACTIVE_PING_SECS as i64 + jitter).max(5) as u64; + Duration::from_secs(wait) + }; + tokio::select! { + _ = cancel_ping_token.cancelled() => return, + _ = tokio::time::sleep(startup_jitter) => {} + } + loop { + let wait = if keepalive_enabled { + let mut interval = keepalive_interval; + let Some(pool) = pool_ping.upgrade() else { + return; + }; + if pool.registry.is_writer_empty(writer_id).await { + interval = interval.min(idle_interval_cap); + } + let jitter_cap_ms = interval.as_millis() / 2; + let effective_jitter_ms = keepalive_jitter.as_millis().min(jitter_cap_ms).max(1); + interval + + Duration::from_millis(rand::rng().random_range(0..=effective_jitter_ms as u64)) + } else { + let jitter = + rand::rng().random_range(-ME_ACTIVE_PING_JITTER_SECS..=ME_ACTIVE_PING_JITTER_SECS); + let secs = (ME_ACTIVE_PING_SECS as i64 + jitter).max(5) as u64; + Duration::from_secs(secs) + }; + tokio::select! { + _ = cancel_ping_token.cancelled() => return, + _ = tokio::time::sleep(wait) => {} + } + let sent_id = ping_id; + let mut p = Vec::with_capacity(12); + p.extend_from_slice(&RPC_PING_U32.to_le_bytes()); + p.extend_from_slice(&sent_id.to_le_bytes()); + { + let mut tracker = ping_tracker_ping.lock().await; + cleanup_tick = cleanup_tick.wrapping_add(1); + if cleanup_tick.is_multiple_of(ME_PING_TRACKER_CLEANUP_EVERY) { + let before = tracker.len(); + tracker.retain(|_, ts| ts.elapsed() < Duration::from_secs(120)); + let expired = before.saturating_sub(tracker.len()); + if expired > 0 { + stats_ping.increment_me_keepalive_timeout_by(expired as u64); + } + } + tracker.insert(sent_id, std::time::Instant::now()); + } + ping_id = ping_id.wrapping_add(1); + stats_ping.increment_me_keepalive_sent(); + if tx_ping + .send(WriterCommand::DataAndFlush(Bytes::from(p))) + .await + .is_err() + { + stats_ping.increment_me_keepalive_failed(); + debug!("ME ping failed, removing dead writer"); + return; + } + } +} + +#[allow(clippy::too_many_arguments)] +async fn rpc_proxy_req_signal_loop( + pool_signal: std::sync::Weak, + writer_id: u64, + tx_signal: mpsc::Sender, + stats_signal: Arc, + cancel_signal: CancellationToken, + keepalive_jitter_signal: Duration, + rpc_proxy_req_every_secs: u64, +) { + if rpc_proxy_req_every_secs == 0 { + // Disabled service signal loop must stay parked until writer cancellation. + // Returning immediately here would complete `select!` and tear down writer lifecycle. + cancel_signal.cancelled().await; + return; + } + + let interval = Duration::from_secs(rpc_proxy_req_every_secs); + let startup_jitter_ms = { + let jitter_cap_ms = interval.as_millis() / 2; + let effective_jitter_ms = keepalive_jitter_signal + .as_millis() + .min(jitter_cap_ms) + .max(1); + rand::rng().random_range(0..=effective_jitter_ms as u64) + }; + + tokio::select! { + _ = cancel_signal.cancelled() => return, + _ = tokio::time::sleep(Duration::from_millis(startup_jitter_ms)) => {} + } + + loop { + let wait = { + let jitter_cap_ms = interval.as_millis() / 2; + let effective_jitter_ms = keepalive_jitter_signal + .as_millis() + .min(jitter_cap_ms) + .max(1); + interval + + Duration::from_millis(rand::rng().random_range(0..=effective_jitter_ms as u64)) + }; + + tokio::select! { + _ = cancel_signal.cancelled() => return, + _ = tokio::time::sleep(wait) => {} + } + + let Some(pool) = pool_signal.upgrade() else { + return; + }; + + let Some(meta) = pool.registry.get_last_writer_meta(writer_id).await else { + stats_signal.increment_me_rpc_proxy_req_signal_skipped_no_meta_total(); + continue; + }; + + let (conn_id, mut service_rx) = pool.registry.register().await; + // Service RPC_PROXY_REQ signal path is intentionally route-only: + // do not bind synthetic conn_id into regular writer/client accounting. + + let payload = build_proxy_req_payload( + conn_id, + meta.client_addr, + meta.our_addr, + &[], + pool.proxy_tag.as_deref(), + meta.proto_flags, + ); + + if tx_signal + .send(WriterCommand::DataAndFlush(payload)) + .await + .is_err() + { + stats_signal.increment_me_rpc_proxy_req_signal_failed_total(); + let _ = pool.registry.unregister(conn_id).await; + return; + } + + stats_signal.increment_me_rpc_proxy_req_signal_sent_total(); + + if matches!( + tokio::time::timeout( + Duration::from_millis(ME_RPC_PROXY_REQ_RESPONSE_WAIT_MS), + service_rx.recv(), + ) + .await, + Ok(Some(_)) + ) { + stats_signal.increment_me_rpc_proxy_req_signal_response_total(); + } + + let mut close_payload = Vec::with_capacity(12); + close_payload.extend_from_slice(&RPC_CLOSE_EXT_U32.to_le_bytes()); + close_payload.extend_from_slice(&conn_id.to_le_bytes()); + + if tx_signal + .send(WriterCommand::DataAndFlush(Bytes::from(close_payload))) + .await + .is_err() + { + stats_signal.increment_me_rpc_proxy_req_signal_failed_total(); + let _ = pool.registry.unregister(conn_id).await; + return; + } + + stats_signal.increment_me_rpc_proxy_req_signal_close_sent_total(); + let _ = pool.registry.unregister(conn_id).await; + } +} + impl MePool { pub(crate) async fn prune_closed_writers(self: &Arc) { let closed_writer_ids: Vec = { @@ -136,46 +372,15 @@ impl MePool { let draining_started_at_epoch_secs = Arc::new(AtomicU64::new(0)); let drain_deadline_epoch_secs = Arc::new(AtomicU64::new(0)); let allow_drain_fallback = Arc::new(AtomicBool::new(false)); - let (tx, mut rx) = mpsc::channel::(self.writer_cmd_channel_capacity); - let mut rpc_writer = RpcWriter { + let (tx, rx) = + mpsc::channel::(self.writer_lifecycle.writer_cmd_channel_capacity); + let rpc_writer = RpcWriter { writer: hs.wr, key: hs.write_key, iv: hs.write_iv, seq_no: 0, crc_mode: hs.crc_mode, }; - let cancel_wr = cancel.clone(); - let cleanup_done = Arc::new(AtomicBool::new(false)); - let cleanup_for_writer = cleanup_done.clone(); - let pool_writer_task = Arc::downgrade(self); - tokio::spawn(async move { - loop { - tokio::select! { - cmd = rx.recv() => { - match cmd { - Some(WriterCommand::Data(payload)) => { - if rpc_writer.send(&payload).await.is_err() { break; } - } - Some(WriterCommand::DataAndFlush(payload)) => { - if rpc_writer.send_and_flush(&payload).await.is_err() { break; } - } - Some(WriterCommand::Close) | None => break, - } - } - _ = cancel_wr.cancelled() => break, - } - } - if cleanup_for_writer - .compare_exchange(false, true, Ordering::AcqRel, Ordering::Relaxed) - .is_ok() - { - if let Some(pool) = pool_writer_task.upgrade() { - pool.remove_writer_and_close_clients(writer_id).await; - } else { - cancel_wr.cancel(); - } - } - }); let writer = MeWriter { id: writer_id, addr, @@ -193,329 +398,135 @@ impl MePool { drain_deadline_epoch_secs: drain_deadline_epoch_secs.clone(), allow_drain_fallback: allow_drain_fallback.clone(), }; - self.writers.write().await.push(writer.clone()); + self.writers + .update(|writers| writers.push(writer.clone())) + .await; self.registry.register_writer(writer_id, tx.clone()).await; self.registry.mark_writer_idle(writer_id).await; self.conn_count.fetch_add(1, Ordering::Relaxed); - self.writer_available.notify_one(); + self.notify_writer_epoch(); let reg = self.registry.clone(); let writers_arc = self.writers_arc(); - let ping_tracker = self.ping_tracker.clone(); + let ping_tracker = Arc::new(tokio::sync::Mutex::new(HashMap::::new())); let ping_tracker_reader = ping_tracker.clone(); + let ping_tracker_ping = ping_tracker.clone(); let rtt_stats = self.rtt_stats.clone(); let stats_reader = self.stats.clone(); let stats_reader_close = self.stats.clone(); let stats_ping = self.stats.clone(); - let pool = Arc::downgrade(self); - let cancel_ping = cancel.clone(); - let tx_ping = tx.clone(); - let ping_tracker_ping = ping_tracker.clone(); - let cleanup_for_reader = cleanup_done.clone(); - let cleanup_for_ping = cleanup_done.clone(); - let keepalive_enabled = self.me_keepalive_enabled; - let keepalive_interval = self.me_keepalive_interval; - let keepalive_jitter = self.me_keepalive_jitter; - let rpc_proxy_req_every_secs = self.rpc_proxy_req_every_secs.load(Ordering::Relaxed); - let tx_signal = tx.clone(); let stats_signal = self.stats.clone(); - let cancel_signal = cancel.clone(); - let cleanup_for_signal = cleanup_done.clone(); - let pool_signal = Arc::downgrade(self); - let keepalive_jitter_signal = self.me_keepalive_jitter; - let cancel_reader_token = cancel.clone(); - let cancel_ping_token = cancel_ping.clone(); - let reader_route_data_wait_ms = self.me_reader_route_data_wait_ms.clone(); - - tokio::spawn(async move { - let res = reader_loop( - hs.rd, - hs.read_key, - hs.read_iv, - hs.crc_mode, - reg.clone(), - BytesMut::new(), - BytesMut::new(), - tx.clone(), - ping_tracker_reader, - rtt_stats.clone(), - stats_reader, - writer_id, - degraded.clone(), - rtt_ema_ms_x10.clone(), - reader_route_data_wait_ms, - cancel_reader_token.clone(), - ) - .await; - let idle_close_by_peer = if let Err(e) = res.as_ref() { - is_me_peer_closed_error(e) && reg.is_writer_empty(writer_id).await - } else { - false - }; - if idle_close_by_peer { - stats_reader_close.increment_me_idle_close_by_peer_total(); - info!(writer_id, "ME socket closed by peer on idle writer"); - } - if cleanup_for_reader - .compare_exchange(false, true, Ordering::AcqRel, Ordering::Relaxed) - .is_ok() - { - if let Some(pool) = pool.upgrade() { - pool.remove_writer_and_close_clients(writer_id).await; - } else { - // Fallback for shutdown races: make writer task exit quickly so stale - // channels are observable by periodic prune. - cancel_reader_token.cancel(); - } - } - if let Err(e) = res - && !idle_close_by_peer - { - warn!(error = %e, "ME reader ended"); - } - let remaining = writers_arc.read().await.len(); - debug!(writer_id, remaining, "ME reader task finished"); - }); - + let pool_lifecycle = Arc::downgrade(self); let pool_ping = Arc::downgrade(self); + let pool_signal = Arc::downgrade(self); + let tx_reader = tx.clone(); + let tx_ping = tx.clone(); + let tx_signal = tx.clone(); + let keepalive_enabled = self.writer_lifecycle.me_keepalive_enabled; + let keepalive_interval = self.writer_lifecycle.me_keepalive_interval; + let keepalive_jitter = self.writer_lifecycle.me_keepalive_jitter; + let keepalive_jitter_signal = self.writer_lifecycle.me_keepalive_jitter; + let rpc_proxy_req_every_secs = self + .writer_lifecycle + .rpc_proxy_req_every_secs + .load(Ordering::Relaxed); + let cancel_reader = cancel.clone(); + let cancel_writer = cancel.clone(); + let cancel_ping = cancel.clone(); + let cancel_signal = cancel.clone(); + let cancel_select = cancel.clone(); + let cancel_cleanup = cancel.clone(); + let reader_route_data_wait_ms = self.transport_policy.me_reader_route_data_wait_ms.clone(); + tokio::spawn(async move { - let mut ping_id: i64 = rand::random::(); - let idle_interval_cap = Duration::from_secs(ME_IDLE_KEEPALIVE_MAX_SECS); - // Per-writer jittered start to avoid phase sync. - let startup_jitter = if keepalive_enabled { - let mut interval = keepalive_interval; - if let Some(pool) = pool_ping.upgrade() { - if pool.registry.is_writer_empty(writer_id).await { - interval = interval.min(idle_interval_cap); - } - } else { - return; + // Reader MUST be the first branch in biased select! to avoid read starvation. + let exit = tokio::select! { + biased; + + reader_res = reader_loop( + hs.rd, + hs.read_key, + hs.read_iv, + hs.crc_mode, + reg.clone(), + BytesMut::new(), + BytesMut::new(), + tx_reader, + ping_tracker_reader, + rtt_stats, + stats_reader, + writer_id, + degraded, + rtt_ema_ms_x10, + reader_route_data_wait_ms, + cancel_reader, + ) => WriterLifecycleExit::Reader(reader_res), + writer_res = writer_command_loop(rx, rpc_writer, cancel_writer) => { + WriterLifecycleExit::Writer(writer_res) } - let jitter_cap_ms = interval.as_millis() / 2; - let effective_jitter_ms = keepalive_jitter.as_millis().min(jitter_cap_ms).max(1); - Duration::from_millis(rand::rng().random_range(0..=effective_jitter_ms as u64)) - } else { - let jitter = rand::rng() - .random_range(-ME_ACTIVE_PING_JITTER_SECS..=ME_ACTIVE_PING_JITTER_SECS); - let wait = (ME_ACTIVE_PING_SECS as i64 + jitter).max(5) as u64; - Duration::from_secs(wait) + _ = ping_loop( + pool_ping, + writer_id, + tx_ping, + ping_tracker_ping, + stats_ping, + keepalive_enabled, + keepalive_interval, + keepalive_jitter, + cancel_ping, + ) => WriterLifecycleExit::Ping, + _ = rpc_proxy_req_signal_loop( + pool_signal, + writer_id, + tx_signal, + stats_signal, + cancel_signal, + keepalive_jitter_signal, + rpc_proxy_req_every_secs, + ) => WriterLifecycleExit::Signal, + _ = cancel_select.cancelled() => WriterLifecycleExit::Cancelled, }; - tokio::select! { - _ = cancel_ping_token.cancelled() => return, - _ = tokio::time::sleep(startup_jitter) => {} - } - loop { - let wait = if keepalive_enabled { - let mut interval = keepalive_interval; - if let Some(pool) = pool_ping.upgrade() { - if pool.registry.is_writer_empty(writer_id).await { - interval = interval.min(idle_interval_cap); - } + + match exit { + WriterLifecycleExit::Reader(res) => { + let idle_close_by_peer = if let Err(e) = res.as_ref() { + is_me_peer_closed_error(e) && reg.is_writer_empty(writer_id).await } else { - break; + false + }; + if idle_close_by_peer { + stats_reader_close.increment_me_idle_close_by_peer_total(); + info!(writer_id, "ME socket closed by peer on idle writer"); } - let jitter_cap_ms = interval.as_millis() / 2; - let effective_jitter_ms = - keepalive_jitter.as_millis().min(jitter_cap_ms).max(1); - interval - + Duration::from_millis( - rand::rng().random_range(0..=effective_jitter_ms as u64), - ) - } else { - let jitter = rand::rng() - .random_range(-ME_ACTIVE_PING_JITTER_SECS..=ME_ACTIVE_PING_JITTER_SECS); - let secs = (ME_ACTIVE_PING_SECS as i64 + jitter).max(5) as u64; - Duration::from_secs(secs) - }; - tokio::select! { - _ = cancel_ping_token.cancelled() => { - break; - } - _ = tokio::time::sleep(wait) => {} - } - let sent_id = ping_id; - let mut p = Vec::with_capacity(12); - p.extend_from_slice(&RPC_PING_U32.to_le_bytes()); - p.extend_from_slice(&sent_id.to_le_bytes()); - { - let mut tracker = ping_tracker_ping.lock().await; - let now_epoch_ms = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap_or_default() - .as_millis() as u64; - let mut run_cleanup = false; - if let Some(pool) = pool_ping.upgrade() { - let last_cleanup_ms = pool - .ping_tracker_last_cleanup_epoch_ms - .load(Ordering::Relaxed); - if now_epoch_ms.saturating_sub(last_cleanup_ms) >= 30_000 - && pool - .ping_tracker_last_cleanup_epoch_ms - .compare_exchange( - last_cleanup_ms, - now_epoch_ms, - Ordering::AcqRel, - Ordering::Relaxed, - ) - .is_ok() - { - run_cleanup = true; - } - } - - if run_cleanup { - let before = tracker.len(); - tracker.retain(|_, (ts, _)| ts.elapsed() < Duration::from_secs(120)); - let expired = before.saturating_sub(tracker.len()); - if expired > 0 { - stats_ping.increment_me_keepalive_timeout_by(expired as u64); - } - } - tracker.insert(sent_id, (std::time::Instant::now(), writer_id)); - } - ping_id = ping_id.wrapping_add(1); - stats_ping.increment_me_keepalive_sent(); - if tx_ping - .send(WriterCommand::DataAndFlush(Bytes::from(p))) - .await - .is_err() - { - stats_ping.increment_me_keepalive_failed(); - debug!("ME ping failed, removing dead writer"); - cancel_ping.cancel(); - if cleanup_for_ping - .compare_exchange(false, true, Ordering::AcqRel, Ordering::Relaxed) - .is_ok() - && let Some(pool) = pool_ping.upgrade() + if let Err(e) = res + && !idle_close_by_peer { - pool.remove_writer_and_close_clients(writer_id).await; + warn!(error = %e, "ME reader ended"); } - break; } - } - }); - - tokio::spawn(async move { - if rpc_proxy_req_every_secs == 0 { - return; - } - - let interval = Duration::from_secs(rpc_proxy_req_every_secs); - let startup_jitter_ms = { - let jitter_cap_ms = interval.as_millis() / 2; - let effective_jitter_ms = keepalive_jitter_signal - .as_millis() - .min(jitter_cap_ms) - .max(1); - rand::rng().random_range(0..=effective_jitter_ms as u64) - }; - - tokio::select! { - _ = cancel_signal.cancelled() => return, - _ = tokio::time::sleep(Duration::from_millis(startup_jitter_ms)) => {} - } - - loop { - let wait = { - let jitter_cap_ms = interval.as_millis() / 2; - let effective_jitter_ms = keepalive_jitter_signal - .as_millis() - .min(jitter_cap_ms) - .max(1); - interval - + Duration::from_millis( - rand::rng().random_range(0..=effective_jitter_ms as u64), - ) - }; - - tokio::select! { - _ = cancel_signal.cancelled() => break, - _ = tokio::time::sleep(wait) => {} - } - - let Some(pool) = pool_signal.upgrade() else { - break; - }; - - let Some(meta) = pool.registry.get_last_writer_meta(writer_id).await else { - stats_signal.increment_me_rpc_proxy_req_signal_skipped_no_meta_total(); - continue; - }; - - let (conn_id, mut service_rx) = pool.registry.register().await; - if !pool - .registry - .bind_writer(conn_id, writer_id, meta.clone()) - .await - { - let _ = pool.registry.unregister(conn_id).await; - stats_signal.increment_me_rpc_proxy_req_signal_skipped_no_meta_total(); - continue; - } - - let payload = build_proxy_req_payload( - conn_id, - meta.client_addr, - meta.our_addr, - &[], - pool.proxy_tag.as_deref(), - meta.proto_flags, - ); - - if tx_signal - .send(WriterCommand::DataAndFlush(payload)) - .await - .is_err() - { - stats_signal.increment_me_rpc_proxy_req_signal_failed_total(); - let _ = pool.registry.unregister(conn_id).await; - cancel_signal.cancel(); - if cleanup_for_signal - .compare_exchange(false, true, Ordering::AcqRel, Ordering::Relaxed) - .is_ok() - { - pool.remove_writer_and_close_clients(writer_id).await; + WriterLifecycleExit::Writer(res) => { + if let Err(e) = res { + warn!(error = %e, "ME writer command loop ended"); } - break; } - - stats_signal.increment_me_rpc_proxy_req_signal_sent_total(); - - if matches!( - tokio::time::timeout( - Duration::from_millis(ME_RPC_PROXY_REQ_RESPONSE_WAIT_MS), - service_rx.recv(), - ) - .await, - Ok(Some(_)) - ) { - stats_signal.increment_me_rpc_proxy_req_signal_response_total(); + WriterLifecycleExit::Ping => { + debug!(writer_id, "ME ping loop finished"); } - - let mut close_payload = Vec::with_capacity(12); - close_payload.extend_from_slice(&RPC_CLOSE_EXT_U32.to_le_bytes()); - close_payload.extend_from_slice(&conn_id.to_le_bytes()); - - if tx_signal - .send(WriterCommand::DataAndFlush(Bytes::from(close_payload))) - .await - .is_err() - { - stats_signal.increment_me_rpc_proxy_req_signal_failed_total(); - let _ = pool.registry.unregister(conn_id).await; - cancel_signal.cancel(); - if cleanup_for_signal - .compare_exchange(false, true, Ordering::AcqRel, Ordering::Relaxed) - .is_ok() - { - pool.remove_writer_and_close_clients(writer_id).await; - } - break; + WriterLifecycleExit::Signal => { + debug!(writer_id, "ME rpc_proxy_req signal loop finished"); } - - stats_signal.increment_me_rpc_proxy_req_signal_close_sent_total(); - let _ = pool.registry.unregister(conn_id).await; + WriterLifecycleExit::Cancelled => {} } + + if let Some(pool) = pool_lifecycle.upgrade() { + pool.remove_writer_and_close_clients(writer_id).await; + } else { + // Fallback for shutdown races: make lifecycle exit observable by prune. + cancel_cleanup.cancel(); + } + + let remaining = writers_arc.read().await.len(); + debug!(writer_id, remaining, "ME writer lifecycle task finished"); }); Ok(()) @@ -594,23 +605,36 @@ impl MePool { // The close command below is only a best-effort accelerator for task shutdown. // Cleanup progress must never depend on command-channel availability. let _ = self.registry.writer_lost(writer_id).await; - { - let mut tracker = self.ping_tracker.lock().await; - tracker.retain(|_, (_, wid)| *wid != writer_id); - } self.rtt_stats.lock().await.remove(&writer_id); if let Some(tx) = close_tx { - let _ = tx.send(WriterCommand::Close).await; + // Keep teardown critical path non-blocking: close is best-effort only. + let _ = tx.try_send(WriterCommand::Close); } if let Some(addr) = removed_addr { if let Some(uptime) = removed_uptime { - // Quarantine flapping endpoints regardless of draining state. - self.maybe_quarantine_flapping_endpoint(addr, uptime).await; + // Quarantine contract: only unexpected removals are considered endpoint flap. + if trigger_refill { + self.stats + .increment_me_endpoint_quarantine_unexpected_total(); + self.maybe_quarantine_flapping_endpoint(addr, uptime, "unexpected") + .await; + } else { + self.stats + .increment_me_endpoint_quarantine_draining_suppressed_total(); + debug!( + %addr, + uptime_ms = uptime.as_millis(), + "Skipping endpoint quarantine for draining writer removal" + ); + } } if trigger_refill && let Some(writer_dc) = removed_dc { self.trigger_immediate_refill_for_dc(addr, writer_dc); } } + if removed { + self.notify_writer_epoch(); + } removed } @@ -676,7 +700,10 @@ impl MePool { MeBindStaleMode::Never => false, MeBindStaleMode::Always => true, MeBindStaleMode::Ttl => { - let ttl_secs = self.me_bind_stale_ttl_secs.load(Ordering::Relaxed); + let ttl_secs = self + .binding_policy + .me_bind_stale_ttl_secs + .load(Ordering::Relaxed); if ttl_secs == 0 { return true; } diff --git a/src/transport/middle_proxy/reader.rs b/src/transport/middle_proxy/reader.rs index 4137b2b..aec55cd 100644 --- a/src/transport/middle_proxy/reader.rs +++ b/src/transport/middle_proxy/reader.rs @@ -32,10 +32,10 @@ pub(crate) async fn reader_loop( enc_leftover: BytesMut, mut dec: BytesMut, tx: mpsc::Sender, - ping_tracker: Arc>>, + ping_tracker: Arc>>, rtt_stats: Arc>>, stats: Arc, - _writer_id: u64, + writer_id: u64, degraded: Arc, writer_rtt_ema_ms_x10: Arc, reader_route_data_wait_ms: Arc, @@ -45,7 +45,7 @@ pub(crate) async fn reader_loop( let mut expected_seq: i32 = 0; loop { - let mut tmp = [0u8; 16_384]; + let mut tmp = [0u8; 65_536]; let n = tokio::select! { res = rd.read(&mut tmp) => res.map_err(ProxyError::Io)?, _ = cancel.cancelled() => return Ok(()), @@ -126,14 +126,10 @@ pub(crate) async fn reader_loop( let data = body.slice(12..); trace!(cid, flags, len = data.len(), "RPC_PROXY_ANS"); - let data_wait_ms = reader_route_data_wait_ms.load(Ordering::Relaxed); - let routed = if data_wait_ms == 0 { - reg.route_nowait(cid, MeResponse::Data { flags, data }) - .await - } else { - reg.route_with_timeout(cid, MeResponse::Data { flags, data }, data_wait_ms) - .await - }; + let route_wait_ms = reader_route_data_wait_ms.load(Ordering::Relaxed); + let routed = reg + .route_with_timeout(cid, MeResponse::Data { flags, data }, route_wait_ms) + .await; if !matches!(routed, RouteResult::Routed) { match routed { RouteResult::NoConn => stats.increment_me_route_drop_no_conn(), @@ -207,13 +203,13 @@ pub(crate) async fn reader_loop( } else if pt == RPC_PONG_U32 && body.len() >= 8 { let ping_id = i64::from_le_bytes(body[0..8].try_into().unwrap()); stats.increment_me_keepalive_pong(); - if let Some((sent, wid)) = { + if let Some(sent) = { let mut guard = ping_tracker.lock().await; guard.remove(&ping_id) } { let rtt = sent.elapsed().as_secs_f64() * 1000.0; let mut stats = rtt_stats.lock().await; - let entry = stats.entry(wid).or_insert((rtt, rtt)); + let entry = stats.entry(writer_id).or_insert((rtt, rtt)); entry.1 = entry.1 * 0.8 + rtt * 0.2; if rtt < entry.0 { entry.0 = rtt; @@ -228,7 +224,7 @@ pub(crate) async fn reader_loop( Ordering::Relaxed, ); trace!( - writer_id = wid, + writer_id, rtt_ms = rtt, ema_ms = entry.1, base_ms = entry.0, diff --git a/src/transport/middle_proxy/registry.rs b/src/transport/middle_proxy/registry.rs index 0a95e18..ff4a68b 100644 --- a/src/transport/middle_proxy/registry.rs +++ b/src/transport/middle_proxy/registry.rs @@ -3,8 +3,9 @@ use std::net::SocketAddr; use std::sync::atomic::{AtomicU8, AtomicU64, Ordering}; use std::time::{Duration, SystemTime, UNIX_EPOCH}; +use dashmap::DashMap; use tokio::sync::mpsc::error::TrySendError; -use tokio::sync::{RwLock, mpsc}; +use tokio::sync::{Mutex, mpsc}; use super::MeResponse; use super::codec::WriterCommand; @@ -50,8 +51,15 @@ pub(super) struct WriterActivitySnapshot { pub active_sessions_by_target_dc: HashMap, } -struct RegistryInner { - map: HashMap>, +struct RoutingTable { + map: DashMap>, +} + +struct BindingState { + inner: Mutex, +} + +struct BindingInner { writers: HashMap>, writer_for_conn: HashMap, conns_for_writer: HashMap>, @@ -60,10 +68,9 @@ struct RegistryInner { writer_idle_since_epoch_secs: HashMap, } -impl RegistryInner { +impl BindingInner { fn new() -> Self { Self { - map: HashMap::new(), writers: HashMap::new(), writer_for_conn: HashMap::new(), conns_for_writer: HashMap::new(), @@ -75,7 +82,8 @@ impl RegistryInner { } pub struct ConnRegistry { - inner: RwLock, + routing: RoutingTable, + binding: BindingState, next_id: AtomicU64, route_channel_capacity: usize, route_backpressure_base_timeout_ms: AtomicU64, @@ -94,7 +102,12 @@ impl ConnRegistry { pub fn with_route_channel_capacity(route_channel_capacity: usize) -> Self { let start = rand::random::() | 1; Self { - inner: RwLock::new(RegistryInner::new()), + routing: RoutingTable { + map: DashMap::new(), + }, + binding: BindingState { + inner: Mutex::new(BindingInner::new()), + }, next_id: AtomicU64::new(start), route_channel_capacity: route_channel_capacity.max(1), route_backpressure_base_timeout_ms: AtomicU64::new(ROUTE_BACKPRESSURE_BASE_TIMEOUT_MS), @@ -130,14 +143,14 @@ impl ConnRegistry { pub async fn register(&self) -> (u64, mpsc::Receiver) { let id = self.next_id.fetch_add(1, Ordering::Relaxed); let (tx, rx) = mpsc::channel(self.route_channel_capacity); - self.inner.write().await.map.insert(id, tx); + self.routing.map.insert(id, tx); (id, rx) } pub async fn register_writer(&self, writer_id: u64, tx: mpsc::Sender) { - let mut inner = self.inner.write().await; - inner.writers.insert(writer_id, tx); - inner + let mut binding = self.binding.inner.lock().await; + binding.writers.insert(writer_id, tx); + binding .conns_for_writer .entry(writer_id) .or_insert_with(HashSet::new); @@ -145,18 +158,18 @@ impl ConnRegistry { /// Unregister connection, returning associated writer_id if any. pub async fn unregister(&self, id: u64) -> Option { - let mut inner = self.inner.write().await; - inner.map.remove(&id); - inner.meta.remove(&id); - if let Some(writer_id) = inner.writer_for_conn.remove(&id) { - let became_empty = if let Some(set) = inner.conns_for_writer.get_mut(&writer_id) { + self.routing.map.remove(&id); + let mut binding = self.binding.inner.lock().await; + binding.meta.remove(&id); + if let Some(writer_id) = binding.writer_for_conn.remove(&id) { + let became_empty = if let Some(set) = binding.conns_for_writer.get_mut(&writer_id) { set.remove(&id); set.is_empty() } else { false }; if became_empty { - inner + binding .writer_idle_since_epoch_secs .insert(writer_id, Self::now_epoch_secs()); } @@ -167,10 +180,7 @@ impl ConnRegistry { #[allow(dead_code)] pub async fn route(&self, id: u64, resp: MeResponse) -> RouteResult { - let tx = { - let inner = self.inner.read().await; - inner.map.get(&id).cloned() - }; + let tx = self.routing.map.get(&id).map(|entry| entry.value().clone()); let Some(tx) = tx else { return RouteResult::NoConn; @@ -223,10 +233,7 @@ impl ConnRegistry { } pub async fn route_nowait(&self, id: u64, resp: MeResponse) -> RouteResult { - let tx = { - let inner = self.inner.read().await; - inner.map.get(&id).cloned() - }; + let tx = self.routing.map.get(&id).map(|entry| entry.value().clone()); let Some(tx) = tx else { return RouteResult::NoConn; @@ -249,10 +256,7 @@ impl ConnRegistry { return self.route_nowait(id, resp).await; } - let tx = { - let inner = self.inner.read().await; - inner.map.get(&id).cloned() - }; + let tx = self.routing.map.get(&id).map(|entry| entry.value().clone()); let Some(tx) = tx else { return RouteResult::NoConn; @@ -291,33 +295,39 @@ impl ConnRegistry { } pub async fn bind_writer(&self, conn_id: u64, writer_id: u64, meta: ConnMeta) -> bool { - let mut inner = self.inner.write().await; - if !inner.writers.contains_key(&writer_id) { + let mut binding = self.binding.inner.lock().await; + // ROUTING IS THE SOURCE OF TRUTH: + // never keep/attach writer binding for a connection that is already + // absent from the routing table. + if !self.routing.map.contains_key(&conn_id) { + return false; + } + if !binding.writers.contains_key(&writer_id) { return false; } - let previous_writer_id = inner.writer_for_conn.insert(conn_id, writer_id); + let previous_writer_id = binding.writer_for_conn.insert(conn_id, writer_id); if let Some(previous_writer_id) = previous_writer_id && previous_writer_id != writer_id { let became_empty = - if let Some(set) = inner.conns_for_writer.get_mut(&previous_writer_id) { + if let Some(set) = binding.conns_for_writer.get_mut(&previous_writer_id) { set.remove(&conn_id); set.is_empty() } else { false }; if became_empty { - inner + binding .writer_idle_since_epoch_secs .insert(previous_writer_id, Self::now_epoch_secs()); } } - inner.meta.insert(conn_id, meta.clone()); - inner.last_meta_for_writer.insert(writer_id, meta); - inner.writer_idle_since_epoch_secs.remove(&writer_id); - inner + binding.meta.insert(conn_id, meta.clone()); + binding.last_meta_for_writer.insert(writer_id, meta); + binding.writer_idle_since_epoch_secs.remove(&writer_id); + binding .conns_for_writer .entry(writer_id) .or_insert_with(HashSet::new) @@ -326,32 +336,32 @@ impl ConnRegistry { } pub async fn mark_writer_idle(&self, writer_id: u64) { - let mut inner = self.inner.write().await; - inner + let mut binding = self.binding.inner.lock().await; + binding .conns_for_writer .entry(writer_id) .or_insert_with(HashSet::new); - inner + binding .writer_idle_since_epoch_secs .entry(writer_id) .or_insert(Self::now_epoch_secs()); } pub async fn get_last_writer_meta(&self, writer_id: u64) -> Option { - let inner = self.inner.read().await; - inner.last_meta_for_writer.get(&writer_id).cloned() + let binding = self.binding.inner.lock().await; + binding.last_meta_for_writer.get(&writer_id).cloned() } pub async fn writer_idle_since_snapshot(&self) -> HashMap { - let inner = self.inner.read().await; - inner.writer_idle_since_epoch_secs.clone() + let binding = self.binding.inner.lock().await; + binding.writer_idle_since_epoch_secs.clone() } pub async fn writer_idle_since_for_writer_ids(&self, writer_ids: &[u64]) -> HashMap { - let inner = self.inner.read().await; + let binding = self.binding.inner.lock().await; let mut out = HashMap::::with_capacity(writer_ids.len()); for writer_id in writer_ids { - if let Some(idle_since) = inner.writer_idle_since_epoch_secs.get(writer_id).copied() { + if let Some(idle_since) = binding.writer_idle_since_epoch_secs.get(writer_id).copied() { out.insert(*writer_id, idle_since); } } @@ -359,14 +369,14 @@ impl ConnRegistry { } pub(super) async fn writer_activity_snapshot(&self) -> WriterActivitySnapshot { - let inner = self.inner.read().await; + let binding = self.binding.inner.lock().await; let mut bound_clients_by_writer = HashMap::::new(); let mut active_sessions_by_target_dc = HashMap::::new(); - for (writer_id, conn_ids) in &inner.conns_for_writer { + for (writer_id, conn_ids) in &binding.conns_for_writer { bound_clients_by_writer.insert(*writer_id, conn_ids.len()); } - for conn_meta in inner.meta.values() { + for conn_meta in binding.meta.values() { if conn_meta.target_dc == 0 { continue; } @@ -382,9 +392,39 @@ impl ConnRegistry { } pub async fn get_writer(&self, conn_id: u64) -> Option { - let inner = self.inner.read().await; - let writer_id = inner.writer_for_conn.get(&conn_id).cloned()?; - let writer = inner.writers.get(&writer_id).cloned()?; + let mut binding = self.binding.inner.lock().await; + // ROUTING IS THE SOURCE OF TRUTH: + // stale bindings are ignored and lazily cleaned when routing no longer + // contains the connection. + if !self.routing.map.contains_key(&conn_id) { + binding.meta.remove(&conn_id); + if let Some(stale_writer_id) = binding.writer_for_conn.remove(&conn_id) + && let Some(conns) = binding.conns_for_writer.get_mut(&stale_writer_id) + { + conns.remove(&conn_id); + if conns.is_empty() { + binding + .writer_idle_since_epoch_secs + .insert(stale_writer_id, Self::now_epoch_secs()); + } + } + return None; + } + + let writer_id = binding.writer_for_conn.get(&conn_id).copied()?; + let Some(writer) = binding.writers.get(&writer_id).cloned() else { + binding.writer_for_conn.remove(&conn_id); + binding.meta.remove(&conn_id); + if let Some(conns) = binding.conns_for_writer.get_mut(&writer_id) { + conns.remove(&conn_id); + if conns.is_empty() { + binding + .writer_idle_since_epoch_secs + .insert(writer_id, Self::now_epoch_secs()); + } + } + return None; + }; Some(ConnWriter { writer_id, tx: writer, @@ -392,16 +432,16 @@ impl ConnRegistry { } pub async fn active_conn_ids(&self) -> Vec { - let inner = self.inner.read().await; - inner.writer_for_conn.keys().copied().collect() + let binding = self.binding.inner.lock().await; + binding.writer_for_conn.keys().copied().collect() } pub async fn writer_lost(&self, writer_id: u64) -> Vec { - let mut inner = self.inner.write().await; - inner.writers.remove(&writer_id); - inner.last_meta_for_writer.remove(&writer_id); - inner.writer_idle_since_epoch_secs.remove(&writer_id); - let conns = inner + let mut binding = self.binding.inner.lock().await; + binding.writers.remove(&writer_id); + binding.last_meta_for_writer.remove(&writer_id); + binding.writer_idle_since_epoch_secs.remove(&writer_id); + let conns = binding .conns_for_writer .remove(&writer_id) .unwrap_or_default() @@ -410,11 +450,11 @@ impl ConnRegistry { let mut out = Vec::new(); for conn_id in conns { - if inner.writer_for_conn.get(&conn_id).copied() != Some(writer_id) { + if binding.writer_for_conn.get(&conn_id).copied() != Some(writer_id) { continue; } - inner.writer_for_conn.remove(&conn_id); - if let Some(m) = inner.meta.get(&conn_id) { + binding.writer_for_conn.remove(&conn_id); + if let Some(m) = binding.meta.get(&conn_id) { out.push(BoundConn { conn_id, meta: m.clone(), @@ -426,13 +466,13 @@ impl ConnRegistry { #[allow(dead_code)] pub async fn get_meta(&self, conn_id: u64) -> Option { - let inner = self.inner.read().await; - inner.meta.get(&conn_id).cloned() + let binding = self.binding.inner.lock().await; + binding.meta.get(&conn_id).cloned() } pub async fn is_writer_empty(&self, writer_id: u64) -> bool { - let inner = self.inner.read().await; - inner + let binding = self.binding.inner.lock().await; + binding .conns_for_writer .get(&writer_id) .map(|s| s.is_empty()) @@ -441,8 +481,8 @@ impl ConnRegistry { #[allow(dead_code)] pub async fn unregister_writer_if_empty(&self, writer_id: u64) -> bool { - let mut inner = self.inner.write().await; - let Some(conn_ids) = inner.conns_for_writer.get(&writer_id) else { + let mut binding = self.binding.inner.lock().await; + let Some(conn_ids) = binding.conns_for_writer.get(&writer_id) else { // Writer is already absent from the registry. return true; }; @@ -450,19 +490,19 @@ impl ConnRegistry { return false; } - inner.writers.remove(&writer_id); - inner.last_meta_for_writer.remove(&writer_id); - inner.writer_idle_since_epoch_secs.remove(&writer_id); - inner.conns_for_writer.remove(&writer_id); + binding.writers.remove(&writer_id); + binding.last_meta_for_writer.remove(&writer_id); + binding.writer_idle_since_epoch_secs.remove(&writer_id); + binding.conns_for_writer.remove(&writer_id); true } #[allow(dead_code)] pub(super) async fn non_empty_writer_ids(&self, writer_ids: &[u64]) -> HashSet { - let inner = self.inner.read().await; + let binding = self.binding.inner.lock().await; let mut out = HashSet::::with_capacity(writer_ids.len()); for writer_id in writer_ids { - if let Some(conns) = inner.conns_for_writer.get(writer_id) + if let Some(conns) = binding.conns_for_writer.get(writer_id) && !conns.is_empty() { out.insert(*writer_id); diff --git a/src/transport/middle_proxy/secret.rs b/src/transport/middle_proxy/secret.rs index 504270a..a167773 100644 --- a/src/transport/middle_proxy/secret.rs +++ b/src/transport/middle_proxy/secret.rs @@ -1,9 +1,12 @@ use httpdate; +use std::sync::Arc; use std::time::SystemTime; use tracing::{debug, info, warn}; +use super::http_fetch::https_get; use super::selftest::record_timeskew_sample; use crate::error::{ProxyError, Result}; +use crate::transport::UpstreamManager; pub const PROXY_SECRET_MIN_LEN: usize = 32; @@ -33,11 +36,21 @@ pub(super) fn validate_proxy_secret_len(data_len: usize, max_len: usize) -> Resu } /// Fetch Telegram proxy-secret binary. +#[allow(dead_code)] pub async fn fetch_proxy_secret(cache_path: Option<&str>, max_len: usize) -> Result> { + fetch_proxy_secret_with_upstream(cache_path, max_len, None).await +} + +/// Fetch Telegram proxy-secret binary, optionally through upstream routing. +pub async fn fetch_proxy_secret_with_upstream( + cache_path: Option<&str>, + max_len: usize, + upstream: Option>, +) -> Result> { let cache = cache_path.unwrap_or("proxy-secret"); // 1) Try fresh download first. - match download_proxy_secret_with_max_len(max_len).await { + match download_proxy_secret_with_max_len_via_upstream(max_len, upstream).await { Ok(data) => { if let Err(e) = tokio::fs::write(cache, &data).await { warn!(error = %e, "Failed to cache proxy-secret (non-fatal)"); @@ -76,20 +89,25 @@ pub async fn fetch_proxy_secret(cache_path: Option<&str>, max_len: usize) -> Res } } +#[allow(dead_code)] pub async fn download_proxy_secret_with_max_len(max_len: usize) -> Result> { - let resp = reqwest::get("https://core.telegram.org/getProxySecret") - .await - .map_err(|e| ProxyError::Proxy(format!("Failed to download proxy-secret: {e}")))?; + download_proxy_secret_with_max_len_via_upstream(max_len, None).await +} - if !resp.status().is_success() { +pub async fn download_proxy_secret_with_max_len_via_upstream( + max_len: usize, + upstream: Option>, +) -> Result> { + let resp = https_get("https://core.telegram.org/getProxySecret", upstream).await?; + + if !(200..=299).contains(&resp.status) { return Err(ProxyError::Proxy(format!( "proxy-secret download HTTP {}", - resp.status() + resp.status ))); } - if let Some(date) = resp.headers().get(reqwest::header::DATE) - && let Ok(date_str) = date.to_str() + if let Some(date_str) = resp.date_header.as_deref() && let Ok(server_time) = httpdate::parse_http_date(date_str) && let Ok(skew) = SystemTime::now() .duration_since(server_time) @@ -110,11 +128,7 @@ pub async fn download_proxy_secret_with_max_len(max_len: usize) -> Result = None; @@ -77,7 +83,11 @@ impl MePool { let mut async_recovery_triggered = false; let mut hybrid_recovery_round = 0u32; let mut hybrid_last_recovery_at: Option = None; - let hybrid_wait_step = self.me_route_no_writer_wait.max(Duration::from_millis(50)); + let mut hybrid_total_deadline: Option = None; + let hybrid_wait_step = self + .route_runtime + .me_route_no_writer_wait + .max(Duration::from_millis(50)); let mut hybrid_wait_current = hybrid_wait_step; loop { @@ -92,9 +102,13 @@ impl MePool { .tx .try_send(WriterCommand::Data(current_payload.clone())) { - Ok(()) => return Ok(()), + Ok(()) => { + self.note_hybrid_route_success(); + return Ok(()); + } Err(TrySendError::Full(cmd)) => { if current.tx.send(cmd).await.is_ok() { + self.note_hybrid_route_success(); return Ok(()); } warn!(writer_id = current.writer_id, "ME writer channel closed"); @@ -118,7 +132,7 @@ impl MePool { match no_writer_mode { MeRouteNoWriterMode::AsyncRecoveryFailfast => { let deadline = *no_writer_deadline.get_or_insert_with(|| { - Instant::now() + self.me_route_no_writer_wait + Instant::now() + self.route_runtime.me_route_no_writer_wait }); if !async_recovery_triggered && !unknown_target_dc { let triggered = @@ -139,7 +153,9 @@ impl MePool { MeRouteNoWriterMode::InlineRecoveryLegacy => { self.stats.increment_me_inline_recovery_total(); if !unknown_target_dc { - for _ in 0..self.me_route_inline_recovery_attempts.max(1) { + for _ in + 0..self.route_runtime.me_route_inline_recovery_attempts.max(1) + { for family in self.family_order() { let map = match family { IpFamily::V4 => self.proxy_map_v4.read().await.clone(), @@ -168,7 +184,7 @@ impl MePool { continue; } let deadline = *no_writer_deadline.get_or_insert_with(|| { - Instant::now() + self.me_route_inline_recovery_wait + Instant::now() + self.route_runtime.me_route_inline_recovery_wait }); if !self.wait_for_writer_until(deadline).await { if !self.writers.read().await.is_empty() { @@ -182,6 +198,15 @@ impl MePool { continue; } MeRouteNoWriterMode::HybridAsyncPersistent => { + let total_deadline = *hybrid_total_deadline.get_or_insert_with(|| { + Instant::now() + self.hybrid_total_wait_budget() + }); + if Instant::now() >= total_deadline { + self.on_hybrid_timeout(total_deadline, routed_dc); + return Err(ProxyError::Proxy( + "ME writer not available within hybrid timeout".into(), + )); + } if !unknown_target_dc { self.maybe_trigger_hybrid_recovery( routed_dc, @@ -214,8 +239,9 @@ impl MePool { let pick_mode = self.writer_pick_mode(); match no_writer_mode { MeRouteNoWriterMode::AsyncRecoveryFailfast => { - let deadline = *no_writer_deadline - .get_or_insert_with(|| Instant::now() + self.me_route_no_writer_wait); + let deadline = *no_writer_deadline.get_or_insert_with(|| { + Instant::now() + self.route_runtime.me_route_no_writer_wait + }); if !async_recovery_triggered && !unknown_target_dc { let triggered = self.trigger_async_recovery_for_target_dc(routed_dc).await; @@ -238,7 +264,7 @@ impl MePool { self.stats.increment_me_inline_recovery_total(); if unknown_target_dc { let deadline = *no_writer_deadline.get_or_insert_with(|| { - Instant::now() + self.me_route_inline_recovery_wait + Instant::now() + self.route_runtime.me_route_inline_recovery_wait }); if self.wait_for_candidate_until(routed_dc, deadline).await { continue; @@ -250,7 +276,9 @@ impl MePool { "No ME writers available for target DC".into(), )); } - if emergency_attempts >= self.me_route_inline_recovery_attempts.max(1) { + if emergency_attempts + >= self.route_runtime.me_route_inline_recovery_attempts.max(1) + { self.stats .increment_me_writer_pick_no_candidate_total(pick_mode); self.stats.increment_me_no_writer_failfast_total(); @@ -292,6 +320,16 @@ impl MePool { } } MeRouteNoWriterMode::HybridAsyncPersistent => { + let total_deadline = *hybrid_total_deadline.get_or_insert_with(|| { + Instant::now() + self.hybrid_total_wait_budget() + }); + if Instant::now() >= total_deadline { + self.on_hybrid_timeout(total_deadline, routed_dc); + return Err(ProxyError::Proxy( + "No ME writers available for target DC within hybrid timeout" + .into(), + )); + } if !unknown_target_dc { self.maybe_trigger_hybrid_recovery( routed_dc, @@ -332,7 +370,11 @@ impl MePool { pick_sample_size, ) } else { - if self.me_deterministic_writer_sort.load(Ordering::Relaxed) { + if self + .writer_selection_policy + .me_deterministic_writer_sort + .load(Ordering::Relaxed) + { candidate_indices.sort_by(|lhs, rhs| { let left = &writers_snapshot[*lhs]; let right = &writers_snapshot[*rhs]; @@ -423,6 +465,7 @@ impl MePool { "Selected stale ME writer for fallback bind" ); } + self.note_hybrid_route_success(); return Ok(()); } Err(TrySendError::Full(_)) => { @@ -453,7 +496,19 @@ impl MePool { .increment_me_writer_pick_blocking_fallback_total(); let effective_our_addr = SocketAddr::new(w.source_ip, our_addr.port()); let (payload, meta) = build_routed_payload(effective_our_addr); - match w.tx.clone().reserve_owned().await { + let reserve_result = + if let Some(timeout) = self.route_runtime.me_route_blocking_send_timeout { + match tokio::time::timeout(timeout, w.tx.clone().reserve_owned()).await { + Ok(result) => result, + Err(_) => { + self.stats.increment_me_writer_pick_full_total(pick_mode); + continue; + } + } + } else { + w.tx.clone().reserve_owned().await + }; + match reserve_result { Ok(permit) => { if !self.registry.bind_writer(conn_id, w.id, meta).await { debug!( @@ -471,6 +526,7 @@ impl MePool { if w.generation < self.current_generation() { self.stats.increment_pool_stale_pick_total(); } + self.note_hybrid_route_success(); return Ok(()); } Err(_) => { @@ -483,7 +539,7 @@ impl MePool { } async fn wait_for_writer_until(&self, deadline: Instant) -> bool { - let waiter = self.writer_available.notified(); + let mut rx = self.writer_epoch.subscribe(); if !self.writers.read().await.is_empty() { return true; } @@ -492,13 +548,14 @@ impl MePool { return !self.writers.read().await.is_empty(); } let timeout = deadline.saturating_duration_since(now); - if tokio::time::timeout(timeout, waiter).await.is_ok() { - return true; + if tokio::time::timeout(timeout, rx.changed()).await.is_ok() { + return !self.writers.read().await.is_empty(); } !self.writers.read().await.is_empty() } async fn wait_for_candidate_until(&self, routed_dc: i32, deadline: Instant) -> bool { + let mut rx = self.writer_epoch.subscribe(); loop { if self.has_candidate_for_target_dc(routed_dc).await { return true; @@ -509,7 +566,6 @@ impl MePool { return self.has_candidate_for_target_dc(routed_dc).await; } - let waiter = self.writer_available.notified(); if self.has_candidate_for_target_dc(routed_dc).await { return true; } @@ -517,7 +573,7 @@ impl MePool { if remaining.is_zero() { return self.has_candidate_for_target_dc(routed_dc).await; } - if tokio::time::timeout(remaining, waiter).await.is_err() { + if tokio::time::timeout(remaining, rx.changed()).await.is_err() { return self.has_candidate_for_target_dc(routed_dc).await; } } @@ -587,6 +643,9 @@ impl MePool { hybrid_last_recovery_at: &mut Option, hybrid_wait_step: Duration, ) { + if !self.try_consume_hybrid_recovery_trigger_slot(HYBRID_RECOVERY_TRIGGER_MIN_INTERVAL_MS) { + return; + } if let Some(last) = *hybrid_last_recovery_at && last.elapsed() < hybrid_wait_step { @@ -602,6 +661,78 @@ impl MePool { *hybrid_last_recovery_at = Some(Instant::now()); } + fn hybrid_total_wait_budget(&self) -> Duration { + let base = self + .route_runtime + .me_route_hybrid_max_wait + .max(Duration::from_millis(50)); + let now_ms = Self::now_epoch_millis(); + let last_success_ms = self + .route_runtime + .me_route_last_success_epoch_ms + .load(Ordering::Relaxed); + if last_success_ms != 0 + && now_ms.saturating_sub(last_success_ms) <= HYBRID_RECENT_SUCCESS_WINDOW_MS + { + return base.saturating_mul(2); + } + base + } + + fn note_hybrid_route_success(&self) { + self.route_runtime + .me_route_last_success_epoch_ms + .store(Self::now_epoch_millis(), Ordering::Relaxed); + } + + fn on_hybrid_timeout(&self, deadline: Instant, routed_dc: i32) { + self.stats.increment_me_hybrid_timeout_total(); + let now_ms = Self::now_epoch_millis(); + let mut last_warn_ms = self + .route_runtime + .me_route_hybrid_timeout_warn_epoch_ms + .load(Ordering::Relaxed); + while now_ms.saturating_sub(last_warn_ms) >= HYBRID_TIMEOUT_WARN_RATE_LIMIT_MS { + match self + .route_runtime + .me_route_hybrid_timeout_warn_epoch_ms + .compare_exchange_weak(last_warn_ms, now_ms, Ordering::AcqRel, Ordering::Relaxed) + { + Ok(_) => { + warn!( + routed_dc, + budget_ms = self.hybrid_total_wait_budget().as_millis() as u64, + elapsed_ms = deadline.elapsed().as_millis() as u64, + "ME hybrid route timeout reached" + ); + break; + } + Err(actual) => last_warn_ms = actual, + } + } + } + + fn try_consume_hybrid_recovery_trigger_slot(&self, min_interval_ms: u64) -> bool { + let now_ms = Self::now_epoch_millis(); + let mut last_trigger_ms = self + .route_runtime + .me_async_recovery_last_trigger_epoch_ms + .load(Ordering::Relaxed); + loop { + if now_ms.saturating_sub(last_trigger_ms) < min_interval_ms { + return false; + } + match self + .route_runtime + .me_async_recovery_last_trigger_epoch_ms + .compare_exchange_weak(last_trigger_ms, now_ms, Ordering::AcqRel, Ordering::Relaxed) + { + Ok(_) => return true, + Err(actual) => last_trigger_ms = actual, + } + } + } + pub async fn send_close(self: &Arc, conn_id: u64) -> Result<()> { if let Some(w) = self.registry.get_writer(conn_id).await { let mut p = Vec::with_capacity(12); @@ -749,7 +880,7 @@ impl MePool { (self.writer_idle_rank_for_selection(writer, idle_since_by_writer, now_epoch_secs) as u64) * 100; - let queue_cap = self.writer_cmd_channel_capacity.max(1) as u64; + let queue_cap = self.writer_lifecycle.writer_cmd_channel_capacity.max(1) as u64; let queue_remaining = writer.tx.capacity() as u64; let queue_used = queue_cap.saturating_sub(queue_remaining.min(queue_cap)); let queue_util_pct = queue_used.saturating_mul(100) / queue_cap; diff --git a/src/transport/middle_proxy/tests/health_adversarial_tests.rs b/src/transport/middle_proxy/tests/health_adversarial_tests.rs index 3444120..4bee91c 100644 --- a/src/transport/middle_proxy/tests/health_adversarial_tests.rs +++ b/src/transport/middle_proxy/tests/health_adversarial_tests.rs @@ -113,6 +113,8 @@ async fn make_pool( general.me_warn_rate_limit_ms, MeRouteNoWriterMode::default(), general.me_route_no_writer_wait_ms, + general.me_route_hybrid_max_wait_ms, + general.me_route_blocking_send_timeout_ms, general.me_route_inline_recovery_attempts, general.me_route_inline_recovery_wait_ms, ); diff --git a/src/transport/middle_proxy/tests/health_integration_tests.rs b/src/transport/middle_proxy/tests/health_integration_tests.rs index b0d3a2a..0a6e110 100644 --- a/src/transport/middle_proxy/tests/health_integration_tests.rs +++ b/src/transport/middle_proxy/tests/health_integration_tests.rs @@ -111,6 +111,8 @@ async fn make_pool( general.me_warn_rate_limit_ms, MeRouteNoWriterMode::default(), general.me_route_no_writer_wait_ms, + general.me_route_hybrid_max_wait_ms, + general.me_route_blocking_send_timeout_ms, general.me_route_inline_recovery_attempts, general.me_route_inline_recovery_wait_ms, ); diff --git a/src/transport/middle_proxy/tests/health_regression_tests.rs b/src/transport/middle_proxy/tests/health_regression_tests.rs index 55bf8f6..92398b4 100644 --- a/src/transport/middle_proxy/tests/health_regression_tests.rs +++ b/src/transport/middle_proxy/tests/health_regression_tests.rs @@ -106,6 +106,8 @@ async fn make_pool(me_pool_drain_threshold: u64) -> Arc { general.me_warn_rate_limit_ms, MeRouteNoWriterMode::default(), general.me_route_no_writer_wait_ms, + general.me_route_hybrid_max_wait_ms, + general.me_route_blocking_send_timeout_ms, general.me_route_inline_recovery_attempts, general.me_route_inline_recovery_wait_ms, ) diff --git a/src/transport/middle_proxy/tests/pool_refill_security_tests.rs b/src/transport/middle_proxy/tests/pool_refill_security_tests.rs index 2d1e23a..90c8382 100644 --- a/src/transport/middle_proxy/tests/pool_refill_security_tests.rs +++ b/src/transport/middle_proxy/tests/pool_refill_security_tests.rs @@ -95,6 +95,8 @@ async fn make_pool() -> Arc { general.me_warn_rate_limit_ms, MeRouteNoWriterMode::default(), general.me_route_no_writer_wait_ms, + general.me_route_hybrid_max_wait_ms, + general.me_route_blocking_send_timeout_ms, general.me_route_inline_recovery_attempts, general.me_route_inline_recovery_wait_ms, ) diff --git a/src/transport/middle_proxy/tests/pool_writer_security_tests.rs b/src/transport/middle_proxy/tests/pool_writer_security_tests.rs index 7bfc061..0184e11 100644 --- a/src/transport/middle_proxy/tests/pool_writer_security_tests.rs +++ b/src/transport/middle_proxy/tests/pool_writer_security_tests.rs @@ -35,7 +35,7 @@ async fn make_pool() -> Arc { NetworkDecision::default(), None, Arc::new(SecureRandom::new()), - Arc::new(Stats::default()), + Arc::new(Stats::new()), general.me_keepalive_enabled, general.me_keepalive_interval_secs, general.me_keepalive_jitter_secs, @@ -100,6 +100,8 @@ async fn make_pool() -> Arc { general.me_warn_rate_limit_ms, MeRouteNoWriterMode::default(), general.me_route_no_writer_wait_ms, + general.me_route_hybrid_max_wait_ms, + general.me_route_blocking_send_timeout_ms, general.me_route_inline_recovery_attempts, general.me_route_inline_recovery_wait_ms, ) @@ -171,10 +173,15 @@ async fn bind_conn_to_writer(pool: &Arc, writer_id: u64, port: u16) -> u } #[tokio::test] -async fn remove_draining_writer_still_quarantines_flapping_endpoint() { +async fn remove_draining_writer_does_not_quarantine_flapping_endpoint() { let pool = make_pool().await; let writer_id = 77; let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 12, 0, 77)), 443); + let before_total = pool.stats.get_me_endpoint_quarantine_total(); + let before_unexpected = pool.stats.get_me_endpoint_quarantine_unexpected_total(); + let before_suppressed = pool + .stats + .get_me_endpoint_quarantine_draining_suppressed_total(); insert_writer( &pool, writer_id, @@ -198,8 +205,18 @@ async fn remove_draining_writer_still_quarantines_flapping_endpoint() { "writer must be removed from pool after cleanup" ); assert!( - pool.is_endpoint_quarantined(addr).await, - "draining removals must still quarantine flapping endpoints" + !pool.is_endpoint_quarantined(addr).await, + "draining removals must not quarantine endpoint" + ); + assert_eq!(pool.stats.get_me_endpoint_quarantine_total(), before_total); + assert_eq!( + pool.stats.get_me_endpoint_quarantine_unexpected_total(), + before_unexpected + ); + assert_eq!( + pool.stats + .get_me_endpoint_quarantine_draining_suppressed_total(), + before_suppressed + 1 ); assert_eq!(pool.conn_count.load(Ordering::Relaxed), 0); } @@ -255,16 +272,21 @@ async fn edge_draining_only_detach_rejects_active_writer() { } #[tokio::test] -async fn adversarial_blackhat_single_remove_establishes_single_quarantine_entry() { +async fn adversarial_blackhat_single_unexpected_remove_establishes_single_quarantine_entry() { let pool = make_pool().await; let writer_id = 93; let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 12, 0, 93)), 443); + let before_total = pool.stats.get_me_endpoint_quarantine_total(); + let before_unexpected = pool.stats.get_me_endpoint_quarantine_unexpected_total(); + let before_suppressed = pool + .stats + .get_me_endpoint_quarantine_draining_suppressed_total(); insert_writer( &pool, writer_id, 2, addr, - true, + false, Instant::now() - Duration::from_secs(1), ) .await; @@ -272,6 +294,49 @@ async fn adversarial_blackhat_single_remove_establishes_single_quarantine_entry( pool.remove_writer_and_close_clients(writer_id).await; assert!(pool.is_endpoint_quarantined(addr).await); assert_eq!(pool.endpoint_quarantine.lock().await.len(), 1); + assert_eq!( + pool.stats.get_me_endpoint_quarantine_total(), + before_total + 1 + ); + assert_eq!( + pool.stats.get_me_endpoint_quarantine_unexpected_total(), + before_unexpected + 1 + ); + assert_eq!( + pool.stats + .get_me_endpoint_quarantine_draining_suppressed_total(), + before_suppressed + ); +} + +#[tokio::test] +async fn remove_ultra_short_uptime_writer_skips_flap_quarantine() { + let pool = make_pool().await; + let writer_id = 931; + let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 12, 0, 131)), 443); + let before_total = pool.stats.get_me_endpoint_quarantine_total(); + let before_unexpected = pool.stats.get_me_endpoint_quarantine_unexpected_total(); + insert_writer( + &pool, + writer_id, + 2, + addr, + false, + Instant::now() - Duration::from_millis(50), + ) + .await; + + pool.remove_writer_and_close_clients(writer_id).await; + + assert!( + !pool.is_endpoint_quarantined(addr).await, + "ultra-short unexpected lifetime must not quarantine endpoint" + ); + assert_eq!(pool.stats.get_me_endpoint_quarantine_total(), before_total); + assert_eq!( + pool.stats.get_me_endpoint_quarantine_unexpected_total(), + before_unexpected + 1 + ); } #[tokio::test] diff --git a/src/transport/middle_proxy/tests/send_adversarial_tests.rs b/src/transport/middle_proxy/tests/send_adversarial_tests.rs index 80379a5..de52d18 100644 --- a/src/transport/middle_proxy/tests/send_adversarial_tests.rs +++ b/src/transport/middle_proxy/tests/send_adversarial_tests.rs @@ -106,6 +106,8 @@ async fn make_pool() -> (Arc, Arc) { general.me_warn_rate_limit_ms, general.me_route_no_writer_mode, general.me_route_no_writer_wait_ms, + general.me_route_hybrid_max_wait_ms, + general.me_route_blocking_send_timeout_ms, general.me_route_inline_recovery_attempts, general.me_route_inline_recovery_wait_ms, ); diff --git a/src/transport/pool.rs b/src/transport/pool.rs index 60f8a01..bb0baac 100644 --- a/src/transport/pool.rs +++ b/src/transport/pool.rs @@ -201,7 +201,10 @@ impl ConnectionPool { pub async fn close_all(&self) { let pools_snapshot: Vec<(SocketAddr, Arc>)> = { let pools = self.pools.read(); - pools.iter().map(|(addr, pool)| (*addr, Arc::clone(pool))).collect() + pools + .iter() + .map(|(addr, pool)| (*addr, Arc::clone(pool))) + .collect() }; for (addr, pool) in pools_snapshot {