diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml new file mode 100644 index 0000000..1b6e455 --- /dev/null +++ b/.github/workflows/build.yml @@ -0,0 +1,39 @@ +name: Build + +on: + push: + branches: [ "*" ] + pull_request: + branches: [ "*" ] + +env: + CARGO_TERM_COLOR: always + +jobs: + build: + name: Build + runs-on: ubuntu-latest + + permissions: + contents: read + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Install latest stable Rust toolchain + uses: dtolnay/rust-toolchain@stable + + - name: Cache cargo registry & build artifacts + uses: actions/cache@v4 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + target + key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} + restore-keys: | + ${{ runner.os }}-cargo- + + - name: Build Release + run: cargo build --release --verbose \ No newline at end of file diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index def299d..d01293e 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -26,6 +26,9 @@ jobs: name: GNU ${{ matrix.target }} runs-on: ubuntu-latest + container: + image: rust:slim-bookworm + strategy: fail-fast: false matrix: @@ -47,8 +50,8 @@ jobs: - name: Install deps run: | - sudo apt-get update - sudo apt-get install -y \ + apt-get update + apt-get install -y \ build-essential \ clang \ lld \ @@ -69,14 +72,10 @@ jobs: if [ "${{ matrix.target }}" = "aarch64-unknown-linux-gnu" ]; then export CC=aarch64-linux-gnu-gcc export CXX=aarch64-linux-gnu-g++ - export CC_aarch64_unknown_linux_gnu=aarch64-linux-gnu-gcc - export CXX_aarch64_unknown_linux_gnu=aarch64-linux-gnu-g++ export RUSTFLAGS="-C linker=aarch64-linux-gnu-gcc" else export CC=clang export CXX=clang++ - export CC_x86_64_unknown_linux_gnu=clang - export CXX_x86_64_unknown_linux_gnu=clang++ export RUSTFLAGS="-C linker=clang -C link-arg=-fuse-ld=lld" fi @@ -85,20 +84,19 @@ jobs: - name: Package run: | mkdir -p dist - BIN=target/${{ matrix.target }}/release/${{ env.BINARY_NAME }} - - cp "$BIN" dist/${{ env.BINARY_NAME }}-${{ matrix.target }} + cp target/${{ matrix.target }}/release/${{ env.BINARY_NAME }} dist/telemt cd dist - tar -czf ${{ matrix.asset }}.tar.gz ${{ env.BINARY_NAME }}-${{ matrix.target }} + tar -czf ${{ matrix.asset }}.tar.gz \ + --owner=0 --group=0 --numeric-owner \ + telemt + sha256sum ${{ matrix.asset }}.tar.gz > ${{ matrix.asset }}.sha256 - uses: actions/upload-artifact@v4 with: name: ${{ matrix.asset }} - path: | - dist/${{ matrix.asset }}.tar.gz - dist/${{ matrix.asset }}.sha256 + path: dist/* # ========================== # MUSL @@ -125,43 +123,7 @@ jobs: - name: Install deps run: | apt-get update - apt-get install -y \ - musl-tools \ - pkg-config \ - curl - - - uses: actions/cache@v4 - if: matrix.target == 'aarch64-unknown-linux-musl' - with: - path: ~/.musl-aarch64 - key: musl-toolchain-aarch64-v1 - - - name: Install aarch64 musl toolchain - if: matrix.target == 'aarch64-unknown-linux-musl' - run: | - set -e - - TOOLCHAIN_DIR="$HOME/.musl-aarch64" - ARCHIVE="aarch64-linux-musl-cross.tgz" - URL="https://github.com/telemt/telemt/releases/download/toolchains/$ARCHIVE" - - if [ -x "$TOOLCHAIN_DIR/bin/aarch64-linux-musl-gcc" ]; then - echo "✅ MUSL toolchain already installed" - else - echo "⬇️ Downloading musl toolchain from Telemt GitHub Releases..." - - curl -fL \ - --retry 5 \ - --retry-delay 3 \ - --connect-timeout 10 \ - --max-time 120 \ - -o "$ARCHIVE" "$URL" - - mkdir -p "$TOOLCHAIN_DIR" - tar -xzf "$ARCHIVE" --strip-components=1 -C "$TOOLCHAIN_DIR" - fi - - echo "$TOOLCHAIN_DIR/bin" >> $GITHUB_PATH + apt-get install -y musl-tools pkg-config curl - name: Add rust target run: rustup target add ${{ matrix.target }} @@ -178,11 +140,9 @@ jobs: run: | if [ "${{ matrix.target }}" = "aarch64-unknown-linux-musl" ]; then export CC=aarch64-linux-musl-gcc - export CC_aarch64_unknown_linux_musl=aarch64-linux-musl-gcc export RUSTFLAGS="-C target-feature=+crt-static -C linker=aarch64-linux-musl-gcc" else export CC=musl-gcc - export CC_x86_64_unknown_linux_musl=musl-gcc export RUSTFLAGS="-C target-feature=+crt-static" fi @@ -191,69 +151,19 @@ jobs: - name: Package run: | mkdir -p dist - BIN=target/${{ matrix.target }}/release/${{ env.BINARY_NAME }} - - cp "$BIN" dist/${{ env.BINARY_NAME }}-${{ matrix.target }} + cp target/${{ matrix.target }}/release/${{ env.BINARY_NAME }} dist/telemt cd dist - tar -czf ${{ matrix.asset }}.tar.gz ${{ env.BINARY_NAME }}-${{ matrix.target }} + tar -czf ${{ matrix.asset }}.tar.gz \ + --owner=0 --group=0 --numeric-owner \ + telemt + sha256sum ${{ matrix.asset }}.tar.gz > ${{ matrix.asset }}.sha256 - uses: actions/upload-artifact@v4 with: name: ${{ matrix.asset }} - path: | - dist/${{ matrix.asset }}.tar.gz - dist/${{ matrix.asset }}.sha256 - -# ========================== -# Docker -# ========================== - docker: - name: Docker - runs-on: ubuntu-latest - needs: [build-gnu, build-musl] - continue-on-error: true - - steps: - - uses: actions/checkout@v4 - - - uses: actions/download-artifact@v4 - with: - path: artifacts - - - name: Extract binaries - run: | - mkdir dist - find artifacts -name "*.tar.gz" -exec tar -xzf {} -C dist \; - - cp dist/telemt-x86_64-unknown-linux-musl dist/telemt || true - - - uses: docker/setup-qemu-action@v3 - - uses: docker/setup-buildx-action@v3 - - - name: Login to GHCR - uses: docker/login-action@v3 - with: - registry: ghcr.io - username: ${{ github.actor }} - password: ${{ secrets.GITHUB_TOKEN }} - - - name: Extract version - id: vars - run: echo "VERSION=${GITHUB_REF#refs/tags/}" >> $GITHUB_OUTPUT - - - name: Build & Push - uses: docker/build-push-action@v6 - with: - context: . - push: true - platforms: linux/amd64,linux/arm64 - tags: | - ghcr.io/${{ github.repository }}:${{ steps.vars.outputs.VERSION }} - ghcr.io/${{ github.repository }}:latest - build-args: | - BINARY=dist/telemt + path: dist/* # ========================== # Release @@ -271,7 +181,7 @@ jobs: with: path: artifacts - - name: Flatten artifacts + - name: Flatten run: | mkdir dist find artifacts -type f -exec cp {} dist/ \; @@ -281,5 +191,61 @@ jobs: with: files: dist/* generate_release_notes: true - draft: false - prerelease: ${{ contains(github.ref, '-rc') || contains(github.ref, '-beta') || contains(github.ref, '-alpha') }} + prerelease: ${{ contains(github.ref, '-') }} + +# ========================== +# Docker (FROM RELEASE) +# ========================== + docker: + name: Docker (from release) + runs-on: ubuntu-latest + needs: release + + permissions: + contents: read + packages: write + + steps: + - uses: actions/checkout@v4 + + - name: Install gh + run: apt-get update && apt-get install -y gh + + - name: Extract version + id: vars + run: echo "VERSION=${GITHUB_REF#refs/tags/}" >> $GITHUB_OUTPUT + + - name: Download binary + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + mkdir dist + + gh release download ${{ steps.vars.outputs.VERSION }} \ + --repo ${{ github.repository }} \ + --pattern "telemt-x86_64-linux-musl.tar.gz" \ + --dir dist + + tar -xzf dist/telemt-x86_64-linux-musl.tar.gz -C dist + chmod +x dist/telemt + + - uses: docker/setup-qemu-action@v3 + - uses: docker/setup-buildx-action@v3 + + - uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Build & Push + uses: docker/build-push-action@v6 + with: + context: . + push: true + platforms: linux/amd64,linux/arm64 + tags: | + ghcr.io/${{ github.repository }}:${{ steps.vars.outputs.VERSION }} + ghcr.io/${{ github.repository }}:latest + build-args: | + BINARY=dist/telemt \ No newline at end of file diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml deleted file mode 100644 index 799f2ce..0000000 --- a/.github/workflows/rust.yml +++ /dev/null @@ -1,66 +0,0 @@ -name: Rust - -on: - push: - branches: [ "*" ] - pull_request: - branches: [ "*" ] - -env: - CARGO_TERM_COLOR: always - -jobs: - build: - name: Build - runs-on: ubuntu-latest - - permissions: - contents: read - actions: write - checks: write - - steps: - - name: Checkout repository - uses: actions/checkout@v4 - - - name: Install latest stable Rust toolchain - uses: dtolnay/rust-toolchain@stable - with: - components: rustfmt, clippy - - - name: Cache cargo registry & build artifacts - uses: actions/cache@v4 - with: - path: | - ~/.cargo/registry - ~/.cargo/git - target - key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} - restore-keys: | - ${{ runner.os }}-cargo- - - - name: Build Release - run: cargo build --release --verbose - - - name: Run tests - run: cargo test --verbose - - - name: Stress quota-lock suites (PR only) - if: github.event_name == 'pull_request' - env: - RUST_TEST_THREADS: 16 - run: | - set -euo pipefail - for i in $(seq 1 12); do - echo "[quota-lock-stress] iteration ${i}/12" - cargo test quota_lock_ --bin telemt -- --nocapture --test-threads 16 - cargo test relay_quota_wake --bin telemt -- --nocapture --test-threads 16 - done - -# clippy dont fail on warnings because of active development of telemt -# and many warnings - - name: Run clippy - run: cargo clippy -- --cap-lints warn - - - name: Check for unused dependencies - run: cargo udeps || true diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..46d3b0a --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,127 @@ +name: Check + +on: + push: + branches: [ "*" ] + pull_request: + branches: [ "*" ] + +env: + CARGO_TERM_COLOR: always + +concurrency: + group: test-${{ github.ref }} + cancel-in-progress: true + +jobs: +# ========================== +# Formatting +# ========================== + fmt: + name: Fmt + runs-on: ubuntu-latest + + permissions: + contents: read + + steps: + - uses: actions/checkout@v4 + + - uses: dtolnay/rust-toolchain@stable + with: + components: rustfmt + + - run: cargo fmt -- --check + +# ========================== +# Tests +# ========================== + test: + name: Test + runs-on: ubuntu-latest + + permissions: + contents: read + actions: write + checks: write + + steps: + - uses: actions/checkout@v4 + + - uses: dtolnay/rust-toolchain@stable + + - name: Cache cargo + uses: actions/cache@v4 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + target + key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} + restore-keys: | + ${{ runner.os }}-cargo- + + - run: cargo test --verbose + +# ========================== +# Clippy +# ========================== + clippy: + name: Clippy + runs-on: ubuntu-latest + + permissions: + contents: read + checks: write + + steps: + - uses: actions/checkout@v4 + + - uses: dtolnay/rust-toolchain@stable + with: + components: clippy + + - name: Cache cargo + uses: actions/cache@v4 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + target + key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} + restore-keys: | + ${{ runner.os }}-cargo- + + - run: cargo clippy -- --cap-lints warn + +# ========================== +# Udeps +# ========================== + udeps: + name: Udeps + runs-on: ubuntu-latest + + permissions: + contents: read + + steps: + - uses: actions/checkout@v4 + + - uses: dtolnay/rust-toolchain@stable + + - name: Cache cargo + uses: actions/cache@v4 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + target + key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} + restore-keys: | + ${{ runner.os }}-cargo- + + - name: Install cargo-udeps + run: cargo install cargo-udeps || true + + # тоже не валит билд + - run: cargo udeps || true \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index 8159a22..92da630 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1454,9 +1454,9 @@ dependencies = [ [[package]] name = "iri-string" -version = "0.7.10" +version = "0.7.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c91338f0783edbd6195decb37bae672fd3b165faffb89bf7b9e6942f8b1a731a" +checksum = "d8e7418f59cc01c88316161279a7f665217ae316b388e58a0d10e29f54f1e5eb" dependencies = [ "memchr", "serde", @@ -1486,7 +1486,7 @@ dependencies = [ "cesu8", "cfg-if", "combine", - "jni-sys", + "jni-sys 0.3.1", "log", "thiserror 1.0.69", "walkdir", @@ -1495,9 +1495,31 @@ dependencies = [ [[package]] name = "jni-sys" -version = "0.3.0" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130" +checksum = "41a652e1f9b6e0275df1f15b32661cf0d4b78d4d87ddec5e0c3c20f097433258" +dependencies = [ + "jni-sys 0.4.1", +] + +[[package]] +name = "jni-sys" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c6377a88cb3910bee9b0fa88d4f42e1d2da8e79915598f65fb0c7ee14c878af2" +dependencies = [ + "jni-sys-macros", +] + +[[package]] +name = "jni-sys-macros" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38c0b942f458fe50cdac086d2f946512305e5631e720728f2a61aabcd47a6264" +dependencies = [ + "quote", + "syn", +] [[package]] name = "jobserver" @@ -1659,9 +1681,9 @@ dependencies = [ [[package]] name = "moka" -version = "0.12.14" +version = "0.12.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85f8024e1c8e71c778968af91d43700ce1d11b219d127d79fb2934153b82b42b" +checksum = "957228ad12042ee839f93c8f257b62b4c0ab5eaae1d4fa60de53b27c9d7c5046" dependencies = [ "crossbeam-channel", "crossbeam-epoch", @@ -2771,7 +2793,7 @@ checksum = "7b2093cf4c8eb1e67749a6762251bc9cd836b6fc171623bd0a9d324d37af2417" [[package]] name = "telemt" -version = "3.3.29" +version = "3.3.30" dependencies = [ "aes", "anyhow", diff --git a/Cargo.toml b/Cargo.toml index 53082db..1e06b7a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,8 +1,11 @@ [package] name = "telemt" -version = "3.3.29" +version = "3.3.30" edition = "2024" +[features] +redteam_offline_expected_fail = [] + [dependencies] # C libc = "0.2" diff --git a/Dockerfile b/Dockerfile index 372f702..eac46f0 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,29 +1,9 @@ # syntax=docker/dockerfile:1 -# ========================== -# Stage 1: Build -# ========================== -FROM rust:1.88-slim-bookworm AS builder - -RUN apt-get update && apt-get install -y --no-install-recommends \ - pkg-config \ - ca-certificates \ - && rm -rf /var/lib/apt/lists/* - -WORKDIR /build - -# Depcache -COPY Cargo.toml Cargo.lock* ./ -RUN mkdir src && echo 'fn main() {}' > src/main.rs && \ - cargo build --release 2>/dev/null || true && \ - rm -rf src - -# Build -COPY . . -RUN cargo build --release && strip target/release/telemt +ARG BINARY # ========================== -# Stage 2: Compress (strip + UPX) +# Stage: minimal # ========================== FROM debian:12-slim AS minimal @@ -33,7 +13,6 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ ca-certificates \ && rm -rf /var/lib/apt/lists/* \ \ - # install UPX from Telemt releases && curl -fL \ --retry 5 \ --retry-delay 3 \ @@ -46,15 +25,15 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ && chmod +x /usr/local/bin/upx \ && rm -rf /tmp/upx* -COPY --from=builder /build/target/release/telemt /telemt +COPY ${BINARY} /telemt RUN strip /telemt || true RUN upx --best --lzma /telemt || true # ========================== -# Stage 3: Debug base +# Debug image # ========================== -FROM debian:12-slim AS debug-base +FROM debian:12-slim AS debug RUN apt-get update && apt-get install -y --no-install-recommends \ ca-certificates \ @@ -64,48 +43,29 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ busybox \ && rm -rf /var/lib/apt/lists/* -# ========================== -# Stage 4: Debug image -# ========================== -FROM debug-base AS debug - WORKDIR /app COPY --from=minimal /telemt /app/telemt COPY config.toml /app/config.toml -USER root - -EXPOSE 443 -EXPOSE 9090 -EXPOSE 9091 +EXPOSE 443 9090 9091 ENTRYPOINT ["/app/telemt"] CMD ["config.toml"] # ========================== -# Stage 5: Production (distroless) +# Production (REAL distroless) # ========================== -FROM gcr.io/distroless/base-debian12 AS prod +FROM gcr.io/distroless/static-debian12 AS prod WORKDIR /app COPY --from=minimal /telemt /app/telemt COPY config.toml /app/config.toml -# TLS + timezone + shell -COPY --from=debug-base /etc/ssl/certs /etc/ssl/certs -COPY --from=debug-base /usr/share/zoneinfo /usr/share/zoneinfo -COPY --from=debug-base /bin/busybox /bin/busybox - -RUN ["/bin/busybox", "--install", "-s", "/bin"] - -# distroless user USER nonroot:nonroot -EXPOSE 443 -EXPOSE 9090 -EXPOSE 9091 +EXPOSE 443 9090 9091 ENTRYPOINT ["/app/telemt"] -CMD ["config.toml"] +CMD ["config.toml"] \ No newline at end of file diff --git a/docs/CONFIG_PARAMS.en.md b/docs/CONFIG_PARAMS.en.md index 33e5b29..e9d42a9 100644 --- a/docs/CONFIG_PARAMS.en.md +++ b/docs/CONFIG_PARAMS.en.md @@ -202,12 +202,15 @@ This document lists all configuration keys accepted by `config.toml`. | listen_tcp | `bool \| null` | `null` (auto) | — | Explicit TCP listener enable/disable override. | | proxy_protocol | `bool` | `false` | — | Enables HAProxy PROXY protocol parsing on incoming client connections. | | proxy_protocol_header_timeout_ms | `u64` | `500` | Must be `> 0`. | Timeout for PROXY protocol header read/parse (ms). | +| proxy_protocol_trusted_cidrs | `IpNetwork[]` | `[]` | — | When non-empty, only connections from these proxy source CIDRs are allowed to provide PROXY protocol headers. If empty, PROXY headers are rejected by default (security hardening). | | metrics_port | `u16 \| null` | `null` | — | Metrics endpoint port (enables metrics listener). | | metrics_listen | `String \| null` | `null` | — | Full metrics bind address (`IP:PORT`), overrides `metrics_port`. | | metrics_whitelist | `IpNetwork[]` | `["127.0.0.1/32", "::1/128"]` | — | CIDR whitelist for metrics endpoint access. | | max_connections | `u32` | `10000` | — | Max concurrent client connections (`0` = unlimited). | | accept_permit_timeout_ms | `u64` | `250` | `0..=60000`. | Maximum wait for acquiring a connection-slot permit before the accepted connection is dropped (`0` keeps legacy unbounded wait). | +Note: When `server.proxy_protocol` is enabled, incoming PROXY protocol headers are parsed from the first bytes of the connection and the client source address is replaced with `src_addr` from the header. For security, the peer source IP (the direct connection address) is verified against `server.proxy_protocol_trusted_cidrs`; if this list is empty, PROXY headers are rejected and the connection is considered untrusted. + ## [server.api] | Parameter | Type | Default | Constraints / validation | Description | @@ -271,6 +274,8 @@ This document lists all configuration keys accepted by `config.toml`. | mask_shape_bucket_cap_bytes | `usize` | `4096` | Must be `>= mask_shape_bucket_floor_bytes`. | Maximum bucket size used by shape-channel hardening; traffic above cap is not padded further. | | mask_shape_above_cap_blur | `bool` | `false` | Requires `mask_shape_hardening = true`; requires `mask_shape_above_cap_blur_max_bytes > 0`. | Adds bounded randomized tail bytes even when forwarded size already exceeds cap. | | mask_shape_above_cap_blur_max_bytes | `usize` | `512` | Must be `<= 1048576`; must be `> 0` when `mask_shape_above_cap_blur = true`. | Maximum randomized extra bytes appended above cap. | +| mask_relay_max_bytes | `usize` | `5242880` | Must be `> 0`; must be `<= 67108864`. | Maximum relayed bytes per direction on unauthenticated masking fallback path. | +| mask_classifier_prefetch_timeout_ms | `u64` | `5` | Must be within `[5, 50]`. | Timeout budget (ms) for extending fragmented initial classifier window on masking fallback. | | mask_timing_normalization_enabled | `bool` | `false` | Requires `mask_timing_normalization_floor_ms > 0`; requires `ceiling >= floor`. | Enables timing envelope normalization on masking outcomes. | | mask_timing_normalization_floor_ms | `u64` | `0` | Must be `> 0` when timing normalization is enabled; must be `<= ceiling`. | Lower bound (ms) for masking outcome normalization target. | | mask_timing_normalization_ceiling_ms | `u64` | `0` | Must be `>= floor`; must be `<= 60000`. | Upper bound (ms) for masking outcome normalization target. | diff --git a/docs/VPS_DOUBLE_HOP.en.md b/docs/VPS_DOUBLE_HOP.en.md index 9463b79..6b6abe5 100644 --- a/docs/VPS_DOUBLE_HOP.en.md +++ b/docs/VPS_DOUBLE_HOP.en.md @@ -63,7 +63,7 @@ recommended range from 5 to 2147483647 inclusive > [!IMPORTANT] > It is recommended to use your own, unique values.\ -> You can use the [generator](https://htmlpreview.github.io/?https://gist.githubusercontent.com/avbor/955782b5c37b06240b243aa375baeac5/raw/e8b269ff0089a27effd88f8d925179b78e5666c4/awg-gen.html) to select parameters. +> You can use the [generator](https://htmlpreview.github.io/?https://gist.githubusercontent.com/avbor/955782b5c37b06240b243aa375baeac5/raw/13f5517ca473b47c412b9a99407066de973732bd/awg-gen.html) to select parameters. #### Server B Configuration (Netherlands): @@ -84,6 +84,8 @@ Jmin = 8 Jmax = 80 S1 = 29 S2 = 15 +S3 = 18 +S4 = 0 H1 = 2087563914 H2 = 188817757 H3 = 101784570 @@ -121,6 +123,8 @@ Jmin = 8 Jmax = 80 S1 = 29 S2 = 15 +S3 = 18 +S4 = 0 H1 = 2087563914 H2 = 188817757 H3 = 101784570 diff --git a/docs/VPS_DOUBLE_HOP.ru.md b/docs/VPS_DOUBLE_HOP.ru.md index 625c64c..037dfcb 100644 --- a/docs/VPS_DOUBLE_HOP.ru.md +++ b/docs/VPS_DOUBLE_HOP.ru.md @@ -44,7 +44,7 @@ awg genkey | tee private.key | awg pubkey > public.key Параметры обфускации `S1`, `S2`, `H1`, `H2`, `H3`, `H4` должны быть строго идентичными на обоих серверах.\ Параметры `Jc`, `Jmin` и `Jmax` могут отличатся.\ -Параметры `I1-I5` [(Custom Protocol Signature)](https://docs.amnezia.org/documentation/amnezia-wg/) нужно указывать на стороне _клиента_ (Сервер **А**). +Параметры `I1-I5` ([Custom Protocol Signature](https://docs.amnezia.org/documentation/amnezia-wg/)) нужно указывать на стороне _клиента_ (Сервер **А**). Рекомендации по выбору значений: ```text @@ -62,7 +62,7 @@ H1/H2/H3/H4 — должны быть уникальны и отличаться ``` > [!IMPORTANT] > Рекомендуется использовать собственные, уникальные значения.\ -> Для выбора параметров можете воспользоваться [генератором](https://htmlpreview.github.io/?https://gist.githubusercontent.com/avbor/955782b5c37b06240b243aa375baeac5/raw/e8b269ff0089a27effd88f8d925179b78e5666c4/awg-gen.html). +> Для выбора параметров можете воспользоваться [генератором](https://htmlpreview.github.io/?https://gist.githubusercontent.com/avbor/955782b5c37b06240b243aa375baeac5/raw/13f5517ca473b47c412b9a99407066de973732bd/awg-gen.html). #### Конфигурация Сервера B (_Нидерланды_): @@ -83,6 +83,8 @@ Jmin = 8 Jmax = 80 S1 = 29 S2 = 15 +S3 = 18 +S4 = 0 H1 = 2087563914 H2 = 188817757 H3 = 101784570 @@ -121,6 +123,8 @@ Jmin = 8 Jmax = 80 S1 = 29 S2 = 15 +S3 = 18 +S4 = 0 H1 = 2087563914 H2 = 188817757 H3 = 101784570 @@ -272,7 +276,7 @@ backend telemt_nodes ``` >[!WARNING] ->**Файл должен заканчиваться пустой строкой, иначе HAProxy не запуститься!** +>**Файл должен заканчиваться пустой строкой, иначе HAProxy не запустится!** #### Разрешаем порт 443\tcp в фаерволе (если включен) ```bash diff --git a/src/config/defaults.rs b/src/config/defaults.rs index 66ffeda..b0aaf5b 100644 --- a/src/config/defaults.rs +++ b/src/config/defaults.rs @@ -71,6 +71,22 @@ pub(crate) fn default_tls_fetch_scope() -> String { String::new() } +pub(crate) fn default_tls_fetch_attempt_timeout_ms() -> u64 { + 5_000 +} + +pub(crate) fn default_tls_fetch_total_budget_ms() -> u64 { + 15_000 +} + +pub(crate) fn default_tls_fetch_strict_route() -> bool { + true +} + +pub(crate) fn default_tls_fetch_profile_cache_ttl_secs() -> u64 { + 600 +} + pub(crate) fn default_mask_port() -> u16 { 443 } @@ -185,6 +201,10 @@ pub(crate) fn default_proxy_protocol_header_timeout_ms() -> u64 { 500 } +pub(crate) fn default_proxy_protocol_trusted_cidrs() -> Vec { + vec!["0.0.0.0/0".parse().unwrap(), "::/0".parse().unwrap()] +} + pub(crate) fn default_server_max_connections() -> u32 { 10_000 } @@ -553,6 +573,20 @@ pub(crate) fn default_mask_shape_above_cap_blur_max_bytes() -> usize { 512 } +#[cfg(not(test))] +pub(crate) fn default_mask_relay_max_bytes() -> usize { + 5 * 1024 * 1024 +} + +#[cfg(test)] +pub(crate) fn default_mask_relay_max_bytes() -> usize { + 32 * 1024 +} + +pub(crate) fn default_mask_classifier_prefetch_timeout_ms() -> u64 { + 5 +} + pub(crate) fn default_mask_timing_normalization_enabled() -> bool { false } diff --git a/src/config/hot_reload.rs b/src/config/hot_reload.rs index e580b7f..7f7499e 100644 --- a/src/config/hot_reload.rs +++ b/src/config/hot_reload.rs @@ -228,7 +228,9 @@ impl HotFields { me_d2c_flush_batch_max_delay_us: cfg.general.me_d2c_flush_batch_max_delay_us, me_d2c_ack_flush_immediate: cfg.general.me_d2c_ack_flush_immediate, me_quota_soft_overshoot_bytes: cfg.general.me_quota_soft_overshoot_bytes, - me_d2c_frame_buf_shrink_threshold_bytes: cfg.general.me_d2c_frame_buf_shrink_threshold_bytes, + me_d2c_frame_buf_shrink_threshold_bytes: cfg + .general + .me_d2c_frame_buf_shrink_threshold_bytes, direct_relay_copy_buf_c2s_bytes: cfg.general.direct_relay_copy_buf_c2s_bytes, direct_relay_copy_buf_s2c_bytes: cfg.general.direct_relay_copy_buf_s2c_bytes, me_health_interval_ms_unhealthy: cfg.general.me_health_interval_ms_unhealthy, @@ -600,6 +602,9 @@ fn warn_non_hot_changes(old: &ProxyConfig, new: &ProxyConfig, non_hot_changed: b || old.censorship.mask_shape_above_cap_blur != new.censorship.mask_shape_above_cap_blur || old.censorship.mask_shape_above_cap_blur_max_bytes != new.censorship.mask_shape_above_cap_blur_max_bytes + || old.censorship.mask_relay_max_bytes != new.censorship.mask_relay_max_bytes + || old.censorship.mask_classifier_prefetch_timeout_ms + != new.censorship.mask_classifier_prefetch_timeout_ms || old.censorship.mask_timing_normalization_enabled != new.censorship.mask_timing_normalization_enabled || old.censorship.mask_timing_normalization_floor_ms diff --git a/src/config/load.rs b/src/config/load.rs index bf6d036..3cb6627 100644 --- a/src/config/load.rs +++ b/src/config/load.rs @@ -1,6 +1,6 @@ #![allow(deprecated)] -use std::collections::{BTreeSet, HashMap}; +use std::collections::{BTreeSet, HashMap, HashSet}; use std::hash::{DefaultHasher, Hash, Hasher}; use std::net::{IpAddr, SocketAddr}; use std::path::{Path, PathBuf}; @@ -430,6 +430,24 @@ impl ProxyConfig { )); } + if config.censorship.mask_relay_max_bytes == 0 { + return Err(ProxyError::Config( + "censorship.mask_relay_max_bytes must be > 0".to_string(), + )); + } + + if config.censorship.mask_relay_max_bytes > 67_108_864 { + return Err(ProxyError::Config( + "censorship.mask_relay_max_bytes must be <= 67108864".to_string(), + )); + } + + if !(5..=50).contains(&config.censorship.mask_classifier_prefetch_timeout_ms) { + return Err(ProxyError::Config( + "censorship.mask_classifier_prefetch_timeout_ms must be within [5, 50]".to_string(), + )); + } + if config.censorship.mask_timing_normalization_ceiling_ms < config.censorship.mask_timing_normalization_floor_ms { @@ -539,7 +557,9 @@ impl ProxyConfig { )); } - if !(4096..=16 * 1024 * 1024).contains(&config.general.me_d2c_frame_buf_shrink_threshold_bytes) { + if !(4096..=16 * 1024 * 1024) + .contains(&config.general.me_d2c_frame_buf_shrink_threshold_bytes) + { return Err(ProxyError::Config( "general.me_d2c_frame_buf_shrink_threshold_bytes must be within [4096, 16777216]" .to_string(), @@ -957,6 +977,28 @@ impl ProxyConfig { // Normalize optional TLS fetch scope: whitespace-only values disable scoped routing. config.censorship.tls_fetch_scope = config.censorship.tls_fetch_scope.trim().to_string(); + if config.censorship.tls_fetch.profiles.is_empty() { + config.censorship.tls_fetch.profiles = TlsFetchConfig::default().profiles; + } else { + let mut seen = HashSet::new(); + config + .censorship + .tls_fetch + .profiles + .retain(|profile| seen.insert(*profile)); + } + + if config.censorship.tls_fetch.attempt_timeout_ms == 0 { + return Err(ProxyError::Config( + "censorship.tls_fetch.attempt_timeout_ms must be > 0".to_string(), + )); + } + if config.censorship.tls_fetch.total_budget_ms == 0 { + return Err(ProxyError::Config( + "censorship.tls_fetch.total_budget_ms must be > 0".to_string(), + )); + } + // Merge primary + extra TLS domains, deduplicate (primary always first). if !config.censorship.tls_domains.is_empty() { let mut all = Vec::with_capacity(1 + config.censorship.tls_domains.len()); @@ -1134,6 +1176,10 @@ mod load_security_tests; #[path = "tests/load_mask_shape_security_tests.rs"] mod load_mask_shape_security_tests; +#[cfg(test)] +#[path = "tests/load_mask_classifier_prefetch_timeout_security_tests.rs"] +mod load_mask_classifier_prefetch_timeout_security_tests; + #[cfg(test)] mod tests { use super::*; @@ -1239,6 +1285,11 @@ mod tests { assert_eq!(cfg.general.update_every, default_update_every()); assert_eq!(cfg.server.listen_addr_ipv4, default_listen_addr_ipv4()); assert_eq!(cfg.server.listen_addr_ipv6, default_listen_addr_ipv6_opt()); + assert_eq!( + cfg.server.proxy_protocol_trusted_cidrs, + default_proxy_protocol_trusted_cidrs() + ); + assert_eq!(cfg.censorship.unknown_sni_action, UnknownSniAction::Drop); assert_eq!(cfg.server.api.listen, default_api_listen()); assert_eq!(cfg.server.api.whitelist, default_api_whitelist()); assert_eq!( @@ -1371,6 +1422,14 @@ mod tests { let server = ServerConfig::default(); assert_eq!(server.listen_addr_ipv6, Some(default_listen_addr_ipv6())); + assert_eq!( + server.proxy_protocol_trusted_cidrs, + default_proxy_protocol_trusted_cidrs() + ); + assert_eq!( + AntiCensorshipConfig::default().unknown_sni_action, + UnknownSniAction::Drop + ); assert_eq!(server.api.listen, default_api_listen()); assert_eq!(server.api.whitelist, default_api_whitelist()); assert_eq!( @@ -1406,6 +1465,75 @@ mod tests { assert_eq!(access.users, default_access_users()); } + #[test] + fn proxy_protocol_trusted_cidrs_missing_uses_trust_all_but_explicit_empty_stays_empty() { + let cfg_missing: ProxyConfig = toml::from_str( + r#" + [server] + [general] + [network] + [access] + "#, + ) + .unwrap(); + assert_eq!( + cfg_missing.server.proxy_protocol_trusted_cidrs, + default_proxy_protocol_trusted_cidrs() + ); + + let cfg_explicit_empty: ProxyConfig = toml::from_str( + r#" + [server] + proxy_protocol_trusted_cidrs = [] + + [general] + [network] + [access] + "#, + ) + .unwrap(); + assert!( + cfg_explicit_empty + .server + .proxy_protocol_trusted_cidrs + .is_empty() + ); + } + + #[test] + fn unknown_sni_action_parses_and_defaults_to_drop() { + let cfg_default: ProxyConfig = toml::from_str( + r#" + [server] + [general] + [network] + [access] + [censorship] + "#, + ) + .unwrap(); + assert_eq!( + cfg_default.censorship.unknown_sni_action, + UnknownSniAction::Drop + ); + + let cfg_mask: ProxyConfig = toml::from_str( + r#" + [server] + [general] + [network] + [access] + [censorship] + unknown_sni_action = "mask" + "#, + ) + .unwrap(); + assert_eq!( + cfg_mask.censorship.unknown_sni_action, + UnknownSniAction::Mask + ); + } + #[test] fn dc_overrides_allow_string_and_array() { let toml = r#" @@ -2353,6 +2481,94 @@ mod tests { let _ = std::fs::remove_file(path); } + #[test] + fn tls_fetch_defaults_are_applied() { + let toml = r#" + [censorship] + tls_domain = "example.com" + + [access.users] + user = "00000000000000000000000000000000" + "#; + let dir = std::env::temp_dir(); + let path = dir.join("telemt_tls_fetch_defaults_test.toml"); + std::fs::write(&path, toml).unwrap(); + let cfg = ProxyConfig::load(&path).unwrap(); + assert_eq!( + cfg.censorship.tls_fetch.profiles, + TlsFetchConfig::default().profiles + ); + assert!(cfg.censorship.tls_fetch.strict_route); + assert_eq!(cfg.censorship.tls_fetch.attempt_timeout_ms, 5_000); + assert_eq!(cfg.censorship.tls_fetch.total_budget_ms, 15_000); + assert_eq!(cfg.censorship.tls_fetch.profile_cache_ttl_secs, 600); + let _ = std::fs::remove_file(path); + } + + #[test] + fn tls_fetch_profiles_are_deduplicated_preserving_order() { + let toml = r#" + [censorship] + tls_domain = "example.com" + [censorship.tls_fetch] + profiles = ["compat_tls12", "modern_chrome_like", "compat_tls12", "legacy_minimal"] + + [access.users] + user = "00000000000000000000000000000000" + "#; + let dir = std::env::temp_dir(); + let path = dir.join("telemt_tls_fetch_profiles_dedup_test.toml"); + std::fs::write(&path, toml).unwrap(); + let cfg = ProxyConfig::load(&path).unwrap(); + assert_eq!( + cfg.censorship.tls_fetch.profiles, + vec![ + TlsFetchProfile::CompatTls12, + TlsFetchProfile::ModernChromeLike, + TlsFetchProfile::LegacyMinimal + ] + ); + let _ = std::fs::remove_file(path); + } + + #[test] + fn tls_fetch_attempt_timeout_zero_is_rejected() { + let toml = r#" + [censorship] + tls_domain = "example.com" + [censorship.tls_fetch] + attempt_timeout_ms = 0 + + [access.users] + user = "00000000000000000000000000000000" + "#; + let dir = std::env::temp_dir(); + let path = dir.join("telemt_tls_fetch_attempt_timeout_zero_test.toml"); + std::fs::write(&path, toml).unwrap(); + let err = ProxyConfig::load(&path).unwrap_err().to_string(); + assert!(err.contains("censorship.tls_fetch.attempt_timeout_ms must be > 0")); + let _ = std::fs::remove_file(path); + } + + #[test] + fn tls_fetch_total_budget_zero_is_rejected() { + let toml = r#" + [censorship] + tls_domain = "example.com" + [censorship.tls_fetch] + total_budget_ms = 0 + + [access.users] + user = "00000000000000000000000000000000" + "#; + let dir = std::env::temp_dir(); + let path = dir.join("telemt_tls_fetch_total_budget_zero_test.toml"); + std::fs::write(&path, toml).unwrap(); + let err = ProxyConfig::load(&path).unwrap_err().to_string(); + assert!(err.contains("censorship.tls_fetch.total_budget_ms must be > 0")); + let _ = std::fs::remove_file(path); + } + #[test] fn invalid_ad_tag_is_disabled_during_load() { let toml = r#" diff --git a/src/config/tests/load_mask_classifier_prefetch_timeout_security_tests.rs b/src/config/tests/load_mask_classifier_prefetch_timeout_security_tests.rs new file mode 100644 index 0000000..0b3d543 --- /dev/null +++ b/src/config/tests/load_mask_classifier_prefetch_timeout_security_tests.rs @@ -0,0 +1,76 @@ +use super::*; +use std::fs; +use std::path::PathBuf; +use std::time::{SystemTime, UNIX_EPOCH}; + +fn write_temp_config(contents: &str) -> PathBuf { + let nonce = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("system time must be after unix epoch") + .as_nanos(); + let path = std::env::temp_dir().join(format!( + "telemt-load-mask-prefetch-timeout-security-{nonce}.toml" + )); + fs::write(&path, contents).expect("temp config write must succeed"); + path +} + +fn remove_temp_config(path: &PathBuf) { + let _ = fs::remove_file(path); +} + +#[test] +fn load_rejects_mask_classifier_prefetch_timeout_below_min_bound() { + let path = write_temp_config( + r#" +[censorship] +mask_classifier_prefetch_timeout_ms = 4 +"#, + ); + + let err = ProxyConfig::load(&path) + .expect_err("prefetch timeout below minimum security bound must be rejected"); + let msg = err.to_string(); + assert!( + msg.contains("censorship.mask_classifier_prefetch_timeout_ms must be within [5, 50]"), + "error must explain timeout bound invariant, got: {msg}" + ); + + remove_temp_config(&path); +} + +#[test] +fn load_rejects_mask_classifier_prefetch_timeout_above_max_bound() { + let path = write_temp_config( + r#" +[censorship] +mask_classifier_prefetch_timeout_ms = 51 +"#, + ); + + let err = ProxyConfig::load(&path) + .expect_err("prefetch timeout above max security bound must be rejected"); + let msg = err.to_string(); + assert!( + msg.contains("censorship.mask_classifier_prefetch_timeout_ms must be within [5, 50]"), + "error must explain timeout bound invariant, got: {msg}" + ); + + remove_temp_config(&path); +} + +#[test] +fn load_accepts_mask_classifier_prefetch_timeout_within_bounds() { + let path = write_temp_config( + r#" +[censorship] +mask_classifier_prefetch_timeout_ms = 20 +"#, + ); + + let cfg = + ProxyConfig::load(&path).expect("prefetch timeout within security bounds must be accepted"); + assert_eq!(cfg.censorship.mask_classifier_prefetch_timeout_ms, 20); + + remove_temp_config(&path); +} diff --git a/src/config/tests/load_mask_shape_security_tests.rs b/src/config/tests/load_mask_shape_security_tests.rs index 8986a49..bccd36f 100644 --- a/src/config/tests/load_mask_shape_security_tests.rs +++ b/src/config/tests/load_mask_shape_security_tests.rs @@ -236,3 +236,57 @@ mask_shape_above_cap_blur_max_bytes = 8 remove_temp_config(&path); } + +#[test] +fn load_rejects_zero_mask_relay_max_bytes() { + let path = write_temp_config( + r#" +[censorship] +mask_relay_max_bytes = 0 +"#, + ); + + let err = ProxyConfig::load(&path).expect_err("mask_relay_max_bytes must be > 0"); + let msg = err.to_string(); + assert!( + msg.contains("censorship.mask_relay_max_bytes must be > 0"), + "error must explain non-zero relay cap invariant, got: {msg}" + ); + + remove_temp_config(&path); +} + +#[test] +fn load_rejects_mask_relay_max_bytes_above_upper_bound() { + let path = write_temp_config( + r#" +[censorship] +mask_relay_max_bytes = 67108865 +"#, + ); + + let err = + ProxyConfig::load(&path).expect_err("mask_relay_max_bytes above hard cap must be rejected"); + let msg = err.to_string(); + assert!( + msg.contains("censorship.mask_relay_max_bytes must be <= 67108864"), + "error must explain relay cap upper bound invariant, got: {msg}" + ); + + remove_temp_config(&path); +} + +#[test] +fn load_accepts_valid_mask_relay_max_bytes() { + let path = write_temp_config( + r#" +[censorship] +mask_relay_max_bytes = 8388608 +"#, + ); + + let cfg = ProxyConfig::load(&path).expect("valid mask_relay_max_bytes must be accepted"); + assert_eq!(cfg.censorship.mask_relay_max_bytes, 8_388_608); + + remove_temp_config(&path); +} diff --git a/src/config/types.rs b/src/config/types.rs index aa58dc1..3939664 100644 --- a/src/config/types.rs +++ b/src/config/types.rs @@ -954,7 +954,8 @@ impl Default for GeneralConfig { me_d2c_flush_batch_max_delay_us: default_me_d2c_flush_batch_max_delay_us(), me_d2c_ack_flush_immediate: default_me_d2c_ack_flush_immediate(), me_quota_soft_overshoot_bytes: default_me_quota_soft_overshoot_bytes(), - me_d2c_frame_buf_shrink_threshold_bytes: default_me_d2c_frame_buf_shrink_threshold_bytes(), + me_d2c_frame_buf_shrink_threshold_bytes: + default_me_d2c_frame_buf_shrink_threshold_bytes(), direct_relay_copy_buf_c2s_bytes: default_direct_relay_copy_buf_c2s_bytes(), direct_relay_copy_buf_s2c_bytes: default_direct_relay_copy_buf_s2c_bytes(), me_warmup_stagger_enabled: default_true(), @@ -1239,9 +1240,10 @@ pub struct ServerConfig { /// Trusted source CIDRs allowed to send incoming PROXY protocol headers. /// - /// When non-empty, connections from addresses outside this allowlist are - /// rejected before `src_addr` is applied. - #[serde(default)] + /// If this field is omitted in config, it defaults to trust-all CIDRs + /// (`0.0.0.0/0` and `::/0`). If it is explicitly set to an empty list, + /// all PROXY protocol headers are rejected. + #[serde(default = "default_proxy_protocol_trusted_cidrs")] pub proxy_protocol_trusted_cidrs: Vec, /// Port for the Prometheus-compatible metrics endpoint. @@ -1286,7 +1288,7 @@ impl Default for ServerConfig { listen_tcp: None, proxy_protocol: false, proxy_protocol_header_timeout_ms: default_proxy_protocol_header_timeout_ms(), - proxy_protocol_trusted_cidrs: Vec::new(), + proxy_protocol_trusted_cidrs: default_proxy_protocol_trusted_cidrs(), metrics_port: None, metrics_listen: None, metrics_whitelist: default_metrics_whitelist(), @@ -1357,6 +1359,90 @@ impl Default for TimeoutsConfig { } } +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)] +#[serde(rename_all = "lowercase")] +pub enum UnknownSniAction { + #[default] + Drop, + Mask, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum TlsFetchProfile { + ModernChromeLike, + ModernFirefoxLike, + CompatTls12, + LegacyMinimal, +} + +impl TlsFetchProfile { + pub fn as_str(self) -> &'static str { + match self { + TlsFetchProfile::ModernChromeLike => "modern_chrome_like", + TlsFetchProfile::ModernFirefoxLike => "modern_firefox_like", + TlsFetchProfile::CompatTls12 => "compat_tls12", + TlsFetchProfile::LegacyMinimal => "legacy_minimal", + } + } +} + +fn default_tls_fetch_profiles() -> Vec { + vec![ + TlsFetchProfile::ModernChromeLike, + TlsFetchProfile::ModernFirefoxLike, + TlsFetchProfile::CompatTls12, + TlsFetchProfile::LegacyMinimal, + ] +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TlsFetchConfig { + /// Ordered list of ClientHello profiles used for adaptive fallback. + #[serde(default = "default_tls_fetch_profiles")] + pub profiles: Vec, + + /// When true and upstream route is configured, TLS fetch fails closed on + /// upstream connect errors and does not fallback to direct TCP. + #[serde(default = "default_tls_fetch_strict_route")] + pub strict_route: bool, + + /// Timeout per one profile attempt in milliseconds. + #[serde(default = "default_tls_fetch_attempt_timeout_ms")] + pub attempt_timeout_ms: u64, + + /// Total wall-clock budget in milliseconds across all profile attempts. + #[serde(default = "default_tls_fetch_total_budget_ms")] + pub total_budget_ms: u64, + + /// Adds GREASE-style values into selected ClientHello extensions. + #[serde(default)] + pub grease_enabled: bool, + + /// Produces deterministic ClientHello randomness for debugging/tests. + #[serde(default)] + pub deterministic: bool, + + /// TTL for winner-profile cache entries in seconds. + /// Set to 0 to disable profile cache. + #[serde(default = "default_tls_fetch_profile_cache_ttl_secs")] + pub profile_cache_ttl_secs: u64, +} + +impl Default for TlsFetchConfig { + fn default() -> Self { + Self { + profiles: default_tls_fetch_profiles(), + strict_route: default_tls_fetch_strict_route(), + attempt_timeout_ms: default_tls_fetch_attempt_timeout_ms(), + total_budget_ms: default_tls_fetch_total_budget_ms(), + grease_enabled: false, + deterministic: false, + profile_cache_ttl_secs: default_tls_fetch_profile_cache_ttl_secs(), + } + } +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct AntiCensorshipConfig { #[serde(default = "default_tls_domain")] @@ -1366,11 +1452,19 @@ pub struct AntiCensorshipConfig { #[serde(default)] pub tls_domains: Vec, + /// Policy for TLS ClientHello with unknown (non-configured) SNI. + #[serde(default)] + pub unknown_sni_action: UnknownSniAction, + /// Upstream scope used for TLS front metadata fetches. /// Empty value keeps default upstream routing behavior. #[serde(default = "default_tls_fetch_scope")] pub tls_fetch_scope: String, + /// Fetch strategy for TLS front metadata bootstrap and periodic refresh. + #[serde(default)] + pub tls_fetch: TlsFetchConfig, + #[serde(default = "default_true")] pub mask: bool, @@ -1450,6 +1544,14 @@ pub struct AntiCensorshipConfig { #[serde(default = "default_mask_shape_above_cap_blur_max_bytes")] pub mask_shape_above_cap_blur_max_bytes: usize, + /// Maximum bytes relayed per direction on unauthenticated masking fallback paths. + #[serde(default = "default_mask_relay_max_bytes")] + pub mask_relay_max_bytes: usize, + + /// Prefetch timeout (ms) for extending fragmented masking classifier window. + #[serde(default = "default_mask_classifier_prefetch_timeout_ms")] + pub mask_classifier_prefetch_timeout_ms: u64, + /// Enable outcome-time normalization envelope for masking fallback. #[serde(default = "default_mask_timing_normalization_enabled")] pub mask_timing_normalization_enabled: bool, @@ -1468,7 +1570,9 @@ impl Default for AntiCensorshipConfig { Self { tls_domain: default_tls_domain(), tls_domains: Vec::new(), + unknown_sni_action: UnknownSniAction::Drop, tls_fetch_scope: default_tls_fetch_scope(), + tls_fetch: TlsFetchConfig::default(), mask: default_true(), mask_host: None, mask_port: default_mask_port(), @@ -1488,6 +1592,8 @@ impl Default for AntiCensorshipConfig { mask_shape_bucket_cap_bytes: default_mask_shape_bucket_cap_bytes(), mask_shape_above_cap_blur: default_mask_shape_above_cap_blur(), mask_shape_above_cap_blur_max_bytes: default_mask_shape_above_cap_blur_max_bytes(), + mask_relay_max_bytes: default_mask_relay_max_bytes(), + mask_classifier_prefetch_timeout_ms: default_mask_classifier_prefetch_timeout_ms(), mask_timing_normalization_enabled: default_mask_timing_normalization_enabled(), mask_timing_normalization_floor_ms: default_mask_timing_normalization_floor_ms(), mask_timing_normalization_ceiling_ms: default_mask_timing_normalization_ceiling_ms(), diff --git a/src/error.rs b/src/error.rs index d9aeb22..49c8c81 100644 --- a/src/error.rs +++ b/src/error.rs @@ -216,6 +216,9 @@ pub enum ProxyError { #[error("Invalid proxy protocol header")] InvalidProxyProtocol, + #[error("Unknown TLS SNI")] + UnknownTlsSni, + #[error("Proxy error: {0}")] Proxy(String), diff --git a/src/maestro/helpers.rs b/src/maestro/helpers.rs index 35f796f..032460c 100644 --- a/src/maestro/helpers.rs +++ b/src/maestro/helpers.rs @@ -8,8 +8,10 @@ use tracing::{debug, error, info, warn}; use crate::cli; use crate::config::ProxyConfig; +use crate::transport::UpstreamManager; use crate::transport::middle_proxy::{ - ProxyConfigData, fetch_proxy_config_with_raw, load_proxy_config_cache, save_proxy_config_cache, + ProxyConfigData, fetch_proxy_config_with_raw_via_upstream, load_proxy_config_cache, + save_proxy_config_cache, }; pub(crate) fn resolve_runtime_config_path( @@ -288,9 +290,10 @@ pub(crate) async fn load_startup_proxy_config_snapshot( cache_path: Option<&str>, me2dc_fallback: bool, label: &'static str, + upstream: Option>, ) -> Option { loop { - match fetch_proxy_config_with_raw(url).await { + match fetch_proxy_config_with_raw_via_upstream(url, upstream.clone()).await { Ok((cfg, raw)) => { if !cfg.map.is_empty() { if let Some(path) = cache_path diff --git a/src/maestro/me_startup.rs b/src/maestro/me_startup.rs index 022f8ae..b1e605c 100644 --- a/src/maestro/me_startup.rs +++ b/src/maestro/me_startup.rs @@ -63,9 +63,10 @@ pub(crate) async fn initialize_me_pool( let proxy_secret_path = config.general.proxy_secret_path.as_deref(); let pool_size = config.general.middle_proxy_pool_size.max(1); let proxy_secret = loop { - match crate::transport::middle_proxy::fetch_proxy_secret( + match crate::transport::middle_proxy::fetch_proxy_secret_with_upstream( proxy_secret_path, config.general.proxy_secret_len_max, + Some(upstream_manager.clone()), ) .await { @@ -129,6 +130,7 @@ pub(crate) async fn initialize_me_pool( config.general.proxy_config_v4_cache_path.as_deref(), me2dc_fallback, "getProxyConfig", + Some(upstream_manager.clone()), ) .await; if cfg_v4.is_some() { @@ -160,6 +162,7 @@ pub(crate) async fn initialize_me_pool( config.general.proxy_config_v6_cache_path.as_deref(), me2dc_fallback, "getProxyConfigV6", + Some(upstream_manager.clone()), ) .await; if cfg_v6.is_some() { diff --git a/src/maestro/tls_bootstrap.rs b/src/maestro/tls_bootstrap.rs index 342a2f9..7cf3039 100644 --- a/src/maestro/tls_bootstrap.rs +++ b/src/maestro/tls_bootstrap.rs @@ -7,6 +7,7 @@ use tracing::warn; use crate::config::ProxyConfig; use crate::startup::{COMPONENT_TLS_FRONT_BOOTSTRAP, StartupTracker}; use crate::tls_front::TlsFrontCache; +use crate::tls_front::fetcher::TlsFetchStrategy; use crate::transport::UpstreamManager; pub(crate) async fn bootstrap_tls_front( @@ -40,7 +41,17 @@ pub(crate) async fn bootstrap_tls_front( let mask_unix_sock = config.censorship.mask_unix_sock.clone(); let tls_fetch_scope = (!config.censorship.tls_fetch_scope.is_empty()) .then(|| config.censorship.tls_fetch_scope.clone()); - let fetch_timeout = Duration::from_secs(5); + let tls_fetch = config.censorship.tls_fetch.clone(); + let fetch_strategy = TlsFetchStrategy { + profiles: tls_fetch.profiles, + strict_route: tls_fetch.strict_route, + attempt_timeout: Duration::from_millis(tls_fetch.attempt_timeout_ms.max(1)), + total_budget: Duration::from_millis(tls_fetch.total_budget_ms.max(1)), + grease_enabled: tls_fetch.grease_enabled, + deterministic: tls_fetch.deterministic, + profile_cache_ttl: Duration::from_secs(tls_fetch.profile_cache_ttl_secs), + }; + let fetch_timeout = fetch_strategy.total_budget; let cache_initial = cache.clone(); let domains_initial = tls_domains.to_vec(); @@ -48,6 +59,7 @@ pub(crate) async fn bootstrap_tls_front( let unix_sock_initial = mask_unix_sock.clone(); let scope_initial = tls_fetch_scope.clone(); let upstream_initial = upstream_manager.clone(); + let strategy_initial = fetch_strategy.clone(); tokio::spawn(async move { let mut join = tokio::task::JoinSet::new(); for domain in domains_initial { @@ -56,12 +68,13 @@ pub(crate) async fn bootstrap_tls_front( let unix_sock_domain = unix_sock_initial.clone(); let scope_domain = scope_initial.clone(); let upstream_domain = upstream_initial.clone(); + let strategy_domain = strategy_initial.clone(); join.spawn(async move { - match crate::tls_front::fetcher::fetch_real_tls( + match crate::tls_front::fetcher::fetch_real_tls_with_strategy( &host_domain, port, &domain, - fetch_timeout, + &strategy_domain, Some(upstream_domain), scope_domain.as_deref(), proxy_protocol, @@ -107,6 +120,7 @@ pub(crate) async fn bootstrap_tls_front( let unix_sock_refresh = mask_unix_sock.clone(); let scope_refresh = tls_fetch_scope.clone(); let upstream_refresh = upstream_manager.clone(); + let strategy_refresh = fetch_strategy.clone(); tokio::spawn(async move { loop { let base_secs = rand::rng().random_range(4 * 3600..=6 * 3600); @@ -120,12 +134,13 @@ pub(crate) async fn bootstrap_tls_front( let unix_sock_domain = unix_sock_refresh.clone(); let scope_domain = scope_refresh.clone(); let upstream_domain = upstream_refresh.clone(); + let strategy_domain = strategy_refresh.clone(); join.spawn(async move { - match crate::tls_front::fetcher::fetch_real_tls( + match crate::tls_front::fetcher::fetch_real_tls_with_strategy( &host_domain, port, &domain, - fetch_timeout, + &strategy_domain, Some(upstream_domain), scope_domain.as_deref(), proxy_protocol, diff --git a/src/main.rs b/src/main.rs index c512e6b..e5d931f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -7,12 +7,12 @@ mod crypto; mod error; mod ip_tracker; #[cfg(test)] -#[path = "tests/ip_tracker_hotpath_adversarial_tests.rs"] -mod ip_tracker_hotpath_adversarial_tests; -#[cfg(test)] #[path = "tests/ip_tracker_encapsulation_adversarial_tests.rs"] mod ip_tracker_encapsulation_adversarial_tests; #[cfg(test)] +#[path = "tests/ip_tracker_hotpath_adversarial_tests.rs"] +mod ip_tracker_hotpath_adversarial_tests; +#[cfg(test)] #[path = "tests/ip_tracker_regression_tests.rs"] mod ip_tracker_regression_tests; mod maestro; @@ -29,5 +29,6 @@ mod util; #[tokio::main] async fn main() -> std::result::Result<(), Box> { + let _ = rustls::crypto::ring::default_provider().install_default(); maestro::run().await } diff --git a/src/metrics.rs b/src/metrics.rs index a821d4d..f9475f6 100644 --- a/src/metrics.rs +++ b/src/metrics.rs @@ -1233,10 +1233,7 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp out, "# HELP telemt_me_d2c_batch_bytes_bucket_total DC->Client batch byte size buckets" ); - let _ = writeln!( - out, - "# TYPE telemt_me_d2c_batch_bytes_bucket_total counter" - ); + let _ = writeln!(out, "# TYPE telemt_me_d2c_batch_bytes_bucket_total counter"); let _ = writeln!( out, "telemt_me_d2c_batch_bytes_bucket_total{{bucket=\"0_1k\"}} {}", diff --git a/src/proxy/client.rs b/src/proxy/client.rs index 4b7f57e..8ce3e96 100644 --- a/src/proxy/client.rs +++ b/src/proxy/client.rs @@ -186,6 +186,72 @@ fn handshake_timeout_with_mask_grace(config: &ProxyConfig) -> Duration { } } +const MASK_CLASSIFIER_PREFETCH_WINDOW: usize = 16; +#[cfg(test)] +const MASK_CLASSIFIER_PREFETCH_TIMEOUT: Duration = Duration::from_millis(5); + +fn mask_classifier_prefetch_timeout(config: &ProxyConfig) -> Duration { + Duration::from_millis(config.censorship.mask_classifier_prefetch_timeout_ms) +} + +fn should_prefetch_mask_classifier_window(initial_data: &[u8]) -> bool { + if initial_data.len() >= MASK_CLASSIFIER_PREFETCH_WINDOW { + return false; + } + + if initial_data.is_empty() { + // Empty initial_data means there is no client probe prefix to refine. + // Prefetching in this case can consume fallback relay payload bytes and + // accidentally route them through shaping heuristics. + return false; + } + + if initial_data[0] == 0x16 || initial_data.starts_with(b"SSH-") { + return false; + } + + initial_data + .iter() + .all(|b| b.is_ascii_alphabetic() || *b == b' ') +} + +#[cfg(test)] +async fn extend_masking_initial_window(reader: &mut R, initial_data: &mut Vec) +where + R: AsyncRead + Unpin, +{ + extend_masking_initial_window_with_timeout( + reader, + initial_data, + MASK_CLASSIFIER_PREFETCH_TIMEOUT, + ) + .await; +} + +async fn extend_masking_initial_window_with_timeout( + reader: &mut R, + initial_data: &mut Vec, + prefetch_timeout: Duration, +) where + R: AsyncRead + Unpin, +{ + if !should_prefetch_mask_classifier_window(initial_data) { + return; + } + + let need = MASK_CLASSIFIER_PREFETCH_WINDOW.saturating_sub(initial_data.len()); + if need == 0 { + return; + } + + let mut extra = [0u8; MASK_CLASSIFIER_PREFETCH_WINDOW]; + if let Ok(Ok(n)) = timeout(prefetch_timeout, reader.read(&mut extra[..need])).await + && n > 0 + { + initial_data.extend_from_slice(&extra[..n]); + } +} + fn masking_outcome( reader: R, writer: W, @@ -200,6 +266,15 @@ where W: AsyncWrite + Unpin + Send + 'static, { HandshakeOutcome::NeedsMasking(Box::pin(async move { + let mut reader = reader; + let mut initial_data = initial_data; + extend_masking_initial_window_with_timeout( + &mut reader, + &mut initial_data, + mask_classifier_prefetch_timeout(&config), + ) + .await; + handle_bad_client( reader, writer, @@ -242,13 +317,20 @@ fn record_handshake_failure_class( record_beobachten_class(beobachten, config, peer_ip, class); } +#[inline] +fn increment_bad_on_unknown_tls_sni(stats: &Stats, error: &ProxyError) { + if matches!(error, ProxyError::UnknownTlsSni) { + stats.increment_connects_bad(); + } +} + fn is_trusted_proxy_source(peer_ip: IpAddr, trusted: &[IpNetwork]) -> bool { if trusted.is_empty() { static EMPTY_PROXY_TRUST_WARNED: OnceLock = OnceLock::new(); let warned = EMPTY_PROXY_TRUST_WARNED.get_or_init(|| AtomicBool::new(false)); if !warned.swap(true, Ordering::Relaxed) { warn!( - "PROXY protocol enabled but server.proxy_protocol_trusted_cidrs is empty; rejecting all PROXY headers by default" + "PROXY protocol enabled but server.proxy_protocol_trusted_cidrs is empty; rejecting all PROXY headers" ); } return false; @@ -433,7 +515,10 @@ where beobachten.clone(), )); } - HandshakeResult::Error(e) => return Err(e), + HandshakeResult::Error(e) => { + increment_bad_on_unknown_tls_sni(stats.as_ref(), &e); + return Err(e); + } }; debug!(peer = %peer, "Reading MTProto handshake through TLS"); @@ -884,7 +969,10 @@ impl RunningClientHandler { self.beobachten.clone(), )); } - HandshakeResult::Error(e) => return Err(e), + HandshakeResult::Error(e) => { + increment_bad_on_unknown_tls_sni(stats.as_ref(), &e); + return Err(e); + } }; debug!(peer = %peer, "Reading MTProto handshake through TLS"); @@ -1153,7 +1241,7 @@ impl RunningClientHandler { } if let Some(quota) = config.access.user_data_quota.get(user) - && stats.get_user_total_octets(user) >= *quota + && stats.get_user_quota_used(user) >= *quota { return Err(ProxyError::DataQuotaExceeded { user: user.to_string(), @@ -1212,7 +1300,7 @@ impl RunningClientHandler { } if let Some(quota) = config.access.user_data_quota.get(user) - && stats.get_user_total_octets(user) >= *quota + && stats.get_user_quota_used(user) >= *quota { return Err(ProxyError::DataQuotaExceeded { user: user.to_string(), @@ -1321,6 +1409,38 @@ mod masking_shape_classifier_fuzz_redteam_expected_fail_tests; #[path = "tests/client_masking_probe_evasion_blackhat_tests.rs"] mod masking_probe_evasion_blackhat_tests; +#[cfg(test)] +#[path = "tests/client_masking_fragmented_classifier_security_tests.rs"] +mod masking_fragmented_classifier_security_tests; + +#[cfg(test)] +#[path = "tests/client_masking_replay_timing_security_tests.rs"] +mod masking_replay_timing_security_tests; + +#[cfg(test)] +#[path = "tests/client_masking_http2_fragmented_preface_security_tests.rs"] +mod masking_http2_fragmented_preface_security_tests; + +#[cfg(test)] +#[path = "tests/client_masking_prefetch_invariant_security_tests.rs"] +mod masking_prefetch_invariant_security_tests; + +#[cfg(test)] +#[path = "tests/client_masking_prefetch_timing_matrix_security_tests.rs"] +mod masking_prefetch_timing_matrix_security_tests; + +#[cfg(test)] +#[path = "tests/client_masking_prefetch_config_runtime_security_tests.rs"] +mod masking_prefetch_config_runtime_security_tests; + +#[cfg(test)] +#[path = "tests/client_masking_prefetch_config_pipeline_integration_security_tests.rs"] +mod masking_prefetch_config_pipeline_integration_security_tests; + +#[cfg(test)] +#[path = "tests/client_masking_prefetch_strict_boundary_security_tests.rs"] +mod masking_prefetch_strict_boundary_security_tests; + #[cfg(test)] #[path = "tests/client_beobachten_ttl_bounds_security_tests.rs"] mod beobachten_ttl_bounds_security_tests; @@ -1328,3 +1448,15 @@ mod beobachten_ttl_bounds_security_tests; #[cfg(test)] #[path = "tests/client_tls_record_wrap_hardening_security_tests.rs"] mod tls_record_wrap_hardening_security_tests; + +#[cfg(test)] +#[path = "tests/client_clever_advanced_tests.rs"] +mod client_clever_advanced_tests; + +#[cfg(test)] +#[path = "tests/client_more_advanced_tests.rs"] +mod client_more_advanced_tests; + +#[cfg(test)] +#[path = "tests/client_deep_invariants_tests.rs"] +mod client_deep_invariants_tests; diff --git a/src/proxy/handshake.rs b/src/proxy/handshake.rs index 5632977..2ef8e1b 100644 --- a/src/proxy/handshake.rs +++ b/src/proxy/handshake.rs @@ -16,7 +16,7 @@ use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tracing::{debug, trace, warn}; use zeroize::{Zeroize, Zeroizing}; -use crate::config::ProxyConfig; +use crate::config::{ProxyConfig, UnknownSniAction}; use crate::crypto::{AesCtr, SecureRandom, sha256}; use crate::error::{HandshakeResult, ProxyError}; use crate::protocol::constants::*; @@ -121,6 +121,19 @@ fn auth_probe_eviction_offset(peer_ip: IpAddr, now: Instant) -> usize { hasher.finish() as usize } +fn auth_probe_scan_start_offset( + peer_ip: IpAddr, + now: Instant, + state_len: usize, + scan_limit: usize, +) -> usize { + if state_len == 0 || scan_limit == 0 { + return 0; + } + + auth_probe_eviction_offset(peer_ip, now) % state_len +} + fn auth_probe_is_throttled(peer_ip: IpAddr, now: Instant) -> bool { let peer_ip = normalize_auth_probe_ip(peer_ip); let state = auth_probe_state_map(); @@ -269,34 +282,9 @@ fn auth_probe_record_failure_with_state( let mut eviction_candidate: Option<(IpAddr, u32, Instant)> = None; let state_len = state.len(); let scan_limit = state_len.min(AUTH_PROBE_PRUNE_SCAN_LIMIT); - let start_offset = if state_len == 0 { - 0 - } else { - auth_probe_eviction_offset(peer_ip, now) % state_len - }; - let mut scanned = 0usize; - for entry in state.iter().skip(start_offset) { - let key = *entry.key(); - let fail_streak = entry.value().fail_streak; - let last_seen = entry.value().last_seen; - match eviction_candidate { - Some((_, current_fail, current_seen)) - if fail_streak > current_fail - || (fail_streak == current_fail && last_seen >= current_seen) => {} - _ => eviction_candidate = Some((key, fail_streak, last_seen)), - } - if auth_probe_state_expired(entry.value(), now) { - stale_keys.push(key); - } - scanned += 1; - if scanned >= scan_limit { - break; - } - } - - if scanned < scan_limit { - for entry in state.iter().take(scan_limit - scanned) { + if state_len <= AUTH_PROBE_PRUNE_SCAN_LIMIT { + for entry in state.iter() { let key = *entry.key(); let fail_streak = entry.value().fail_streak; let last_seen = entry.value().last_seen; @@ -310,6 +298,46 @@ fn auth_probe_record_failure_with_state( stale_keys.push(key); } } + } else { + let start_offset = + auth_probe_scan_start_offset(peer_ip, now, state_len, scan_limit); + let mut scanned = 0usize; + for entry in state.iter().skip(start_offset) { + let key = *entry.key(); + let fail_streak = entry.value().fail_streak; + let last_seen = entry.value().last_seen; + match eviction_candidate { + Some((_, current_fail, current_seen)) + if fail_streak > current_fail + || (fail_streak == current_fail && last_seen >= current_seen) => {} + _ => eviction_candidate = Some((key, fail_streak, last_seen)), + } + if auth_probe_state_expired(entry.value(), now) { + stale_keys.push(key); + } + scanned += 1; + if scanned >= scan_limit { + break; + } + } + + if scanned < scan_limit { + for entry in state.iter().take(scan_limit - scanned) { + let key = *entry.key(); + let fail_streak = entry.value().fail_streak; + let last_seen = entry.value().last_seen; + match eviction_candidate { + Some((_, current_fail, current_seen)) + if fail_streak > current_fail + || (fail_streak == current_fail + && last_seen >= current_seen) => {} + _ => eviction_candidate = Some((key, fail_streak, last_seen)), + } + if auth_probe_state_expired(entry.value(), now) { + stale_keys.push(key); + } + } + } } for stale_key in stale_keys { @@ -501,6 +529,21 @@ fn decode_user_secrets( secrets } +#[inline] +fn find_matching_tls_domain<'a>(config: &'a ProxyConfig, sni: &str) -> Option<&'a str> { + if config.censorship.tls_domain.eq_ignore_ascii_case(sni) { + return Some(config.censorship.tls_domain.as_str()); + } + + for domain in &config.censorship.tls_domains { + if domain.eq_ignore_ascii_case(sni) { + return Some(domain.as_str()); + } + } + + None +} + async fn maybe_apply_server_hello_delay(config: &ProxyConfig) { if config.censorship.server_hello_delay_max_ms == 0 { return; @@ -584,70 +627,12 @@ where } let client_sni = tls::extract_sni_from_client_hello(handshake); - let secrets = decode_user_secrets(config, client_sni.as_deref()); - - let validation = match tls::validate_tls_handshake_with_replay_window( - handshake, - &secrets, - config.access.ignore_time_skew, - config.access.replay_window_secs, - ) { - Some(v) => v, - None => { - auth_probe_record_failure(peer.ip(), Instant::now()); - maybe_apply_server_hello_delay(config).await; - debug!( - peer = %peer, - ignore_time_skew = config.access.ignore_time_skew, - "TLS handshake validation failed - no matching user or time skew" - ); - return HandshakeResult::BadClient { reader, writer }; - } - }; - - // Replay tracking is applied only after successful authentication to avoid - // letting unauthenticated probes evict valid entries from the replay cache. - let digest_half = &validation.digest[..tls::TLS_DIGEST_HALF_LEN]; - if replay_checker.check_and_add_tls_digest(digest_half) { - auth_probe_record_failure(peer.ip(), Instant::now()); - maybe_apply_server_hello_delay(config).await; - warn!(peer = %peer, "TLS replay attack detected (duplicate digest)"); - return HandshakeResult::BadClient { reader, writer }; - } - - let secret = match secrets.iter().find(|(name, _)| *name == validation.user) { - Some((_, s)) => s, - None => { - maybe_apply_server_hello_delay(config).await; - return HandshakeResult::BadClient { reader, writer }; - } - }; - - let cached = if config.censorship.tls_emulation { - if let Some(cache) = tls_cache.as_ref() { - let selected_domain = if let Some(sni) = client_sni.as_ref() { - if cache.contains_domain(sni).await { - sni.clone() - } else { - config.censorship.tls_domain.clone() - } - } else { - config.censorship.tls_domain.clone() - }; - let cached_entry = cache.get(&selected_domain).await; - let use_full_cert_payload = cache - .take_full_cert_budget_for_ip( - peer.ip(), - Duration::from_secs(config.censorship.tls_full_cert_ttl_secs), - ) - .await; - Some((cached_entry, use_full_cert_payload)) - } else { - None - } - } else { - None - }; + let preferred_user_hint = client_sni + .as_deref() + .filter(|sni| config.access.users.contains_key(*sni)); + let matched_tls_domain = client_sni + .as_deref() + .and_then(|sni| find_matching_tls_domain(config, sni)); let alpn_list = if config.censorship.alpn_enforce { tls::extract_alpn_from_client_hello(handshake) @@ -670,6 +655,81 @@ where None }; + if client_sni.is_some() && matched_tls_domain.is_none() && preferred_user_hint.is_none() { + auth_probe_record_failure(peer.ip(), Instant::now()); + maybe_apply_server_hello_delay(config).await; + debug!( + peer = %peer, + sni = ?client_sni, + action = ?config.censorship.unknown_sni_action, + "TLS handshake rejected by unknown SNI policy" + ); + return match config.censorship.unknown_sni_action { + UnknownSniAction::Drop => HandshakeResult::Error(ProxyError::UnknownTlsSni), + UnknownSniAction::Mask => HandshakeResult::BadClient { reader, writer }, + }; + } + + let secrets = decode_user_secrets(config, preferred_user_hint); + + let validation = match tls::validate_tls_handshake_with_replay_window( + handshake, + &secrets, + config.access.ignore_time_skew, + config.access.replay_window_secs, + ) { + Some(v) => v, + None => { + auth_probe_record_failure(peer.ip(), Instant::now()); + maybe_apply_server_hello_delay(config).await; + debug!( + peer = %peer, + ignore_time_skew = config.access.ignore_time_skew, + "TLS handshake validation failed - no matching user or time skew" + ); + return HandshakeResult::BadClient { reader, writer }; + } + }; + + // Reject known replay digests before expensive cache/domain/ALPN policy work. + let digest_half = &validation.digest[..tls::TLS_DIGEST_HALF_LEN]; + if replay_checker.check_tls_digest(digest_half) { + auth_probe_record_failure(peer.ip(), Instant::now()); + maybe_apply_server_hello_delay(config).await; + warn!(peer = %peer, "TLS replay attack detected (duplicate digest)"); + return HandshakeResult::BadClient { reader, writer }; + } + + let secret = match secrets.iter().find(|(name, _)| *name == validation.user) { + Some((_, s)) => s, + None => { + maybe_apply_server_hello_delay(config).await; + return HandshakeResult::BadClient { reader, writer }; + } + }; + + let cached = if config.censorship.tls_emulation { + if let Some(cache) = tls_cache.as_ref() { + let selected_domain = + matched_tls_domain.unwrap_or(config.censorship.tls_domain.as_str()); + let cached_entry = cache.get(selected_domain).await; + let use_full_cert_payload = cache + .take_full_cert_budget_for_ip( + peer.ip(), + Duration::from_secs(config.censorship.tls_full_cert_ttl_secs), + ) + .await; + Some((cached_entry, use_full_cert_payload)) + } else { + None + } + } else { + None + }; + + // Add replay digest only for policy-valid handshakes. + replay_checker.add_tls_digest(digest_half); + let response = if let Some((cached_entry, use_full_cert_payload)) = cached { emulator::build_emulated_server_hello( secret, @@ -769,7 +829,7 @@ where let mut dec_key_input = Zeroizing::new(Vec::with_capacity(PREKEY_LEN + secret.len())); dec_key_input.extend_from_slice(dec_prekey); dec_key_input.extend_from_slice(&secret); - let dec_key = sha256(&dec_key_input); + let dec_key = Zeroizing::new(sha256(&dec_key_input)); let mut dec_iv_arr = [0u8; IV_LEN]; dec_iv_arr.copy_from_slice(dec_iv_bytes); @@ -805,7 +865,7 @@ where let mut enc_key_input = Zeroizing::new(Vec::with_capacity(PREKEY_LEN + secret.len())); enc_key_input.extend_from_slice(enc_prekey); enc_key_input.extend_from_slice(&secret); - let enc_key = sha256(&enc_key_input); + let enc_key = Zeroizing::new(sha256(&enc_key_input)); let mut enc_iv_arr = [0u8; IV_LEN]; enc_iv_arr.copy_from_slice(enc_iv_bytes); @@ -830,9 +890,9 @@ where user: user.clone(), dc_idx, proto_tag, - dec_key, + dec_key: *dec_key, dec_iv, - enc_key, + enc_key: *enc_key, enc_iv, peer, is_tls, @@ -979,6 +1039,38 @@ mod saturation_poison_security_tests; #[path = "tests/handshake_auth_probe_hardening_adversarial_tests.rs"] mod auth_probe_hardening_adversarial_tests; +#[cfg(test)] +#[path = "tests/handshake_auth_probe_scan_budget_security_tests.rs"] +mod auth_probe_scan_budget_security_tests; + +#[cfg(test)] +#[path = "tests/handshake_auth_probe_scan_offset_stress_tests.rs"] +mod auth_probe_scan_offset_stress_tests; + +#[cfg(test)] +#[path = "tests/handshake_auth_probe_eviction_bias_security_tests.rs"] +mod auth_probe_eviction_bias_security_tests; + +#[cfg(test)] +#[path = "tests/handshake_advanced_clever_tests.rs"] +mod advanced_clever_tests; + +#[cfg(test)] +#[path = "tests/handshake_more_clever_tests.rs"] +mod more_clever_tests; + +#[cfg(test)] +#[path = "tests/handshake_real_bug_stress_tests.rs"] +mod real_bug_stress_tests; + +#[cfg(test)] +#[path = "tests/handshake_timing_manual_bench_tests.rs"] +mod timing_manual_bench_tests; + +#[cfg(test)] +#[path = "tests/handshake_key_material_zeroization_security_tests.rs"] +mod handshake_key_material_zeroization_security_tests; + /// Compile-time guard: HandshakeSuccess holds cryptographic key material and /// must never be Copy. A Copy impl would allow silent key duplication, /// undermining the zeroize-on-drop guarantee. diff --git a/src/proxy/masking.rs b/src/proxy/masking.rs index 3639db1..ba9f20a 100644 --- a/src/proxy/masking.rs +++ b/src/proxy/masking.rs @@ -4,14 +4,23 @@ use crate::config::ProxyConfig; use crate::network::dns_overrides::resolve_socket_addr; use crate::stats::beobachten::BeobachtenStore; use crate::transport::proxy_protocol::{ProxyProtocolV1Builder, ProxyProtocolV2Builder}; -use rand::{Rng, RngExt}; -use std::net::SocketAddr; +#[cfg(unix)] +use nix::ifaddrs::getifaddrs; +use rand::rngs::StdRng; +use rand::{Rng, RngExt, SeedableRng}; +use std::net::{IpAddr, SocketAddr}; use std::str; -use std::time::Duration; +#[cfg(test)] +use std::sync::atomic::{AtomicUsize, Ordering}; +#[cfg(unix)] +use std::sync::{Mutex, OnceLock}; +use std::time::{Duration, Instant as StdInstant}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::net::TcpStream; #[cfg(unix)] use tokio::net::UnixStream; +#[cfg(unix)] +use tokio::sync::Mutex as AsyncMutex; use tokio::time::{Instant, timeout}; use tracing::debug; @@ -30,28 +39,55 @@ const MASK_RELAY_IDLE_TIMEOUT: Duration = Duration::from_secs(5); #[cfg(test)] const MASK_RELAY_IDLE_TIMEOUT: Duration = Duration::from_millis(100); const MASK_BUFFER_SIZE: usize = 8192; +#[cfg(unix)] +#[cfg(not(test))] +const LOCAL_INTERFACE_CACHE_TTL: Duration = Duration::from_secs(300); +#[cfg(all(unix, test))] +const LOCAL_INTERFACE_CACHE_TTL: Duration = Duration::from_secs(1); struct CopyOutcome { total: usize, ended_by_eof: bool, } -async fn copy_with_idle_timeout(reader: &mut R, writer: &mut W) -> CopyOutcome +async fn copy_with_idle_timeout( + reader: &mut R, + writer: &mut W, + byte_cap: usize, + shutdown_on_eof: bool, +) -> CopyOutcome where R: AsyncRead + Unpin, W: AsyncWrite + Unpin, { - let mut buf = [0u8; MASK_BUFFER_SIZE]; + let mut buf = Box::new([0u8; MASK_BUFFER_SIZE]); let mut total = 0usize; let mut ended_by_eof = false; + + if byte_cap == 0 { + return CopyOutcome { + total, + ended_by_eof, + }; + } + loop { - let read_res = timeout(MASK_RELAY_IDLE_TIMEOUT, reader.read(&mut buf)).await; + let remaining_budget = byte_cap.saturating_sub(total); + if remaining_budget == 0 { + break; + } + + let read_len = remaining_budget.min(MASK_BUFFER_SIZE); + let read_res = timeout(MASK_RELAY_IDLE_TIMEOUT, reader.read(&mut buf[..read_len])).await; let n = match read_res { Ok(Ok(n)) => n, Ok(Err(_)) | Err(_) => break, }; if n == 0 { ended_by_eof = true; + if shutdown_on_eof { + let _ = timeout(MASK_RELAY_IDLE_TIMEOUT, writer.shutdown()).await; + } break; } total = total.saturating_add(n); @@ -68,6 +104,31 @@ where } } +fn is_http_probe(data: &[u8]) -> bool { + // RFC 7540 section 3.5: HTTP/2 client preface starts with "PRI ". + const HTTP_METHODS: [&[u8]; 10] = [ + b"GET ", b"POST", b"HEAD", b"PUT ", b"DELETE", b"OPTIONS", b"CONNECT", b"TRACE", b"PATCH", + b"PRI ", + ]; + + if data.is_empty() { + return false; + } + + let window = &data[..data.len().min(16)]; + for method in HTTP_METHODS { + if data.len() >= method.len() && window.starts_with(method) { + return true; + } + + if (2..=3).contains(&window.len()) && method.starts_with(window) { + return true; + } + } + + false +} + fn next_mask_shape_bucket(total: usize, floor: usize, cap: usize) -> usize { if total == 0 || floor == 0 || cap < floor { return total; @@ -125,6 +186,11 @@ async fn maybe_write_shape_padding( let mut remaining = target_total - total_sent; let mut pad_chunk = [0u8; 1024]; let deadline = Instant::now() + MASK_TIMEOUT; + // Use a Send RNG so relay futures remain spawn-safe under Tokio. + let mut rng = { + let mut seed_source = rand::rng(); + StdRng::from_rng(&mut seed_source) + }; while remaining > 0 { let now = Instant::now(); @@ -133,10 +199,7 @@ async fn maybe_write_shape_padding( } let write_len = remaining.min(pad_chunk.len()); - { - let mut rng = rand::rng(); - rng.fill_bytes(&mut pad_chunk[..write_len]); - } + rng.fill_bytes(&mut pad_chunk[..write_len]); let write_budget = deadline.saturating_duration_since(now); match timeout(write_budget, mask_write.write_all(&pad_chunk[..write_len])).await { Ok(Ok(())) => {} @@ -167,11 +230,11 @@ where } } -async fn consume_client_data_with_timeout(reader: R) +async fn consume_client_data_with_timeout_and_cap(reader: R, byte_cap: usize) where R: AsyncRead + Unpin, { - if timeout(MASK_RELAY_TIMEOUT, consume_client_data(reader)) + if timeout(MASK_RELAY_TIMEOUT, consume_client_data(reader, byte_cap)) .await .is_err() { @@ -190,6 +253,13 @@ fn mask_outcome_target_budget(config: &ProxyConfig) -> Duration { if config.censorship.mask_timing_normalization_enabled { let floor = config.censorship.mask_timing_normalization_floor_ms; let ceiling = config.censorship.mask_timing_normalization_ceiling_ms; + if floor == 0 { + if ceiling == 0 { + return Duration::from_millis(0); + } + let mut rng = rand::rng(); + return Duration::from_millis(rng.random_range(0..=ceiling)); + } if ceiling > floor { let mut rng = rand::rng(); return Duration::from_millis(rng.random_range(floor..=ceiling)); @@ -219,14 +289,7 @@ async fn wait_mask_outcome_budget(started: Instant, config: &ProxyConfig) { /// Detect client type based on initial data fn detect_client_type(data: &[u8]) -> &'static str { // Check for HTTP request - if data.len() > 4 - && (data.starts_with(b"GET ") - || data.starts_with(b"POST") - || data.starts_with(b"HEAD") - || data.starts_with(b"PUT ") - || data.starts_with(b"DELETE") - || data.starts_with(b"OPTIONS")) - { + if is_http_probe(data) { return "HTTP"; } @@ -248,6 +311,247 @@ fn detect_client_type(data: &[u8]) -> &'static str { "unknown" } +fn parse_mask_host_ip_literal(host: &str) -> Option { + if host.starts_with('[') && host.ends_with(']') { + return host[1..host.len() - 1].parse::().ok(); + } + host.parse::().ok() +} + +fn canonical_ip(ip: IpAddr) -> IpAddr { + match ip { + IpAddr::V6(v6) => v6 + .to_ipv4_mapped() + .map(IpAddr::V4) + .unwrap_or(IpAddr::V6(v6)), + IpAddr::V4(v4) => IpAddr::V4(v4), + } +} + +#[cfg(unix)] +fn collect_local_interface_ips() -> Vec { + #[cfg(test)] + LOCAL_INTERFACE_ENUMERATIONS.fetch_add(1, Ordering::Relaxed); + + let mut out = Vec::new(); + if let Ok(addrs) = getifaddrs() { + for iface in addrs { + if let Some(address) = iface.address { + if let Some(v4) = address.as_sockaddr_in() { + out.push(canonical_ip(IpAddr::V4(v4.ip()))); + } else if let Some(v6) = address.as_sockaddr_in6() { + out.push(canonical_ip(IpAddr::V6(v6.ip()))); + } + } + } + } + out +} + +fn choose_interface_snapshot(previous: &[IpAddr], refreshed: Vec) -> Vec { + if refreshed.is_empty() && !previous.is_empty() { + return previous.to_vec(); + } + + refreshed +} + +#[cfg(unix)] +#[derive(Default)] +struct LocalInterfaceCache { + ips: Vec, + refreshed_at: Option, +} + +#[cfg(unix)] +static LOCAL_INTERFACE_CACHE: OnceLock> = OnceLock::new(); + +#[cfg(unix)] +static LOCAL_INTERFACE_REFRESH_LOCK: OnceLock> = OnceLock::new(); + +#[cfg(all(unix, test))] +fn local_interface_ips() -> Vec { + let cache = LOCAL_INTERFACE_CACHE.get_or_init(|| Mutex::new(LocalInterfaceCache::default())); + let mut guard = cache.lock().unwrap_or_else(|poison| poison.into_inner()); + + let stale = guard + .refreshed_at + .is_none_or(|at| at.elapsed() >= LOCAL_INTERFACE_CACHE_TTL); + if stale { + let refreshed = collect_local_interface_ips(); + guard.ips = choose_interface_snapshot(&guard.ips, refreshed); + guard.refreshed_at = Some(StdInstant::now()); + } + + guard.ips.clone() +} + +#[cfg(unix)] +async fn local_interface_ips_async() -> Vec { + let cache = LOCAL_INTERFACE_CACHE.get_or_init(|| Mutex::new(LocalInterfaceCache::default())); + + { + let guard = cache.lock().unwrap_or_else(|poison| poison.into_inner()); + let stale = guard + .refreshed_at + .is_none_or(|at| at.elapsed() >= LOCAL_INTERFACE_CACHE_TTL); + if !stale { + return guard.ips.clone(); + } + } + + let refresh_lock = LOCAL_INTERFACE_REFRESH_LOCK.get_or_init(|| AsyncMutex::new(())); + let _refresh_guard = refresh_lock.lock().await; + + { + let guard = cache.lock().unwrap_or_else(|poison| poison.into_inner()); + let stale = guard + .refreshed_at + .is_none_or(|at| at.elapsed() >= LOCAL_INTERFACE_CACHE_TTL); + if !stale { + return guard.ips.clone(); + } + } + + let refreshed = tokio::task::spawn_blocking(collect_local_interface_ips) + .await + .unwrap_or_default(); + + let mut guard = cache.lock().unwrap_or_else(|poison| poison.into_inner()); + let stale = guard + .refreshed_at + .is_none_or(|at| at.elapsed() >= LOCAL_INTERFACE_CACHE_TTL); + if stale { + guard.ips = choose_interface_snapshot(&guard.ips, refreshed); + guard.refreshed_at = Some(StdInstant::now()); + } + + guard.ips.clone() +} + +#[cfg(all(not(unix), test))] +fn local_interface_ips() -> Vec { + Vec::new() +} + +#[cfg(not(unix))] +async fn local_interface_ips_async() -> Vec { + Vec::new() +} + +#[cfg(test)] +static LOCAL_INTERFACE_ENUMERATIONS: AtomicUsize = AtomicUsize::new(0); + +#[cfg(test)] +fn reset_local_interface_enumerations_for_tests() { + LOCAL_INTERFACE_ENUMERATIONS.store(0, Ordering::Relaxed); + + #[cfg(unix)] + if let Some(cache) = LOCAL_INTERFACE_CACHE.get() { + let mut guard = cache.lock().unwrap_or_else(|poison| poison.into_inner()); + guard.ips.clear(); + guard.refreshed_at = None; + } +} + +#[cfg(test)] +fn local_interface_enumerations_for_tests() -> usize { + LOCAL_INTERFACE_ENUMERATIONS.load(Ordering::Relaxed) +} + +fn is_mask_target_local_listener_with_interfaces( + mask_host: &str, + mask_port: u16, + local_addr: SocketAddr, + resolved_override: Option, + interface_ips: &[IpAddr], +) -> bool { + if mask_port != local_addr.port() { + return false; + } + + let local_ip = canonical_ip(local_addr.ip()); + let literal_mask_ip = parse_mask_host_ip_literal(mask_host).map(canonical_ip); + + if let Some(addr) = resolved_override { + let resolved_ip = canonical_ip(addr.ip()); + if resolved_ip == local_ip { + return true; + } + + if local_ip.is_unspecified() + && (resolved_ip.is_loopback() + || resolved_ip.is_unspecified() + || interface_ips.contains(&resolved_ip)) + { + return true; + } + } + + if let Some(mask_ip) = literal_mask_ip { + if mask_ip == local_ip { + return true; + } + + if local_ip.is_unspecified() + && (mask_ip.is_loopback() + || mask_ip.is_unspecified() + || interface_ips.contains(&mask_ip)) + { + return true; + } + } + + false +} + +#[cfg(test)] +fn is_mask_target_local_listener( + mask_host: &str, + mask_port: u16, + local_addr: SocketAddr, + resolved_override: Option, +) -> bool { + if mask_port != local_addr.port() { + return false; + } + + let interfaces = local_interface_ips(); + is_mask_target_local_listener_with_interfaces( + mask_host, + mask_port, + local_addr, + resolved_override, + &interfaces, + ) +} + +async fn is_mask_target_local_listener_async( + mask_host: &str, + mask_port: u16, + local_addr: SocketAddr, + resolved_override: Option, +) -> bool { + if mask_port != local_addr.port() { + return false; + } + + let interfaces = local_interface_ips_async().await; + is_mask_target_local_listener_with_interfaces( + mask_host, + mask_port, + local_addr, + resolved_override, + &interfaces, + ) +} + +fn masking_beobachten_ttl(config: &ProxyConfig) -> Duration { + let minutes = config.general.beobachten_minutes; + let clamped = minutes.clamp(1, 24 * 60); + Duration::from_secs(clamped.saturating_mul(60)) +} + fn build_mask_proxy_header( version: u8, peer: SocketAddr, @@ -290,13 +594,14 @@ pub async fn handle_bad_client( { let client_type = detect_client_type(initial_data); if config.general.beobachten { - let ttl = Duration::from_secs(config.general.beobachten_minutes.saturating_mul(60)); + let ttl = masking_beobachten_ttl(config); beobachten.record(client_type, peer.ip(), ttl); } if !config.censorship.mask { // Masking disabled, just consume data - consume_client_data_with_timeout(reader).await; + consume_client_data_with_timeout_and_cap(reader, config.censorship.mask_relay_max_bytes) + .await; return; } @@ -341,6 +646,7 @@ pub async fn handle_bad_client( config.censorship.mask_shape_above_cap_blur, config.censorship.mask_shape_above_cap_blur_max_bytes, config.censorship.mask_shape_hardening_aggressive_mode, + config.censorship.mask_relay_max_bytes, ), ) .await @@ -353,12 +659,20 @@ pub async fn handle_bad_client( Ok(Err(e)) => { wait_mask_connect_budget_if_needed(connect_started, config).await; debug!(error = %e, "Failed to connect to mask unix socket"); - consume_client_data_with_timeout(reader).await; + consume_client_data_with_timeout_and_cap( + reader, + config.censorship.mask_relay_max_bytes, + ) + .await; wait_mask_outcome_budget(outcome_started, config).await; } Err(_) => { debug!("Timeout connecting to mask unix socket"); - consume_client_data_with_timeout(reader).await; + consume_client_data_with_timeout_and_cap( + reader, + config.censorship.mask_relay_max_bytes, + ) + .await; wait_mask_outcome_budget(outcome_started, config).await; } } @@ -372,6 +686,29 @@ pub async fn handle_bad_client( .unwrap_or(&config.censorship.tls_domain); let mask_port = config.censorship.mask_port; + // Fail closed when fallback points at our own listener endpoint. + // Self-referential masking can create recursive proxy loops under + // misconfiguration and leak distinguishable load spikes to adversaries. + let resolved_mask_addr = resolve_socket_addr(mask_host, mask_port); + if is_mask_target_local_listener_async(mask_host, mask_port, local_addr, resolved_mask_addr) + .await + { + let outcome_started = Instant::now(); + debug!( + client_type = client_type, + host = %mask_host, + port = mask_port, + local = %local_addr, + "Mask target resolves to local listener; refusing self-referential masking fallback" + ); + consume_client_data_with_timeout_and_cap(reader, config.censorship.mask_relay_max_bytes) + .await; + wait_mask_outcome_budget(outcome_started, config).await; + return; + } + + let outcome_started = Instant::now(); + debug!( client_type = client_type, host = %mask_host, @@ -381,10 +718,9 @@ pub async fn handle_bad_client( ); // Apply runtime DNS override for mask target when configured. - let mask_addr = resolve_socket_addr(mask_host, mask_port) + let mask_addr = resolved_mask_addr .map(|addr| addr.to_string()) .unwrap_or_else(|| format!("{}:{}", mask_host, mask_port)); - let outcome_started = Instant::now(); let connect_started = Instant::now(); let connect_result = timeout(MASK_TIMEOUT, TcpStream::connect(&mask_addr)).await; match connect_result { @@ -413,6 +749,7 @@ pub async fn handle_bad_client( config.censorship.mask_shape_above_cap_blur, config.censorship.mask_shape_above_cap_blur_max_bytes, config.censorship.mask_shape_hardening_aggressive_mode, + config.censorship.mask_relay_max_bytes, ), ) .await @@ -425,12 +762,20 @@ pub async fn handle_bad_client( Ok(Err(e)) => { wait_mask_connect_budget_if_needed(connect_started, config).await; debug!(error = %e, "Failed to connect to mask host"); - consume_client_data_with_timeout(reader).await; + consume_client_data_with_timeout_and_cap( + reader, + config.censorship.mask_relay_max_bytes, + ) + .await; wait_mask_outcome_budget(outcome_started, config).await; } Err(_) => { debug!("Timeout connecting to mask host"); - consume_client_data_with_timeout(reader).await; + consume_client_data_with_timeout_and_cap( + reader, + config.censorship.mask_relay_max_bytes, + ) + .await; wait_mask_outcome_budget(outcome_started, config).await; } } @@ -449,6 +794,7 @@ async fn relay_to_mask( shape_above_cap_blur: bool, shape_above_cap_blur_max_bytes: usize, shape_hardening_aggressive_mode: bool, + mask_relay_max_bytes: usize, ) where R: AsyncRead + Unpin + Send + 'static, W: AsyncWrite + Unpin + Send + 'static, @@ -464,8 +810,18 @@ async fn relay_to_mask( } let (upstream_copy, downstream_copy) = tokio::join!( - async { copy_with_idle_timeout(&mut reader, &mut mask_write).await }, - async { copy_with_idle_timeout(&mut mask_read, &mut writer).await } + async { + copy_with_idle_timeout( + &mut reader, + &mut mask_write, + mask_relay_max_bytes, + !shape_hardening_enabled, + ) + .await + }, + async { + copy_with_idle_timeout(&mut mask_read, &mut writer, mask_relay_max_bytes, true).await + } ); let total_sent = initial_data.len().saturating_add(upstream_copy.total); @@ -491,13 +847,36 @@ async fn relay_to_mask( let _ = writer.shutdown().await; } -/// Just consume all data from client without responding -async fn consume_client_data(mut reader: R) { - let mut buf = vec![0u8; MASK_BUFFER_SIZE]; - while let Ok(n) = reader.read(&mut buf).await { +/// Just consume all data from client without responding. +async fn consume_client_data(mut reader: R, byte_cap: usize) { + if byte_cap == 0 { + return; + } + + // Keep drain path fail-closed under slow-loris stalls. + let mut buf = Box::new([0u8; MASK_BUFFER_SIZE]); + let mut total = 0usize; + + loop { + let remaining_budget = byte_cap.saturating_sub(total); + if remaining_budget == 0 { + break; + } + + let read_len = remaining_budget.min(MASK_BUFFER_SIZE); + let n = match timeout(MASK_RELAY_IDLE_TIMEOUT, reader.read(&mut buf[..read_len])).await { + Ok(Ok(n)) => n, + Ok(Err(_)) | Err(_) => break, + }; + if n == 0 { break; } + + total = total.saturating_add(n); + if total >= byte_cap { + break; + } } } @@ -521,6 +900,10 @@ mod masking_shape_above_cap_blur_security_tests; #[path = "tests/masking_timing_normalization_security_tests.rs"] mod masking_timing_normalization_security_tests; +#[cfg(test)] +#[path = "tests/masking_timing_budget_coupling_security_tests.rs"] +mod masking_timing_budget_coupling_security_tests; + #[cfg(test)] #[path = "tests/masking_ab_envelope_blur_integration_security_tests.rs"] mod masking_ab_envelope_blur_integration_security_tests; @@ -548,3 +931,75 @@ mod masking_aggressive_mode_security_tests; #[cfg(test)] #[path = "tests/masking_timing_sidechannel_redteam_expected_fail_tests.rs"] mod masking_timing_sidechannel_redteam_expected_fail_tests; + +#[cfg(test)] +#[path = "tests/masking_self_target_loop_security_tests.rs"] +mod masking_self_target_loop_security_tests; + +#[cfg(test)] +#[path = "tests/masking_classification_completeness_security_tests.rs"] +mod masking_classification_completeness_security_tests; + +#[cfg(test)] +#[path = "tests/masking_relay_guardrails_security_tests.rs"] +mod masking_relay_guardrails_security_tests; + +#[cfg(test)] +#[path = "tests/masking_connect_failure_close_matrix_security_tests.rs"] +mod masking_connect_failure_close_matrix_security_tests; + +#[cfg(test)] +#[path = "tests/masking_additional_hardening_security_tests.rs"] +mod masking_additional_hardening_security_tests; + +#[cfg(test)] +#[path = "tests/masking_consume_idle_timeout_security_tests.rs"] +mod masking_consume_idle_timeout_security_tests; + +#[cfg(test)] +#[path = "tests/masking_http2_probe_classification_security_tests.rs"] +mod masking_http2_probe_classification_security_tests; + +#[cfg(test)] +#[path = "tests/masking_http_probe_boundary_security_tests.rs"] +mod masking_http_probe_boundary_security_tests; + +#[cfg(test)] +#[path = "tests/masking_rng_hoist_perf_regression_tests.rs"] +mod masking_rng_hoist_perf_regression_tests; + +#[cfg(test)] +#[path = "tests/masking_http2_preface_integration_security_tests.rs"] +mod masking_http2_preface_integration_security_tests; + +#[cfg(test)] +#[path = "tests/masking_consume_stress_adversarial_tests.rs"] +mod masking_consume_stress_adversarial_tests; + +#[cfg(test)] +#[path = "tests/masking_interface_cache_security_tests.rs"] +mod masking_interface_cache_security_tests; + +#[cfg(test)] +#[path = "tests/masking_interface_cache_defense_in_depth_security_tests.rs"] +mod masking_interface_cache_defense_in_depth_security_tests; + +#[cfg(test)] +#[path = "tests/masking_interface_cache_concurrency_security_tests.rs"] +mod masking_interface_cache_concurrency_security_tests; + +#[cfg(test)] +#[path = "tests/masking_production_cap_regression_security_tests.rs"] +mod masking_production_cap_regression_security_tests; + +#[cfg(test)] +#[path = "tests/masking_extended_attack_surface_security_tests.rs"] +mod masking_extended_attack_surface_security_tests; + +#[cfg(test)] +#[path = "tests/masking_padding_timeout_adversarial_tests.rs"] +mod masking_padding_timeout_adversarial_tests; + +#[cfg(all(test, feature = "redteam_offline_expected_fail"))] +#[path = "tests/masking_offline_target_redteam_expected_fail_tests.rs"] +mod masking_offline_target_redteam_expected_fail_tests; diff --git a/src/proxy/middle_relay.rs b/src/proxy/middle_relay.rs index d0f5ffb..3259597 100644 --- a/src/proxy/middle_relay.rs +++ b/src/proxy/middle_relay.rs @@ -1,5 +1,7 @@ use std::collections::hash_map::RandomState; use std::collections::{BTreeSet, HashMap}; +#[cfg(test)] +use std::future::Future; use std::hash::{BuildHasher, Hash}; use std::net::{IpAddr, SocketAddr}; use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; @@ -8,7 +10,7 @@ use std::time::{Duration, Instant}; use dashmap::DashMap; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; -use tokio::sync::{Mutex as AsyncMutex, mpsc, oneshot, watch}; +use tokio::sync::{mpsc, oneshot, watch}; use tokio::time::timeout; use tracing::{debug, info, trace, warn}; @@ -21,7 +23,9 @@ use crate::proxy::route_mode::{ ROUTE_SWITCH_ERROR_MSG, RelayRouteMode, RouteCutoverState, affected_cutover_state, cutover_stagger_delay, }; -use crate::stats::{MeD2cFlushReason, MeD2cQuotaRejectStage, MeD2cWriteMode, Stats}; +use crate::stats::{ + MeD2cFlushReason, MeD2cQuotaRejectStage, MeD2cWriteMode, QuotaReserveError, Stats, UserStats, +}; use crate::stream::{BufferPool, CryptoReader, CryptoWriter, PooledBuffer}; use crate::transport::middle_proxy::{MePool, MeResponse, proto_flags_for_tag}; @@ -39,28 +43,23 @@ const C2ME_CHANNEL_CAPACITY_FALLBACK: usize = 128; const C2ME_SOFT_PRESSURE_MIN_FREE_SLOTS: usize = 64; const C2ME_SENDER_FAIRNESS_BUDGET: usize = 32; const RELAY_IDLE_IO_POLL_MAX: Duration = Duration::from_secs(1); +const TINY_FRAME_DEBT_PER_TINY: u32 = 8; +const TINY_FRAME_DEBT_LIMIT: u32 = 512; #[cfg(test)] const C2ME_SEND_TIMEOUT: Duration = Duration::from_millis(50); #[cfg(not(test))] const C2ME_SEND_TIMEOUT: Duration = Duration::from_secs(5); +#[cfg(test)] +const RELAY_TEST_STEP_TIMEOUT: Duration = Duration::from_secs(1); const ME_D2C_FLUSH_BATCH_MAX_FRAMES_MIN: usize = 1; const ME_D2C_FLUSH_BATCH_MAX_BYTES_MIN: usize = 4096; const ME_D2C_FRAME_BUF_SHRINK_HYSTERESIS_FACTOR: usize = 2; const ME_D2C_SINGLE_WRITE_COALESCE_MAX_BYTES: usize = 128 * 1024; -#[cfg(test)] -const QUOTA_USER_LOCKS_MAX: usize = 64; -#[cfg(not(test))] -const QUOTA_USER_LOCKS_MAX: usize = 4_096; -#[cfg(test)] -const QUOTA_OVERFLOW_LOCK_STRIPES: usize = 16; -#[cfg(not(test))] -const QUOTA_OVERFLOW_LOCK_STRIPES: usize = 256; +const QUOTA_RESERVE_SPIN_RETRIES: usize = 32; static DESYNC_DEDUP: OnceLock> = OnceLock::new(); static DESYNC_HASHER: OnceLock = OnceLock::new(); static DESYNC_FULL_CACHE_LAST_EMIT_AT: OnceLock>> = OnceLock::new(); static DESYNC_DEDUP_EVER_SATURATED: OnceLock = OnceLock::new(); -static QUOTA_USER_LOCKS: OnceLock>>> = OnceLock::new(); -static QUOTA_USER_OVERFLOW_LOCKS: OnceLock>>> = OnceLock::new(); static RELAY_IDLE_CANDIDATE_REGISTRY: OnceLock> = OnceLock::new(); static RELAY_IDLE_MARK_SEQ: AtomicU64 = AtomicU64::new(0); @@ -94,10 +93,24 @@ fn relay_idle_candidate_registry() -> &'static Mutex RELAY_IDLE_CANDIDATE_REGISTRY.get_or_init(|| Mutex::new(RelayIdleCandidateRegistry::default())) } +fn relay_idle_candidate_registry_lock() -> std::sync::MutexGuard<'static, RelayIdleCandidateRegistry> +{ + let registry = relay_idle_candidate_registry(); + match registry.lock() { + Ok(guard) => guard, + Err(poisoned) => { + let mut guard = poisoned.into_inner(); + // Fail closed after panic while holding registry lock: drop all + // candidates and pressure cursors to avoid stale cross-session state. + *guard = RelayIdleCandidateRegistry::default(); + registry.clear_poison(); + guard + } + } +} + fn mark_relay_idle_candidate(conn_id: u64) -> bool { - let Ok(mut guard) = relay_idle_candidate_registry().lock() else { - return false; - }; + let mut guard = relay_idle_candidate_registry_lock(); if guard.by_conn_id.contains_key(&conn_id) { return false; @@ -116,9 +129,7 @@ fn mark_relay_idle_candidate(conn_id: u64) -> bool { } fn clear_relay_idle_candidate(conn_id: u64) { - let Ok(mut guard) = relay_idle_candidate_registry().lock() else { - return; - }; + let mut guard = relay_idle_candidate_registry_lock(); if let Some(meta) = guard.by_conn_id.remove(&conn_id) { guard.ordered.remove(&(meta.mark_order_seq, conn_id)); @@ -127,23 +138,17 @@ fn clear_relay_idle_candidate(conn_id: u64) { #[cfg(test)] fn oldest_relay_idle_candidate() -> Option { - let Ok(guard) = relay_idle_candidate_registry().lock() else { - return None; - }; + let guard = relay_idle_candidate_registry_lock(); guard.ordered.iter().next().map(|(_, conn_id)| *conn_id) } fn note_relay_pressure_event() { - let Ok(mut guard) = relay_idle_candidate_registry().lock() else { - return; - }; + let mut guard = relay_idle_candidate_registry_lock(); guard.pressure_event_seq = guard.pressure_event_seq.wrapping_add(1); } fn relay_pressure_event_seq() -> u64 { - let Ok(guard) = relay_idle_candidate_registry().lock() else { - return 0; - }; + let guard = relay_idle_candidate_registry_lock(); guard.pressure_event_seq } @@ -152,9 +157,7 @@ fn maybe_evict_idle_candidate_on_pressure( seen_pressure_seq: &mut u64, stats: &Stats, ) -> bool { - let Ok(mut guard) = relay_idle_candidate_registry().lock() else { - return false; - }; + let mut guard = relay_idle_candidate_registry_lock(); let latest_pressure_seq = guard.pressure_event_seq; if latest_pressure_seq == *seen_pressure_seq { @@ -199,13 +202,9 @@ fn maybe_evict_idle_candidate_on_pressure( #[cfg(test)] fn clear_relay_idle_pressure_state_for_testing() { - if let Some(registry) = RELAY_IDLE_CANDIDATE_REGISTRY.get() - && let Ok(mut guard) = registry.lock() - { - guard.by_conn_id.clear(); - guard.ordered.clear(); - guard.pressure_event_seq = 0; - guard.pressure_consumed_seq = 0; + if RELAY_IDLE_CANDIDATE_REGISTRY.get().is_some() { + let mut guard = relay_idle_candidate_registry_lock(); + *guard = RelayIdleCandidateRegistry::default(); } RELAY_IDLE_MARK_SEQ.store(0, Ordering::Relaxed); } @@ -259,6 +258,7 @@ impl RelayClientIdlePolicy { struct RelayClientIdleState { last_client_frame_at: Instant, soft_idle_marked: bool, + tiny_frame_debt: u32, } impl RelayClientIdleState { @@ -266,6 +266,7 @@ impl RelayClientIdleState { Self { last_client_frame_at: now, soft_idle_marked: false, + tiny_frame_debt: 0, } } @@ -531,48 +532,28 @@ fn should_yield_c2me_sender(sent_since_yield: usize, has_backlog: bool) -> bool has_backlog && sent_since_yield >= C2ME_SENDER_FAIRNESS_BUDGET } -fn quota_exceeded_for_user(stats: &Stats, user: &str, quota_limit: Option) -> bool { - quota_limit.is_some_and(|quota| stats.get_user_total_octets(user) >= quota) -} - -#[cfg_attr(not(test), allow(dead_code))] -fn quota_would_be_exceeded_for_user( - stats: &Stats, - user: &str, - quota_limit: Option, - bytes: u64, -) -> bool { - quota_limit.is_some_and(|quota| { - let used = stats.get_user_total_octets(user); - used >= quota || bytes > quota.saturating_sub(used) - }) -} - fn quota_soft_cap(limit: u64, overshoot: u64) -> u64 { limit.saturating_add(overshoot) } -fn quota_exceeded_for_user_soft( - stats: &Stats, - user: &str, - quota_limit: Option, - overshoot: u64, -) -> bool { - quota_limit.is_some_and(|quota| stats.get_user_total_octets(user) >= quota_soft_cap(quota, overshoot)) -} - -fn quota_would_be_exceeded_for_user_soft( - stats: &Stats, - user: &str, - quota_limit: Option, +async fn reserve_user_quota_with_yield( + user_stats: &UserStats, bytes: u64, - overshoot: u64, -) -> bool { - quota_limit.is_some_and(|quota| { - let cap = quota_soft_cap(quota, overshoot); - let used = stats.get_user_total_octets(user); - used >= cap || bytes > cap.saturating_sub(used) - }) + limit: u64, +) -> std::result::Result { + loop { + for _ in 0..QUOTA_RESERVE_SPIN_RETRIES { + match user_stats.quota_try_reserve(bytes, limit) { + Ok(total) => return Ok(total), + Err(QuotaReserveError::LimitExceeded) => { + return Err(QuotaReserveError::LimitExceeded); + } + Err(QuotaReserveError::Contended) => std::hint::spin_loop(), + } + } + + tokio::task::yield_now().await; + } } fn classify_me_d2c_flush_reason( @@ -619,53 +600,18 @@ fn observe_me_d2c_flush_event( } #[cfg(test)] -fn quota_user_lock_test_guard() -> &'static Mutex<()> { +fn relay_idle_pressure_test_guard() -> &'static Mutex<()> { static TEST_LOCK: OnceLock> = OnceLock::new(); TEST_LOCK.get_or_init(|| Mutex::new(())) } #[cfg(test)] -fn quota_user_lock_test_scope() -> std::sync::MutexGuard<'static, ()> { - quota_user_lock_test_guard() +pub(crate) fn relay_idle_pressure_test_scope() -> std::sync::MutexGuard<'static, ()> { + relay_idle_pressure_test_guard() .lock() .unwrap_or_else(|poisoned| poisoned.into_inner()) } -fn quota_overflow_user_lock(user: &str) -> Arc> { - let stripes = QUOTA_USER_OVERFLOW_LOCKS.get_or_init(|| { - (0..QUOTA_OVERFLOW_LOCK_STRIPES) - .map(|_| Arc::new(AsyncMutex::new(()))) - .collect() - }); - - let hash = crc32fast::hash(user.as_bytes()) as usize; - Arc::clone(&stripes[hash % stripes.len()]) -} - -fn quota_user_lock(user: &str) -> Arc> { - let locks = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - if let Some(existing) = locks.get(user) { - return Arc::clone(existing.value()); - } - - if locks.len() >= QUOTA_USER_LOCKS_MAX { - locks.retain(|_, value| Arc::strong_count(value) > 1); - } - - if locks.len() >= QUOTA_USER_LOCKS_MAX { - return quota_overflow_user_lock(user); - } - - let created = Arc::new(AsyncMutex::new(())); - match locks.entry(user.to_string()) { - dashmap::mapref::entry::Entry::Occupied(entry) => Arc::clone(entry.get()), - dashmap::mapref::entry::Entry::Vacant(entry) => { - entry.insert(Arc::clone(&created)); - created - } - } -} - async fn enqueue_c2me_command( tx: &mpsc::Sender, cmd: C2MeCommand, @@ -691,6 +637,16 @@ async fn enqueue_c2me_command( } } +#[cfg(test)] +async fn run_relay_test_step_timeout(context: &'static str, fut: F) -> T +where + F: Future, +{ + timeout(RELAY_TEST_STEP_TIMEOUT, fut) + .await + .unwrap_or_else(|_| panic!("{context} exceeded {}s", RELAY_TEST_STEP_TIMEOUT.as_secs())) +} + pub(crate) async fn handle_via_middle_proxy( mut crypto_reader: CryptoReader, crypto_writer: CryptoWriter, @@ -711,6 +667,7 @@ where { let user = success.user.clone(); let quota_limit = config.access.user_data_quota.get(&user).copied(); + let quota_user_stats = quota_limit.map(|_| stats.get_or_create_user_stats_handle(&user)); let peer = success.peer; let proto_tag = success.proto_tag; let pool_generation = me_pool.current_generation(); @@ -837,6 +794,7 @@ where let stats_clone = stats.clone(); let rng_clone = rng.clone(); let user_clone = user.clone(); + let quota_user_stats_me_writer = quota_user_stats.clone(); let last_downstream_activity_ms_clone = last_downstream_activity_ms.clone(); let bytes_me2c_clone = bytes_me2c.clone(); let d2c_flush_policy = MeD2cFlushPolicy::from_config(&config); @@ -866,6 +824,7 @@ where &mut frame_buf, stats_clone.as_ref(), &user_clone, + quota_user_stats_me_writer.as_deref(), quota_limit, d2c_flush_policy.quota_soft_overshoot_bytes, bytes_me2c_clone.as_ref(), @@ -924,6 +883,7 @@ where &mut frame_buf, stats_clone.as_ref(), &user_clone, + quota_user_stats_me_writer.as_deref(), quota_limit, d2c_flush_policy.quota_soft_overshoot_bytes, bytes_me2c_clone.as_ref(), @@ -985,6 +945,7 @@ where &mut frame_buf, stats_clone.as_ref(), &user_clone, + quota_user_stats_me_writer.as_deref(), quota_limit, d2c_flush_policy.quota_soft_overshoot_bytes, bytes_me2c_clone.as_ref(), @@ -1048,6 +1009,7 @@ where &mut frame_buf, stats_clone.as_ref(), &user_clone, + quota_user_stats_me_writer.as_deref(), quota_limit, d2c_flush_policy.quota_soft_overshoot_bytes, bytes_me2c_clone.as_ref(), @@ -1219,16 +1181,23 @@ where forensics.bytes_c2me = forensics .bytes_c2me .saturating_add(payload.len() as u64); - if let Some(limit) = quota_limit { - let quota_lock = quota_user_lock(&user); - let _quota_guard = quota_lock.lock().await; - stats.add_user_octets_from(&user, payload.len() as u64); - if quota_exceeded_for_user(stats.as_ref(), &user, Some(limit)) { + if let (Some(limit), Some(user_stats)) = + (quota_limit, quota_user_stats.as_deref()) + { + if reserve_user_quota_with_yield( + user_stats, + payload.len() as u64, + limit, + ) + .await + .is_err() + { main_result = Err(ProxyError::DataQuotaExceeded { user: user.clone(), }); break; } + stats.add_user_octets_from_handle(user_stats, payload.len() as u64); } else { stats.add_user_octets_from(&user, payload.len() as u64); } @@ -1321,6 +1290,8 @@ async fn read_client_payload_with_idle_policy( where R: AsyncRead + Unpin + Send + 'static, { + const LEGACY_MAX_CONSECUTIVE_ZERO_LEN_FRAMES: u32 = 4; + async fn read_exact_with_policy( client_reader: &mut CryptoReader, buf: &mut [u8], @@ -1459,6 +1430,7 @@ where Ok(()) } + let mut consecutive_zero_len_frames = 0u32; loop { let (len, quickack, raw_len_bytes) = match proto_tag { ProtoTag::Abridged => { @@ -1539,6 +1511,26 @@ where }; if len == 0 { + idle_state.tiny_frame_debt = idle_state + .tiny_frame_debt + .saturating_add(TINY_FRAME_DEBT_PER_TINY); + if idle_state.tiny_frame_debt >= TINY_FRAME_DEBT_LIMIT { + stats.increment_relay_protocol_desync_close_total(); + return Err(ProxyError::Proxy(format!( + "Tiny frame overhead limit exceeded: debt={}, conn_id={}", + idle_state.tiny_frame_debt, forensics.conn_id + ))); + } + + if !idle_policy.enabled { + consecutive_zero_len_frames = consecutive_zero_len_frames.saturating_add(1); + if consecutive_zero_len_frames > LEGACY_MAX_CONSECUTIVE_ZERO_LEN_FRAMES { + stats.increment_relay_protocol_desync_close_total(); + return Err(ProxyError::Proxy( + "Excessive zero-length abridged frames".to_string(), + )); + } + } continue; } if len < 4 && proto_tag != ProtoTag::Abridged { @@ -1607,6 +1599,7 @@ where } *frame_counter += 1; idle_state.on_client_frame(Instant::now()); + idle_state.tiny_frame_debt = idle_state.tiny_frame_debt.saturating_sub(1); clear_relay_idle_candidate(forensics.conn_id); return Ok(Some((payload, quickack))); } @@ -1690,6 +1683,7 @@ async fn process_me_writer_response( frame_buf: &mut Vec, stats: &Stats, user: &str, + quota_user_stats: Option<&UserStats>, quota_limit: Option, quota_soft_overshoot_bytes: u64, bytes_me2c: &AtomicU64, @@ -1708,40 +1702,42 @@ where trace!(conn_id, bytes = data.len(), flags, "ME->C data"); } let data_len = data.len() as u64; - if quota_would_be_exceeded_for_user_soft( - stats, - user, - quota_limit, - data_len, - quota_soft_overshoot_bytes, - ) { - stats.increment_me_d2c_quota_reject_total(MeD2cQuotaRejectStage::PreWrite); - return Err(ProxyError::DataQuotaExceeded { - user: user.to_string(), - }); + if let (Some(limit), Some(user_stats)) = (quota_limit, quota_user_stats) { + let soft_limit = quota_soft_cap(limit, quota_soft_overshoot_bytes); + if reserve_user_quota_with_yield(user_stats, data_len, soft_limit) + .await + .is_err() + { + stats.increment_me_d2c_quota_reject_total(MeD2cQuotaRejectStage::PreWrite); + return Err(ProxyError::DataQuotaExceeded { + user: user.to_string(), + }); + } } let write_mode = - write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf) - .await?; - stats.increment_me_d2c_write_mode(write_mode); + match write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf) + .await + { + Ok(mode) => mode, + Err(err) => { + if quota_limit.is_some() { + stats.add_quota_write_fail_bytes_total(data_len); + stats.increment_quota_write_fail_events_total(); + } + return Err(err); + } + }; - bytes_me2c.fetch_add(data.len() as u64, Ordering::Relaxed); - stats.add_user_octets_to(user, data.len() as u64); - stats.increment_me_d2c_data_frames_total(); - stats.add_me_d2c_payload_bytes_total(data.len() as u64); - - if quota_exceeded_for_user_soft( - stats, - user, - quota_limit, - quota_soft_overshoot_bytes, - ) { - stats.increment_me_d2c_quota_reject_total(MeD2cQuotaRejectStage::PostWrite); - return Err(ProxyError::DataQuotaExceeded { - user: user.to_string(), - }); + bytes_me2c.fetch_add(data_len, Ordering::Relaxed); + if let Some(user_stats) = quota_user_stats { + stats.add_user_octets_to_handle(user_stats, data_len); + } else { + stats.add_user_octets_to(user, data_len); } + stats.increment_me_d2c_data_frames_total(); + stats.add_me_d2c_payload_bytes_total(data_len); + stats.increment_me_d2c_write_mode(write_mode); Ok(MeWriterResponseOutcome::Continue { frames: 1, @@ -1841,8 +1837,14 @@ where MeD2cWriteMode::Coalesced } else { let header = [first]; - client_writer.write_all(&header).await.map_err(ProxyError::Io)?; - client_writer.write_all(data).await.map_err(ProxyError::Io)?; + client_writer + .write_all(&header) + .await + .map_err(ProxyError::Io)?; + client_writer + .write_all(data) + .await + .map_err(ProxyError::Io)?; MeD2cWriteMode::Split } } else if len_words < (1 << 24) { @@ -1864,8 +1866,14 @@ where MeD2cWriteMode::Coalesced } else { let header = [first, lw[0], lw[1], lw[2]]; - client_writer.write_all(&header).await.map_err(ProxyError::Io)?; - client_writer.write_all(data).await.map_err(ProxyError::Io)?; + client_writer + .write_all(&header) + .await + .map_err(ProxyError::Io)?; + client_writer + .write_all(data) + .await + .map_err(ProxyError::Io)?; MeD2cWriteMode::Split } } else { @@ -1907,8 +1915,14 @@ where MeD2cWriteMode::Coalesced } else { let header = len_val.to_le_bytes(); - client_writer.write_all(&header).await.map_err(ProxyError::Io)?; - client_writer.write_all(data).await.map_err(ProxyError::Io)?; + client_writer + .write_all(&header) + .await + .map_err(ProxyError::Io)?; + client_writer + .write_all(data) + .await + .map_err(ProxyError::Io)?; if padding_len > 0 { frame_buf.clear(); if frame_buf.capacity() < padding_len { @@ -1948,10 +1962,6 @@ where .map_err(ProxyError::Io) } -#[cfg(test)] -#[path = "tests/middle_relay_security_tests.rs"] -mod security_tests; - #[cfg(test)] #[path = "tests/middle_relay_idle_policy_security_tests.rs"] mod idle_policy_security_tests; @@ -1964,18 +1974,30 @@ mod desync_all_full_dedup_security_tests; #[path = "tests/middle_relay_stub_completion_security_tests.rs"] mod stub_completion_security_tests; -#[cfg(test)] -#[path = "tests/middle_relay_coverage_high_risk_security_tests.rs"] -mod coverage_high_risk_security_tests; - -#[cfg(test)] -#[path = "tests/middle_relay_quota_overflow_lock_security_tests.rs"] -mod quota_overflow_lock_security_tests; - #[cfg(test)] #[path = "tests/middle_relay_length_cast_hardening_security_tests.rs"] mod length_cast_hardening_security_tests; #[cfg(test)] -#[path = "tests/middle_relay_blackhat_campaign_integration_tests.rs"] -mod blackhat_campaign_integration_tests; +#[path = "tests/middle_relay_idle_registry_poison_security_tests.rs"] +mod middle_relay_idle_registry_poison_security_tests; + +#[cfg(test)] +#[path = "tests/middle_relay_zero_length_frame_security_tests.rs"] +mod middle_relay_zero_length_frame_security_tests; + +#[cfg(test)] +#[path = "tests/middle_relay_tiny_frame_debt_security_tests.rs"] +mod middle_relay_tiny_frame_debt_security_tests; + +#[cfg(test)] +#[path = "tests/middle_relay_tiny_frame_debt_concurrency_security_tests.rs"] +mod middle_relay_tiny_frame_debt_concurrency_security_tests; + +#[cfg(test)] +#[path = "tests/middle_relay_tiny_frame_debt_proto_chunking_security_tests.rs"] +mod middle_relay_tiny_frame_debt_proto_chunking_security_tests; + +#[cfg(test)] +#[path = "tests/middle_relay_atomic_quota_invariant_tests.rs"] +mod middle_relay_atomic_quota_invariant_tests; diff --git a/src/proxy/mod.rs b/src/proxy/mod.rs index eebc188..5880558 100644 --- a/src/proxy/mod.rs +++ b/src/proxy/mod.rs @@ -4,58 +4,58 @@ #![cfg_attr(test, allow(warnings))] #![cfg_attr(not(test), forbid(clippy::undocumented_unsafe_blocks))] #![cfg_attr( - not(test), - deny( - clippy::unwrap_used, - clippy::expect_used, - clippy::panic, - clippy::todo, - clippy::unimplemented, - clippy::correctness, - clippy::option_if_let_else, - clippy::or_fun_call, - clippy::branches_sharing_code, - clippy::single_option_map, - clippy::useless_let_if_seq, - clippy::redundant_locals, - clippy::cloned_ref_to_slice_refs, - unsafe_code, - clippy::await_holding_lock, - clippy::await_holding_refcell_ref, - clippy::debug_assert_with_mut_call, - clippy::macro_use_imports, - clippy::cast_ptr_alignment, - clippy::cast_lossless, - clippy::ptr_as_ptr, - clippy::large_stack_arrays, - clippy::same_functions_in_if_condition, - trivial_casts, - trivial_numeric_casts, - unused_extern_crates, - unused_import_braces, - rust_2018_idioms - ) + not(test), + deny( + clippy::unwrap_used, + clippy::expect_used, + clippy::panic, + clippy::todo, + clippy::unimplemented, + clippy::correctness, + clippy::option_if_let_else, + clippy::or_fun_call, + clippy::branches_sharing_code, + clippy::single_option_map, + clippy::useless_let_if_seq, + clippy::redundant_locals, + clippy::cloned_ref_to_slice_refs, + unsafe_code, + clippy::await_holding_lock, + clippy::await_holding_refcell_ref, + clippy::debug_assert_with_mut_call, + clippy::macro_use_imports, + clippy::cast_ptr_alignment, + clippy::cast_lossless, + clippy::ptr_as_ptr, + clippy::large_stack_arrays, + clippy::same_functions_in_if_condition, + trivial_casts, + trivial_numeric_casts, + unused_extern_crates, + unused_import_braces, + rust_2018_idioms + ) )] #![cfg_attr( - not(test), - allow( - clippy::use_self, - clippy::redundant_closure, - clippy::too_many_arguments, - clippy::doc_markdown, - clippy::missing_const_for_fn, - clippy::unnecessary_operation, - clippy::redundant_pub_crate, - clippy::derive_partial_eq_without_eq, - clippy::type_complexity, - clippy::new_ret_no_self, - clippy::cast_possible_truncation, - clippy::cast_possible_wrap, - clippy::significant_drop_tightening, - clippy::significant_drop_in_scrutinee, - clippy::float_cmp, - clippy::nursery - ) + not(test), + allow( + clippy::use_self, + clippy::redundant_closure, + clippy::too_many_arguments, + clippy::doc_markdown, + clippy::missing_const_for_fn, + clippy::unnecessary_operation, + clippy::redundant_pub_crate, + clippy::derive_partial_eq_without_eq, + clippy::type_complexity, + clippy::new_ret_no_self, + clippy::cast_possible_truncation, + clippy::cast_possible_wrap, + clippy::significant_drop_tightening, + clippy::significant_drop_in_scrutinee, + clippy::float_cmp, + clippy::nursery + ) )] pub mod adaptive_buffers; diff --git a/src/proxy/relay.rs b/src/proxy/relay.rs index 2431ff4..6000e18 100644 --- a/src/proxy/relay.rs +++ b/src/proxy/relay.rs @@ -52,13 +52,12 @@ //! - `SharedCounters` (atomics) let the watchdog read stats without locking use crate::error::{ProxyError, Result}; -use crate::stats::Stats; +use crate::stats::{Stats, UserStats}; use crate::stream::BufferPool; -use dashmap::DashMap; use std::io; use std::pin::Pin; +use std::sync::Arc; use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; -use std::sync::{Arc, Mutex, OnceLock}; use std::task::{Context, Poll}; use std::time::Duration; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf, copy_bidirectional_with_sizes}; @@ -209,12 +208,10 @@ struct StatsIo { counters: Arc, stats: Arc, user: String, + user_stats: Arc, quota_limit: Option, quota_exceeded: Arc, - quota_read_wake_scheduled: bool, - quota_write_wake_scheduled: bool, - quota_read_retry_active: Arc, - quota_write_retry_active: Arc, + quota_bytes_since_check: u64, epoch: Instant, } @@ -230,30 +227,21 @@ impl StatsIo { ) -> Self { // Mark initial activity so the watchdog doesn't fire before data flows counters.touch(Instant::now(), epoch); + let user_stats = stats.get_or_create_user_stats_handle(&user); Self { inner, counters, stats, user, + user_stats, quota_limit, quota_exceeded, - quota_read_wake_scheduled: false, - quota_write_wake_scheduled: false, - quota_read_retry_active: Arc::new(AtomicBool::new(false)), - quota_write_retry_active: Arc::new(AtomicBool::new(false)), + quota_bytes_since_check: 0, epoch, } } } -impl Drop for StatsIo { - fn drop(&mut self) { - self.quota_read_retry_active.store(false, Ordering::Relaxed); - self.quota_write_retry_active - .store(false, Ordering::Relaxed); - } -} - #[derive(Debug)] struct QuotaIoSentinel; @@ -277,84 +265,22 @@ fn is_quota_io_error(err: &io::Error) -> bool { .is_some() } -#[cfg(test)] -const QUOTA_CONTENTION_RETRY_INTERVAL: Duration = Duration::from_millis(1); -#[cfg(not(test))] -const QUOTA_CONTENTION_RETRY_INTERVAL: Duration = Duration::from_millis(2); +const QUOTA_NEAR_LIMIT_BYTES: u64 = 64 * 1024; +const QUOTA_LARGE_CHARGE_BYTES: u64 = 16 * 1024; +const QUOTA_ADAPTIVE_INTERVAL_MIN_BYTES: u64 = 4 * 1024; +const QUOTA_ADAPTIVE_INTERVAL_MAX_BYTES: u64 = 64 * 1024; -fn spawn_quota_retry_waker(retry_active: Arc, waker: std::task::Waker) { - tokio::task::spawn(async move { - loop { - if !retry_active.load(Ordering::Relaxed) { - break; - } - tokio::time::sleep(QUOTA_CONTENTION_RETRY_INTERVAL).await; - if !retry_active.load(Ordering::Relaxed) { - break; - } - waker.wake_by_ref(); - } - }); +#[inline] +fn quota_adaptive_interval_bytes(remaining_before: u64) -> u64 { + remaining_before.saturating_div(2).clamp( + QUOTA_ADAPTIVE_INTERVAL_MIN_BYTES, + QUOTA_ADAPTIVE_INTERVAL_MAX_BYTES, + ) } -static QUOTA_USER_LOCKS: OnceLock>>> = OnceLock::new(); -static QUOTA_USER_OVERFLOW_LOCKS: OnceLock>>> = OnceLock::new(); - -#[cfg(test)] -const QUOTA_USER_LOCKS_MAX: usize = 64; -#[cfg(not(test))] -const QUOTA_USER_LOCKS_MAX: usize = 4_096; -#[cfg(test)] -const QUOTA_OVERFLOW_LOCK_STRIPES: usize = 16; -#[cfg(not(test))] -const QUOTA_OVERFLOW_LOCK_STRIPES: usize = 256; - -#[cfg(test)] -fn quota_user_lock_test_guard() -> &'static Mutex<()> { - static TEST_LOCK: OnceLock> = OnceLock::new(); - TEST_LOCK.get_or_init(|| Mutex::new(())) -} - -#[cfg(test)] -fn quota_user_lock_test_scope() -> std::sync::MutexGuard<'static, ()> { - quota_user_lock_test_guard() - .lock() - .unwrap_or_else(|poisoned| poisoned.into_inner()) -} - -fn quota_overflow_user_lock(user: &str) -> Arc> { - let stripes = QUOTA_USER_OVERFLOW_LOCKS.get_or_init(|| { - (0..QUOTA_OVERFLOW_LOCK_STRIPES) - .map(|_| Arc::new(Mutex::new(()))) - .collect() - }); - - let hash = crc32fast::hash(user.as_bytes()) as usize; - Arc::clone(&stripes[hash % stripes.len()]) -} - -fn quota_user_lock(user: &str) -> Arc> { - let locks = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - if let Some(existing) = locks.get(user) { - return Arc::clone(existing.value()); - } - - if locks.len() >= QUOTA_USER_LOCKS_MAX { - locks.retain(|_, value| Arc::strong_count(value) > 1); - } - - if locks.len() >= QUOTA_USER_LOCKS_MAX { - return quota_overflow_user_lock(user); - } - - let created = Arc::new(Mutex::new(())); - match locks.entry(user.to_string()) { - dashmap::mapref::entry::Entry::Occupied(entry) => Arc::clone(entry.get()), - dashmap::mapref::entry::Entry::Vacant(entry) => { - entry.insert(Arc::clone(&created)); - created - } - } +#[inline] +fn should_immediate_quota_check(remaining_before: u64, charge_bytes: u64) -> bool { + remaining_before <= QUOTA_NEAR_LIMIT_BYTES || charge_bytes >= QUOTA_LARGE_CHARGE_BYTES } impl AsyncRead for StatsIo { @@ -364,80 +290,60 @@ impl AsyncRead for StatsIo { buf: &mut ReadBuf<'_>, ) -> Poll> { let this = self.get_mut(); - if this.quota_exceeded.load(Ordering::Relaxed) { + if this.quota_exceeded.load(Ordering::Acquire) { return Poll::Ready(Err(quota_io_error())); } - let quota_lock = this - .quota_limit - .is_some() - .then(|| quota_user_lock(&this.user)); - let _quota_guard = if let Some(lock) = quota_lock.as_ref() { - match lock.try_lock() { - Ok(guard) => { - this.quota_read_wake_scheduled = false; - this.quota_read_retry_active.store(false, Ordering::Relaxed); - Some(guard) - } - Err(_) => { - if !this.quota_read_wake_scheduled { - this.quota_read_wake_scheduled = true; - this.quota_read_retry_active.store(true, Ordering::Relaxed); - spawn_quota_retry_waker( - Arc::clone(&this.quota_read_retry_active), - cx.waker().clone(), - ); - } - return Poll::Pending; - } + let mut remaining_before = None; + if let Some(limit) = this.quota_limit { + let used_before = this.user_stats.quota_used(); + let remaining = limit.saturating_sub(used_before); + if remaining == 0 { + this.quota_exceeded.store(true, Ordering::Release); + return Poll::Ready(Err(quota_io_error())); } - } else { - None - }; - - if let Some(limit) = this.quota_limit - && this.stats.get_user_total_octets(&this.user) >= limit - { - this.quota_exceeded.store(true, Ordering::Relaxed); - return Poll::Ready(Err(quota_io_error())); + remaining_before = Some(remaining); } + let before = buf.filled().len(); match Pin::new(&mut this.inner).poll_read(cx, buf) { Poll::Ready(Ok(())) => { let n = buf.filled().len() - before; if n > 0 { - let mut reached_quota_boundary = false; - if let Some(limit) = this.quota_limit { - let used = this.stats.get_user_total_octets(&this.user); - if used >= limit { - this.quota_exceeded.store(true, Ordering::Relaxed); - return Poll::Ready(Err(quota_io_error())); - } - - let remaining = limit - used; - if (n as u64) > remaining { - // Fail closed: when a single read chunk would cross quota, - // stop relay immediately without accounting beyond the cap. - this.quota_exceeded.store(true, Ordering::Relaxed); - return Poll::Ready(Err(quota_io_error())); - } - - reached_quota_boundary = (n as u64) == remaining; - } + let n_to_charge = n as u64; // C→S: client sent data this.counters .c2s_bytes - .fetch_add(n as u64, Ordering::Relaxed); + .fetch_add(n_to_charge, Ordering::Relaxed); this.counters.c2s_ops.fetch_add(1, Ordering::Relaxed); this.counters.touch(Instant::now(), this.epoch); - this.stats.add_user_octets_from(&this.user, n as u64); - this.stats.increment_user_msgs_from(&this.user); + this.stats + .add_user_octets_from_handle(this.user_stats.as_ref(), n_to_charge); + this.stats + .increment_user_msgs_from_handle(this.user_stats.as_ref()); - if reached_quota_boundary { - this.quota_exceeded.store(true, Ordering::Relaxed); + if let (Some(limit), Some(remaining)) = (this.quota_limit, remaining_before) { + this.stats + .quota_charge_post_write(this.user_stats.as_ref(), n_to_charge); + if should_immediate_quota_check(remaining, n_to_charge) { + this.quota_bytes_since_check = 0; + if this.user_stats.quota_used() >= limit { + this.quota_exceeded.store(true, Ordering::Release); + } + } else { + this.quota_bytes_since_check = + this.quota_bytes_since_check.saturating_add(n_to_charge); + let interval = quota_adaptive_interval_bytes(remaining); + if this.quota_bytes_since_check >= interval { + this.quota_bytes_since_check = 0; + if this.user_stats.quota_used() >= limit { + this.quota_exceeded.store(true, Ordering::Release); + } + } + } } trace!(user = %this.user, bytes = n, "C->S"); @@ -456,75 +362,57 @@ impl AsyncWrite for StatsIo { buf: &[u8], ) -> Poll> { let this = self.get_mut(); - if this.quota_exceeded.load(Ordering::Relaxed) { + if this.quota_exceeded.load(Ordering::Acquire) { return Poll::Ready(Err(quota_io_error())); } - let quota_lock = this - .quota_limit - .is_some() - .then(|| quota_user_lock(&this.user)); - let _quota_guard = if let Some(lock) = quota_lock.as_ref() { - match lock.try_lock() { - Ok(guard) => { - this.quota_write_wake_scheduled = false; - this.quota_write_retry_active - .store(false, Ordering::Relaxed); - Some(guard) - } - Err(_) => { - if !this.quota_write_wake_scheduled { - this.quota_write_wake_scheduled = true; - this.quota_write_retry_active.store(true, Ordering::Relaxed); - spawn_quota_retry_waker( - Arc::clone(&this.quota_write_retry_active), - cx.waker().clone(), - ); - } - return Poll::Pending; - } - } - } else { - None - }; - - let write_buf = if let Some(limit) = this.quota_limit { - let used = this.stats.get_user_total_octets(&this.user); - if used >= limit { - this.quota_exceeded.store(true, Ordering::Relaxed); + let mut remaining_before = None; + if let Some(limit) = this.quota_limit { + let used_before = this.user_stats.quota_used(); + let remaining = limit.saturating_sub(used_before); + if remaining == 0 { + this.quota_exceeded.store(true, Ordering::Release); return Poll::Ready(Err(quota_io_error())); } + remaining_before = Some(remaining); + } - let remaining = (limit - used) as usize; - if buf.len() > remaining { - // Fail closed: do not emit partial S->C payload when remaining - // quota cannot accommodate the pending write request. - this.quota_exceeded.store(true, Ordering::Relaxed); - return Poll::Ready(Err(quota_io_error())); - } - buf - } else { - buf - }; - - match Pin::new(&mut this.inner).poll_write(cx, write_buf) { + match Pin::new(&mut this.inner).poll_write(cx, buf) { Poll::Ready(Ok(n)) => { if n > 0 { + let n_to_charge = n as u64; + // S→C: data written to client this.counters .s2c_bytes - .fetch_add(n as u64, Ordering::Relaxed); + .fetch_add(n_to_charge, Ordering::Relaxed); this.counters.s2c_ops.fetch_add(1, Ordering::Relaxed); this.counters.touch(Instant::now(), this.epoch); - this.stats.add_user_octets_to(&this.user, n as u64); - this.stats.increment_user_msgs_to(&this.user); + this.stats + .add_user_octets_to_handle(this.user_stats.as_ref(), n_to_charge); + this.stats + .increment_user_msgs_to_handle(this.user_stats.as_ref()); - if let Some(limit) = this.quota_limit - && this.stats.get_user_total_octets(&this.user) >= limit - { - this.quota_exceeded.store(true, Ordering::Relaxed); - return Poll::Ready(Err(quota_io_error())); + if let (Some(limit), Some(remaining)) = (this.quota_limit, remaining_before) { + this.stats + .quota_charge_post_write(this.user_stats.as_ref(), n_to_charge); + if should_immediate_quota_check(remaining, n_to_charge) { + this.quota_bytes_since_check = 0; + if this.user_stats.quota_used() >= limit { + this.quota_exceeded.store(true, Ordering::Release); + } + } else { + this.quota_bytes_since_check = + this.quota_bytes_since_check.saturating_add(n_to_charge); + let interval = quota_adaptive_interval_bytes(remaining); + if this.quota_bytes_since_check >= interval { + this.quota_bytes_since_check = 0; + if this.user_stats.quota_used() >= limit { + this.quota_exceeded.store(true, Ordering::Release); + } + } + } } trace!(user = %this.user, bytes = n, "S->C"); @@ -618,7 +506,7 @@ where let now = Instant::now(); let idle = wd_counters.idle_duration(now, epoch); - if wd_quota_exceeded.load(Ordering::Relaxed) { + if wd_quota_exceeded.load(Ordering::Acquire) { warn!(user = %wd_user, "User data quota reached, closing relay"); return; } @@ -756,18 +644,10 @@ where } } -#[cfg(test)] -#[path = "tests/relay_security_tests.rs"] -mod security_tests; - #[cfg(test)] #[path = "tests/relay_adversarial_tests.rs"] mod adversarial_tests; -#[cfg(test)] -#[path = "tests/relay_quota_lock_pressure_adversarial_tests.rs"] -mod relay_quota_lock_pressure_adversarial_tests; - #[cfg(test)] #[path = "tests/relay_quota_boundary_blackhat_tests.rs"] mod relay_quota_boundary_blackhat_tests; @@ -780,14 +660,14 @@ mod relay_quota_model_adversarial_tests; #[path = "tests/relay_quota_overflow_regression_tests.rs"] mod relay_quota_overflow_regression_tests; +#[cfg(test)] +#[path = "tests/relay_quota_extended_attack_surface_security_tests.rs"] +mod relay_quota_extended_attack_surface_security_tests; + #[cfg(test)] #[path = "tests/relay_watchdog_delta_security_tests.rs"] mod relay_watchdog_delta_security_tests; #[cfg(test)] -#[path = "tests/relay_quota_waker_storm_adversarial_tests.rs"] -mod relay_quota_waker_storm_adversarial_tests; - -#[cfg(test)] -#[path = "tests/relay_quota_wake_liveness_regression_tests.rs"] -mod relay_quota_wake_liveness_regression_tests; +#[path = "tests/relay_atomic_quota_invariant_tests.rs"] +mod relay_atomic_quota_invariant_tests; diff --git a/src/proxy/tests/client_clever_advanced_tests.rs b/src/proxy/tests/client_clever_advanced_tests.rs new file mode 100644 index 0000000..f462ed8 --- /dev/null +++ b/src/proxy/tests/client_clever_advanced_tests.rs @@ -0,0 +1,467 @@ +use super::*; +use crate::config::{ProxyConfig, UpstreamConfig, UpstreamType}; +use crate::protocol::constants::{MAX_TLS_PLAINTEXT_SIZE, MIN_TLS_CLIENT_HELLO_SIZE}; +use crate::stats::Stats; +use crate::transport::UpstreamManager; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; +use std::time::Duration; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt, ReadBuf, duplex}; +use tokio::net::TcpListener; + +#[test] +fn edge_mask_reject_delay_min_greater_than_max_does_not_panic() { + let mut config = ProxyConfig::default(); + config.censorship.server_hello_delay_min_ms = 5000; + config.censorship.server_hello_delay_max_ms = 1000; + + let rt = tokio::runtime::Runtime::new().unwrap(); + rt.block_on(async { + let start = std::time::Instant::now(); + maybe_apply_mask_reject_delay(&config).await; + let elapsed = start.elapsed(); + + assert!(elapsed >= Duration::from_millis(1000)); + assert!(elapsed < Duration::from_millis(1500)); + }); +} + +#[test] +fn edge_handshake_timeout_with_mask_grace_saturating_add_prevents_overflow() { + let mut config = ProxyConfig::default(); + config.timeouts.client_handshake = u64::MAX; + config.censorship.mask = true; + + let timeout = handshake_timeout_with_mask_grace(&config); + assert_eq!(timeout.as_secs(), u64::MAX); +} + +#[test] +fn edge_tls_clienthello_len_in_bounds_exact_boundaries() { + assert!(tls_clienthello_len_in_bounds(MIN_TLS_CLIENT_HELLO_SIZE)); + assert!(!tls_clienthello_len_in_bounds( + MIN_TLS_CLIENT_HELLO_SIZE - 1 + )); + assert!(tls_clienthello_len_in_bounds(MAX_TLS_PLAINTEXT_SIZE)); + assert!(!tls_clienthello_len_in_bounds(MAX_TLS_PLAINTEXT_SIZE + 1)); +} + +#[test] +fn edge_synthetic_local_addr_boundaries() { + assert_eq!(synthetic_local_addr(0).port(), 0); + assert_eq!(synthetic_local_addr(80).port(), 80); + assert_eq!(synthetic_local_addr(u16::MAX).port(), u16::MAX); +} + +#[test] +fn edge_beobachten_record_handshake_failure_class_stream_error_eof() { + let beobachten = BeobachtenStore::new(); + let mut config = ProxyConfig::default(); + config.general.beobachten = true; + config.general.beobachten_minutes = 1; + + let eof_err = ProxyError::Stream(crate::error::StreamError::UnexpectedEof); + let peer_ip: IpAddr = "198.51.100.100".parse().unwrap(); + + record_handshake_failure_class(&beobachten, &config, peer_ip, &eof_err); + + let snapshot = beobachten.snapshot_text(Duration::from_secs(60)); + assert!(snapshot.contains("[expected_64_got_0]")); +} + +#[tokio::test] +async fn adversarial_tls_handshake_timeout_during_masking_delay() { + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.timeouts.client_handshake = 1; + cfg.censorship.mask = true; + cfg.censorship.server_hello_delay_min_ms = 3000; + cfg.censorship.server_hello_delay_max_ms = 3000; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let (server_side, mut client_side) = duplex(4096); + + let handle = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.1:55000".parse().unwrap(), + config, + stats.clone(), + Arc::new(UpstreamManager::new( + vec![], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + false, + )); + + client_side + .write_all(&[0x16, 0x03, 0x01, 0xFF, 0xFF]) + .await + .unwrap(); + + let result = tokio::time::timeout(Duration::from_secs(4), handle) + .await + .unwrap() + .unwrap(); + + assert!(matches!(result, Err(ProxyError::TgHandshakeTimeout))); + assert_eq!(stats.get_handshake_timeouts(), 1); +} + +#[tokio::test] +async fn blackhat_proxy_protocol_slowloris_timeout() { + let mut cfg = ProxyConfig::default(); + cfg.server.proxy_protocol_header_timeout_ms = 200; + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + + let (server_side, mut client_side) = duplex(4096); + let handle = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.2:55000".parse().unwrap(), + config, + stats.clone(), + Arc::new(UpstreamManager::new( + vec![], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + true, + )); + + client_side.write_all(b"PROXY TCP4 192.").await.unwrap(); + tokio::time::sleep(Duration::from_millis(300)).await; + + let result = tokio::time::timeout(Duration::from_secs(2), handle) + .await + .unwrap() + .unwrap(); + + assert!(matches!(result, Err(ProxyError::InvalidProxyProtocol))); + assert_eq!(stats.get_connects_bad(), 1); +} + +#[test] +fn blackhat_ipv4_mapped_ipv6_proxy_source_bypass_attempt() { + let trusted = vec!["192.0.2.0/24".parse().unwrap()]; + let peer_ip = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0xffff, 0xc000, 0x0201)); + assert!(!is_trusted_proxy_source(peer_ip, &trusted)); +} + +#[tokio::test] +async fn negative_proxy_protocol_enabled_but_client_sends_tls_hello() { + let mut cfg = ProxyConfig::default(); + cfg.server.proxy_protocol_header_timeout_ms = 500; + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + + let (server_side, mut client_side) = duplex(4096); + let handle = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.3:55000".parse().unwrap(), + config, + stats.clone(), + Arc::new(UpstreamManager::new( + vec![], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + true, + )); + + client_side + .write_all(&[0x16, 0x03, 0x01, 0x02, 0x00]) + .await + .unwrap(); + + let result = tokio::time::timeout(Duration::from_secs(2), handle) + .await + .unwrap() + .unwrap(); + + assert!(matches!(result, Err(ProxyError::InvalidProxyProtocol))); + assert_eq!(stats.get_connects_bad(), 1); +} + +#[tokio::test] +async fn edge_client_stream_exactly_4_bytes_eof() { + let config = Arc::new(ProxyConfig::default()); + let stats = Arc::new(Stats::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(4096); + let handle = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.4:55000".parse().unwrap(), + config, + stats.clone(), + Arc::new(UpstreamManager::new( + vec![], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + beobachten.clone(), + false, + )); + + client_side + .write_all(&[0x16, 0x03, 0x01, 0x00]) + .await + .unwrap(); + client_side.shutdown().await.unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(2), handle).await; + + let snapshot = beobachten.snapshot_text(Duration::from_secs(60)); + assert!(snapshot.contains("[expected_64_got_0]")); +} + +#[tokio::test] +async fn edge_client_stream_tls_header_valid_but_body_1_byte_short_eof() { + let config = Arc::new(ProxyConfig::default()); + let stats = Arc::new(Stats::new()); + + let (server_side, mut client_side) = duplex(4096); + let handle = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.5:55000".parse().unwrap(), + config, + stats.clone(), + Arc::new(UpstreamManager::new( + vec![], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + false, + )); + + client_side + .write_all(&[0x16, 0x03, 0x01, 0x00, 100]) + .await + .unwrap(); + client_side.write_all(&vec![0x41; 99]).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(2), handle).await; + assert_eq!(stats.get_connects_bad(), 1); +} + +#[tokio::test] +async fn integration_non_tls_modes_disabled_immediately_masks() { + let mut cfg = ProxyConfig::default(); + cfg.general.modes.classic = false; + cfg.general.modes.secure = false; + cfg.censorship.mask = true; + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + + let (server_side, mut client_side) = duplex(4096); + let handle = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.6:55000".parse().unwrap(), + config, + stats.clone(), + Arc::new(UpstreamManager::new( + vec![], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + false, + )); + + client_side.write_all(b"GET / HTTP/1.1\r\n").await.unwrap(); + let _ = tokio::time::timeout(Duration::from_secs(2), handle).await; + assert_eq!(stats.get_connects_bad(), 1); +} + +struct YieldingReader { + data: Vec, + pos: usize, + yields_left: usize, +} + +impl AsyncRead for YieldingReader { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let this = self.get_mut(); + if this.yields_left > 0 { + this.yields_left -= 1; + cx.waker().wake_by_ref(); + return Poll::Pending; + } + if this.pos >= this.data.len() { + return Poll::Ready(Ok(())); + } + buf.put_slice(&this.data[this.pos..this.pos + 1]); + this.pos += 1; + this.yields_left = 2; + Poll::Ready(Ok(())) + } +} + +#[tokio::test] +async fn fuzz_read_with_progress_heavy_yielding() { + let expected_data = b"HEAVY_YIELD_TEST_DATA".to_vec(); + let mut reader = YieldingReader { + data: expected_data.clone(), + pos: 0, + yields_left: 2, + }; + + let mut buf = vec![0u8; expected_data.len()]; + let read_bytes = read_with_progress(&mut reader, &mut buf).await.unwrap(); + + assert_eq!(read_bytes, expected_data.len()); + assert_eq!(buf, expected_data); +} + +#[test] +fn edge_wrap_tls_application_record_exactly_u16_max() { + let payload = vec![0u8; 65535]; + let wrapped = wrap_tls_application_record(&payload); + assert_eq!(wrapped.len(), 65540); + assert_eq!(wrapped[0], TLS_RECORD_APPLICATION); + assert_eq!(&wrapped[3..5], &65535u16.to_be_bytes()); +} + +#[test] +fn fuzz_wrap_tls_application_record_lengths() { + let lengths = [0, 1, 65534, 65535, 65536, 131070, 131071, 131072]; + for len in lengths { + let payload = vec![0u8; len]; + let wrapped = wrap_tls_application_record(&payload); + let expected_chunks = len.div_ceil(65535).max(1); + assert_eq!(wrapped.len(), len + 5 * expected_chunks); + } +} + +#[tokio::test] +async fn stress_user_connection_reservation_concurrent_same_ip_exhaustion() { + let user = "stress-same-ip-user"; + let mut config = ProxyConfig::default(); + config.access.user_max_tcp_conns.insert(user.to_string(), 5); + + let config = Arc::new(config); + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 10).await; + + let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(198, 51, 100, 77)), 55000); + + let mut tasks = tokio::task::JoinSet::new(); + let mut reservations = Vec::new(); + + for _ in 0..10 { + let config = config.clone(); + let stats = stats.clone(); + let ip_tracker = ip_tracker.clone(); + tasks.spawn(async move { + RunningClientHandler::acquire_user_connection_reservation_static( + user, &config, stats, peer, ip_tracker, + ) + .await + }); + } + + let mut successes = 0; + let mut failures = 0; + + while let Some(res) = tasks.join_next().await { + match res.unwrap() { + Ok(r) => { + successes += 1; + reservations.push(r); + } + Err(_) => failures += 1, + } + } + + assert_eq!(successes, 5); + assert_eq!(failures, 5); + assert_eq!(stats.get_user_curr_connects(user), 5); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 1); + + for reservation in reservations { + reservation.release().await; + } + + assert_eq!(stats.get_user_curr_connects(user), 0); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 0); +} diff --git a/src/proxy/tests/client_deep_invariants_tests.rs b/src/proxy/tests/client_deep_invariants_tests.rs new file mode 100644 index 0000000..e57f817 --- /dev/null +++ b/src/proxy/tests/client_deep_invariants_tests.rs @@ -0,0 +1,222 @@ +use super::*; +use crate::config::ProxyConfig; +use crate::protocol::constants::MIN_TLS_CLIENT_HELLO_SIZE; +use crate::stats::Stats; +use crate::transport::UpstreamManager; +use std::sync::Arc; +use std::time::Duration; +use tokio::io::{AsyncWriteExt, duplex}; + +fn preload_user_quota(stats: &Stats, user: &str, bytes: u64) { + let user_stats = stats.get_or_create_user_stats_handle(user); + stats.quota_charge_post_write(user_stats.as_ref(), bytes); +} + +#[test] +fn invariant_wrap_tls_application_record_exact_multiples() { + let chunk_size = u16::MAX as usize; + let payload = vec![0xAA; chunk_size * 2]; + + let wrapped = wrap_tls_application_record(&payload); + + assert_eq!(wrapped.len(), 2 * (5 + chunk_size)); + assert_eq!(wrapped[0], TLS_RECORD_APPLICATION); + assert_eq!(&wrapped[3..5], &65535u16.to_be_bytes()); + + let second_header_idx = 5 + chunk_size; + assert_eq!(wrapped[second_header_idx], TLS_RECORD_APPLICATION); + assert_eq!( + &wrapped[second_header_idx + 3..second_header_idx + 5], + &65535u16.to_be_bytes() + ); +} + +#[tokio::test] +async fn invariant_tls_clienthello_truncation_exact_boundary_triggers_masking() { + let config = Arc::new(ProxyConfig::default()); + let stats = Arc::new(Stats::new()); + + let (server_side, mut client_side) = duplex(4096); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.20:55000".parse().unwrap(), + config, + stats.clone(), + Arc::new(UpstreamManager::new( + vec![], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + false, + )); + + let claimed_len = MIN_TLS_CLIENT_HELLO_SIZE as u16; + let mut header = vec![0x16, 0x03, 0x01]; + header.extend_from_slice(&claimed_len.to_be_bytes()); + + client_side.write_all(&header).await.unwrap(); + client_side + .write_all(&vec![0x42; MIN_TLS_CLIENT_HELLO_SIZE - 1]) + .await + .unwrap(); + client_side.shutdown().await.unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(2), handler) + .await + .unwrap(); + assert_eq!(stats.get_connects_bad(), 1); +} + +#[tokio::test] +async fn invariant_acquire_reservation_ip_limit_rollback() { + let user = "rollback-test-user"; + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 10); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 1).await; + + let peer_a = "198.51.100.21:55000".parse().unwrap(); + let _res_a = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer_a, + ip_tracker.clone(), + ) + .await + .unwrap(); + + assert_eq!(stats.get_user_curr_connects(user), 1); + + let peer_b = "203.0.113.22:55000".parse().unwrap(); + let res_b = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer_b, + ip_tracker.clone(), + ) + .await; + + assert!(matches!( + res_b, + Err(ProxyError::ConnectionLimitExceeded { .. }) + )); + assert_eq!(stats.get_user_curr_connects(user), 1); +} + +#[tokio::test] +async fn invariant_quota_exact_boundary_inclusive() { + let user = "quota-strict-user"; + let mut config = ProxyConfig::default(); + config.access.user_data_quota.insert(user.to_string(), 1000); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + let peer = "198.51.100.23:55000".parse().unwrap(); + + preload_user_quota(stats.as_ref(), user, 999); + let res1 = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await; + assert!(res1.is_ok()); + res1.unwrap().release().await; + + preload_user_quota(stats.as_ref(), user, 1); + let res2 = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await; + assert!(matches!(res2, Err(ProxyError::DataQuotaExceeded { .. }))); +} + +#[tokio::test] +async fn invariant_direct_mode_partial_header_eof_is_error_not_bad_connect() { + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = true; + cfg.general.beobachten_minutes = 1; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(4096); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.25:55000".parse().unwrap(), + config, + stats.clone(), + Arc::new(UpstreamManager::new( + vec![], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + beobachten.clone(), + false, + )); + + client_side.write_all(&[0xEF, 0xEF, 0xEF]).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let result = tokio::time::timeout(Duration::from_secs(2), handler) + .await + .unwrap() + .unwrap(); + + assert!(result.is_err()); + assert_eq!(stats.get_connects_bad(), 0); + let snapshot = beobachten.snapshot_text(Duration::from_secs(60)); + assert!(snapshot.contains("[expected_64_got_0]")); +} + +#[tokio::test] +async fn invariant_route_mode_snapshot_picks_up_latest_mode() { + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + assert!(matches!( + route_runtime.snapshot().mode, + RelayRouteMode::Direct + )); + + route_runtime.set_mode(RelayRouteMode::Middle); + assert!(matches!( + route_runtime.snapshot().mode, + RelayRouteMode::Middle + )); +} diff --git a/src/proxy/tests/client_masking_fragmented_classifier_security_tests.rs b/src/proxy/tests/client_masking_fragmented_classifier_security_tests.rs new file mode 100644 index 0000000..d7ac4ef --- /dev/null +++ b/src/proxy/tests/client_masking_fragmented_classifier_security_tests.rs @@ -0,0 +1,100 @@ +use super::*; +use crate::config::{UpstreamConfig, UpstreamType}; +use std::sync::Arc; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::net::TcpListener; +use tokio::time::Duration; + +fn new_upstream_manager(stats: Arc) -> Arc { + Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats, + )) +} + +#[tokio::test] +async fn fragmented_connect_probe_is_classified_as_http_via_prefetch_window() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = Vec::new(); + stream.read_to_end(&mut got).await.unwrap(); + got + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = true; + cfg.general.beobachten_minutes = 1; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.general.modes.classic = false; + cfg.general.modes.secure = false; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(4096); + let peer: SocketAddr = "198.51.100.251:57501".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats.clone(), + new_upstream_manager(stats), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + beobachten.clone(), + false, + )); + + client_side.write_all(b"CONNE").await.unwrap(); + client_side + .write_all(b"CT example.org:443 HTTP/1.1\r\nHost: example.org\r\n\r\n") + .await + .unwrap(); + client_side.shutdown().await.unwrap(); + + let forwarded = tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + assert!( + forwarded.starts_with(b"CONNECT example.org:443 HTTP/1.1"), + "mask backend must receive the full fragmented CONNECT probe" + ); + + let result = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + assert!(result.is_ok()); + + let snapshot = beobachten.snapshot_text(Duration::from_secs(60)); + assert!(snapshot.contains("[HTTP]")); + assert!(snapshot.contains("198.51.100.251-1")); +} diff --git a/src/proxy/tests/client_masking_http2_fragmented_preface_security_tests.rs b/src/proxy/tests/client_masking_http2_fragmented_preface_security_tests.rs new file mode 100644 index 0000000..3036f95 --- /dev/null +++ b/src/proxy/tests/client_masking_http2_fragmented_preface_security_tests.rs @@ -0,0 +1,122 @@ +use super::*; +use crate::config::{UpstreamConfig, UpstreamType}; +use std::sync::Arc; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::net::TcpListener; +use tokio::time::{Duration, sleep}; + +fn new_upstream_manager(stats: Arc) -> Arc { + Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats, + )) +} + +async fn run_http2_fragment_case(split_at: usize, delay_ms: u64, peer: SocketAddr) { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let preface = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = Vec::new(); + stream.read_to_end(&mut got).await.unwrap(); + got + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = true; + cfg.general.beobachten_minutes = 1; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.general.modes.classic = false; + cfg.general.modes.secure = false; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(4096); + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats.clone(), + new_upstream_manager(stats), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + beobachten.clone(), + false, + )); + + let first = split_at.min(preface.len()); + client_side.write_all(&preface[..first]).await.unwrap(); + if first < preface.len() { + sleep(Duration::from_millis(delay_ms)).await; + client_side.write_all(&preface[first..]).await.unwrap(); + } + client_side.shutdown().await.unwrap(); + + let forwarded = tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + assert!( + forwarded.starts_with(&preface), + "mask backend must receive an intact HTTP/2 preface prefix" + ); + + let result = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + assert!(result.is_ok()); + + let snapshot = beobachten.snapshot_text(Duration::from_secs(60)); + assert!(snapshot.contains("[HTTP]")); + assert!(snapshot.contains(&format!("{}-1", peer.ip()))); +} + +#[tokio::test] +async fn http2_preface_fragmentation_matrix_is_classified_and_forwarded() { + let cases = [(2usize, 0u64), (3, 0), (4, 0), (2, 7), (3, 7), (8, 1)]; + + for (i, (split_at, delay_ms)) in cases.into_iter().enumerate() { + let peer: SocketAddr = format!("198.51.100.{}:58{}", 140 + i, 100 + i) + .parse() + .unwrap(); + run_http2_fragment_case(split_at, delay_ms, peer).await; + } +} + +#[tokio::test] +async fn http2_preface_splitpoint_light_fuzz_classifies_http() { + for split_at in 2usize..=12 { + let delay_ms = if split_at % 3 == 0 { 7 } else { 1 }; + let peer: SocketAddr = format!("198.51.101.{}:59{}", split_at, 10 + split_at) + .parse() + .unwrap(); + run_http2_fragment_case(split_at, delay_ms, peer).await; + } +} diff --git a/src/proxy/tests/client_masking_prefetch_config_pipeline_integration_security_tests.rs b/src/proxy/tests/client_masking_prefetch_config_pipeline_integration_security_tests.rs new file mode 100644 index 0000000..e64dc03 --- /dev/null +++ b/src/proxy/tests/client_masking_prefetch_config_pipeline_integration_security_tests.rs @@ -0,0 +1,150 @@ +use super::*; +use crate::config::{UpstreamConfig, UpstreamType}; +use std::sync::Arc; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::net::TcpListener; +use tokio::time::{Duration, sleep}; + +fn new_upstream_manager(stats: Arc) -> Arc { + Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats, + )) +} + +async fn run_pipeline_prefetch_case( + prefetch_timeout_ms: u64, + delayed_tail_ms: u64, + peer: SocketAddr, +) -> (Vec, String) { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = Vec::new(); + stream.read_to_end(&mut got).await.unwrap(); + got + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = true; + cfg.general.beobachten_minutes = 1; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_classifier_prefetch_timeout_ms = prefetch_timeout_ms; + cfg.general.modes.classic = false; + cfg.general.modes.secure = false; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(4096); + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats.clone(), + new_upstream_manager(stats), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + beobachten.clone(), + false, + )); + + client_side.write_all(b"C").await.unwrap(); + sleep(Duration::from_millis(delayed_tail_ms)).await; + + client_side + .write_all(b"ONNECT example.org:443 HTTP/1.1\r\nHost: example.org\r\n\r\n") + .await + .unwrap(); + client_side.shutdown().await.unwrap(); + + let forwarded = tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + let result = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + assert!(result.is_ok()); + + let snapshot = beobachten.snapshot_text(Duration::from_secs(60)); + (forwarded, snapshot) +} + +#[tokio::test] +async fn tdd_pipeline_prefetch_5ms_misses_15ms_tail_and_classifies_as_port_scanner() { + let peer: SocketAddr = "198.51.100.171:58071".parse().unwrap(); + let (forwarded, snapshot) = run_pipeline_prefetch_case(5, 15, peer).await; + + assert!( + forwarded.starts_with(b"CONNECT"), + "mask backend must still receive full payload bytes in-order" + ); + assert!( + snapshot.contains("[HTTP]") || snapshot.contains("[port-scanner]"), + "unexpected classifier snapshot for 5ms delayed-tail case: {snapshot}" + ); +} + +#[tokio::test] +async fn tdd_pipeline_prefetch_20ms_recovers_15ms_tail_and_classifies_as_http() { + let peer: SocketAddr = "198.51.100.172:58072".parse().unwrap(); + let (forwarded, snapshot) = run_pipeline_prefetch_case(20, 15, peer).await; + + assert!( + forwarded.starts_with(b"CONNECT"), + "mask backend must receive full CONNECT payload" + ); + assert!( + snapshot.contains("[HTTP]"), + "20ms budget should recover delayed fragmented prefix and classify as HTTP" + ); +} + +#[tokio::test] +async fn matrix_pipeline_prefetch_budget_behavior_5_20_50ms() { + let peer5: SocketAddr = "198.51.100.173:58073".parse().unwrap(); + let peer20: SocketAddr = "198.51.100.174:58074".parse().unwrap(); + let peer50: SocketAddr = "198.51.100.175:58075".parse().unwrap(); + + let (_, snap5) = run_pipeline_prefetch_case(5, 35, peer5).await; + let (_, snap20) = run_pipeline_prefetch_case(20, 35, peer20).await; + let (_, snap50) = run_pipeline_prefetch_case(50, 35, peer50).await; + + assert!( + snap5.contains("[HTTP]") || snap5.contains("[port-scanner]"), + "unexpected 5ms snapshot: {snap5}" + ); + assert!( + snap20.contains("[HTTP]") || snap20.contains("[port-scanner]"), + "unexpected 20ms snapshot: {snap20}" + ); + assert!(snap50.contains("[HTTP]")); +} diff --git a/src/proxy/tests/client_masking_prefetch_config_runtime_security_tests.rs b/src/proxy/tests/client_masking_prefetch_config_runtime_security_tests.rs new file mode 100644 index 0000000..64e7a85 --- /dev/null +++ b/src/proxy/tests/client_masking_prefetch_config_runtime_security_tests.rs @@ -0,0 +1,88 @@ +use super::*; +use tokio::io::{AsyncWriteExt, duplex}; +use tokio::time::{Duration, sleep}; + +#[test] +fn prefetch_timeout_budget_reads_from_config() { + let mut cfg = ProxyConfig::default(); + assert_eq!( + mask_classifier_prefetch_timeout(&cfg), + Duration::from_millis(5), + "default prefetch timeout budget must remain 5ms" + ); + + cfg.censorship.mask_classifier_prefetch_timeout_ms = 20; + assert_eq!( + mask_classifier_prefetch_timeout(&cfg), + Duration::from_millis(20), + "runtime prefetch timeout budget must follow configured value" + ); +} + +#[tokio::test] +async fn configured_prefetch_budget_20ms_recovers_tail_delayed_15ms() { + let (mut reader, mut writer) = duplex(1024); + + let writer_task = tokio::spawn(async move { + sleep(Duration::from_millis(15)).await; + writer + .write_all(b"ONNECT example.org:443 HTTP/1.1\r\n") + .await + .expect("tail bytes must be writable"); + writer + .shutdown() + .await + .expect("writer shutdown must succeed"); + }); + + let mut initial_data = b"C".to_vec(); + extend_masking_initial_window_with_timeout( + &mut reader, + &mut initial_data, + Duration::from_millis(20), + ) + .await; + + writer_task + .await + .expect("writer task must not panic in runtime timeout test"); + + assert!( + initial_data.starts_with(b"CONNECT"), + "20ms configured prefetch budget should recover 15ms delayed CONNECT tail" + ); +} + +#[tokio::test] +async fn configured_prefetch_budget_5ms_misses_tail_delayed_15ms() { + let (mut reader, mut writer) = duplex(1024); + + let writer_task = tokio::spawn(async move { + sleep(Duration::from_millis(15)).await; + writer + .write_all(b"ONNECT example.org:443 HTTP/1.1\r\n") + .await + .expect("tail bytes must be writable"); + writer + .shutdown() + .await + .expect("writer shutdown must succeed"); + }); + + let mut initial_data = b"C".to_vec(); + extend_masking_initial_window_with_timeout( + &mut reader, + &mut initial_data, + Duration::from_millis(5), + ) + .await; + + writer_task + .await + .expect("writer task must not panic in runtime timeout test"); + + assert!( + !initial_data.starts_with(b"CONNECT"), + "5ms configured prefetch budget should miss 15ms delayed CONNECT tail" + ); +} diff --git a/src/proxy/tests/client_masking_prefetch_invariant_security_tests.rs b/src/proxy/tests/client_masking_prefetch_invariant_security_tests.rs new file mode 100644 index 0000000..b49db3c --- /dev/null +++ b/src/proxy/tests/client_masking_prefetch_invariant_security_tests.rs @@ -0,0 +1,264 @@ +use super::*; +use crate::config::{UpstreamConfig, UpstreamType}; +use crate::crypto::sha256_hmac; +use crate::protocol::constants::{HANDSHAKE_LEN, TLS_VERSION}; +use crate::protocol::tls; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::net::TcpListener; + +struct PipelineHarness { + config: Arc, + stats: Arc, + upstream_manager: Arc, + replay_checker: Arc, + buffer_pool: Arc, + rng: Arc, + route_runtime: Arc, + ip_tracker: Arc, + beobachten: Arc, +} + +fn build_harness(secret_hex: &str, mask_port: u16) -> PipelineHarness { + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = mask_port; + cfg.censorship.mask_proxy_protocol = 0; + cfg.access.ignore_time_skew = true; + cfg.access + .users + .insert("user".to_string(), secret_hex.to_string()); + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + + PipelineHarness { + config, + stats, + upstream_manager, + replay_checker: Arc::new(ReplayChecker::new(256, Duration::from_secs(60))), + buffer_pool: Arc::new(BufferPool::new()), + rng: Arc::new(SecureRandom::new()), + route_runtime: Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + ip_tracker: Arc::new(UserIpTracker::new()), + beobachten: Arc::new(BeobachtenStore::new()), + } +} + +fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec { + let total_len = 5 + tls_len; + let mut handshake = vec![fill; total_len]; + + handshake[0] = 0x16; + handshake[1] = 0x03; + handshake[2] = 0x01; + handshake[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes()); + + let session_id_len: usize = 32; + handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8; + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + let computed = sha256_hmac(secret, &handshake); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] + .copy_from_slice(&digest); + + handshake +} + +fn wrap_tls_application_data(payload: &[u8]) -> Vec { + let mut record = Vec::with_capacity(5 + payload.len()); + record.push(0x17); + record.extend_from_slice(&TLS_VERSION); + record.extend_from_slice(&(payload.len() as u16).to_be_bytes()); + record.extend_from_slice(payload); + record +} + +async fn read_and_discard_tls_record_body(stream: &mut T, header: [u8; 5]) +where + T: tokio::io::AsyncRead + Unpin, +{ + let len = u16::from_be_bytes([header[3], header[4]]) as usize; + let mut body = vec![0u8; len]; + stream.read_exact(&mut body).await.unwrap(); +} + +#[test] +fn empty_initial_data_prefetch_gate_is_fail_closed() { + assert!( + !should_prefetch_mask_classifier_window(&[]), + "empty initial_data must not trigger classifier prefetch" + ); +} + +#[tokio::test] +async fn blackhat_empty_initial_data_prefetch_must_not_consume_fallback_payload() { + let payload = b"\x17\x03\x03\x00\x10coalesced-tail-bytes".to_vec(); + let (mut reader, mut writer) = duplex(1024); + + writer.write_all(&payload).await.unwrap(); + writer.shutdown().await.unwrap(); + + let mut initial_data = Vec::new(); + extend_masking_initial_window(&mut reader, &mut initial_data).await; + + assert!( + initial_data.is_empty(), + "empty initial_data must remain empty after prefetch stage" + ); + + let mut remaining = Vec::new(); + reader.read_to_end(&mut remaining).await.unwrap(); + assert_eq!( + remaining, payload, + "prefetch stage must not consume fallback payload when initial_data is empty" + ); +} + +#[tokio::test] +async fn positive_fragmented_http_prefix_still_prefetches_within_window() { + let (mut reader, mut writer) = duplex(1024); + writer + .write_all(b"NECT example.org:443 HTTP/1.1\r\n") + .await + .unwrap(); + writer.shutdown().await.unwrap(); + + let mut initial_data = b"CON".to_vec(); + extend_masking_initial_window(&mut reader, &mut initial_data).await; + + assert!( + initial_data.starts_with(b"CONNECT"), + "fragmented HTTP method prefix should still be recoverable by prefetch" + ); + assert!( + initial_data.len() <= 16, + "prefetch window must remain bounded" + ); +} + +#[tokio::test] +async fn light_fuzz_empty_initial_data_never_prefetches_any_bytes() { + let mut seed = 0xD15C_A11E_2026_0322u64; + + for _ in 0..128 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let len = ((seed & 0x3f) as usize).saturating_add(1); + let mut payload = vec![0u8; len]; + for (idx, byte) in payload.iter_mut().enumerate() { + *byte = (seed as u8).wrapping_add(idx as u8).wrapping_mul(17); + } + + let (mut reader, mut writer) = duplex(1024); + writer.write_all(&payload).await.unwrap(); + writer.shutdown().await.unwrap(); + + let mut initial_data = Vec::new(); + extend_masking_initial_window(&mut reader, &mut initial_data).await; + assert!(initial_data.is_empty()); + + let mut remaining = Vec::new(); + reader.read_to_end(&mut remaining).await.unwrap(); + assert_eq!(remaining, payload); + } +} + +#[tokio::test] +async fn blackhat_integration_empty_initial_data_path_is_byte_exact_and_eof_clean() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0xD3u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 411, 600, 0x2B); + let mut invalid_payload = vec![0u8; HANDSHAKE_LEN]; + invalid_payload[0] = 0xFF; + let invalid_mtproto_record = wrap_tls_application_data(&invalid_payload); + let trailing_record = wrap_tls_application_data(b"empty-prefetch-invariant"); + let expected = trailing_record.clone(); + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got = vec![0u8; expected.len()]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, expected); + + let mut one = [0u8; 1]; + let n = stream.read(&mut one).await.unwrap(); + assert_eq!( + n, 0, + "fallback stream must not append synthetic bytes on empty initial_data path" + ); + }); + + let harness = build_harness("d3d3d3d3d3d3d3d3d3d3d3d3d3d3d3d3", backend_addr.port()); + let (server_side, mut client_side) = duplex(131072); + + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.245:56145".parse().unwrap(), + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut head = [0u8; 5]; + client_side.read_exact(&mut head).await.unwrap(); + assert_eq!(head[0], 0x16); + read_and_discard_tls_record_body(&mut client_side, head).await; + + client_side + .write_all(&invalid_mtproto_record) + .await + .unwrap(); + client_side.write_all(&trailing_record).await.unwrap(); + client_side.shutdown().await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} diff --git a/src/proxy/tests/client_masking_prefetch_strict_boundary_security_tests.rs b/src/proxy/tests/client_masking_prefetch_strict_boundary_security_tests.rs new file mode 100644 index 0000000..cbb6603 --- /dev/null +++ b/src/proxy/tests/client_masking_prefetch_strict_boundary_security_tests.rs @@ -0,0 +1,72 @@ +use super::*; +use tokio::io::{AsyncWriteExt, duplex}; +use tokio::time::{Duration, advance, sleep}; + +async fn run_strict_prefetch_case(prefetch_ms: u64, tail_delay_ms: u64) -> Vec { + let (mut reader, mut writer) = duplex(1024); + + let writer_task = tokio::spawn(async move { + sleep(Duration::from_millis(tail_delay_ms)).await; + let _ = writer + .write_all(b"ONNECT example.org:443 HTTP/1.1\r\n") + .await; + let _ = writer.shutdown().await; + }); + + let mut initial_data = b"C".to_vec(); + let mut prefetch_task = tokio::spawn(async move { + extend_masking_initial_window_with_timeout( + &mut reader, + &mut initial_data, + Duration::from_millis(prefetch_ms), + ) + .await; + initial_data + }); + + tokio::task::yield_now().await; + + if tail_delay_ms > 0 { + advance(Duration::from_millis(tail_delay_ms)).await; + tokio::task::yield_now().await; + } + + if prefetch_ms > tail_delay_ms { + advance(Duration::from_millis(prefetch_ms - tail_delay_ms)).await; + tokio::task::yield_now().await; + } + + let result = prefetch_task.await.expect("prefetch task must not panic"); + writer_task.await.expect("writer task must not panic"); + result +} + +#[tokio::test(start_paused = true)] +async fn strict_prefetch_5ms_misses_15ms_tail() { + let got = run_strict_prefetch_case(5, 15).await; + assert_eq!(got, b"C".to_vec()); +} + +#[tokio::test(start_paused = true)] +async fn strict_prefetch_20ms_recovers_15ms_tail() { + let got = run_strict_prefetch_case(20, 15).await; + assert!(got.starts_with(b"CONNECT")); +} + +#[tokio::test(start_paused = true)] +async fn strict_prefetch_50ms_recovers_35ms_tail() { + let got = run_strict_prefetch_case(50, 35).await; + assert!(got.starts_with(b"CONNECT")); +} + +#[tokio::test(start_paused = true)] +async fn strict_prefetch_equal_budget_and_delay_recovers_tail() { + let got = run_strict_prefetch_case(20, 20).await; + assert!(got.starts_with(b"CONNECT")); +} + +#[tokio::test(start_paused = true)] +async fn strict_prefetch_one_ms_after_budget_misses_tail() { + let got = run_strict_prefetch_case(20, 21).await; + assert_eq!(got, b"C".to_vec()); +} diff --git a/src/proxy/tests/client_masking_prefetch_timing_matrix_security_tests.rs b/src/proxy/tests/client_masking_prefetch_timing_matrix_security_tests.rs new file mode 100644 index 0000000..bee1eb3 --- /dev/null +++ b/src/proxy/tests/client_masking_prefetch_timing_matrix_security_tests.rs @@ -0,0 +1,98 @@ +use super::*; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::time::{Duration, sleep, timeout}; + +async fn extend_masking_initial_window_with_budget( + reader: &mut R, + initial_data: &mut Vec, + prefetch_timeout: Duration, +) where + R: AsyncRead + Unpin, +{ + if !should_prefetch_mask_classifier_window(initial_data) { + return; + } + + let need = 16usize.saturating_sub(initial_data.len()); + if need == 0 { + return; + } + + let mut extra = [0u8; 16]; + if let Ok(Ok(n)) = timeout(prefetch_timeout, reader.read(&mut extra[..need])).await + && n > 0 + { + initial_data.extend_from_slice(&extra[..n]); + } +} + +async fn run_prefetch_budget_case(prefetch_budget_ms: u64, delayed_tail_ms: u64) -> bool { + let (mut reader, mut writer) = duplex(1024); + + let writer_task = tokio::spawn(async move { + sleep(Duration::from_millis(delayed_tail_ms)).await; + writer + .write_all(b"ONNECT example.org:443 HTTP/1.1\r\n") + .await + .expect("tail bytes must be writable"); + writer + .shutdown() + .await + .expect("writer shutdown must succeed"); + }); + + let mut initial_data = b"C".to_vec(); + extend_masking_initial_window_with_budget( + &mut reader, + &mut initial_data, + Duration::from_millis(prefetch_budget_ms), + ) + .await; + + writer_task + .await + .expect("writer task must not panic during matrix case"); + + initial_data.starts_with(b"CONNECT") +} + +#[tokio::test] +async fn adversarial_prefetch_budget_matrix_5_20_50ms_for_fragmented_connect_tail() { + let cases = [ + // (tail-delay-ms, expected CONNECT recovery for budgets [5, 20, 50]) + (2u64, [true, true, true]), + (15u64, [false, true, true]), + (35u64, [false, false, true]), + ]; + + for (tail_delay_ms, expected) in cases { + let got_5 = run_prefetch_budget_case(5, tail_delay_ms).await; + let got_20 = run_prefetch_budget_case(20, tail_delay_ms).await; + let got_50 = run_prefetch_budget_case(50, tail_delay_ms).await; + + assert_eq!( + got_5, expected[0], + "5ms prefetch budget mismatch for tail delay {}ms", + tail_delay_ms + ); + assert_eq!( + got_20, expected[1], + "20ms prefetch budget mismatch for tail delay {}ms", + tail_delay_ms + ); + assert_eq!( + got_50, expected[2], + "50ms prefetch budget mismatch for tail delay {}ms", + tail_delay_ms + ); + } +} + +#[tokio::test] +async fn control_current_runtime_prefetch_budget_is_5ms() { + assert_eq!( + MASK_CLASSIFIER_PREFETCH_TIMEOUT, + Duration::from_millis(5), + "matrix assumptions require current runtime prefetch budget to stay at 5ms" + ); +} diff --git a/src/proxy/tests/client_masking_replay_timing_security_tests.rs b/src/proxy/tests/client_masking_replay_timing_security_tests.rs new file mode 100644 index 0000000..c3339e8 --- /dev/null +++ b/src/proxy/tests/client_masking_replay_timing_security_tests.rs @@ -0,0 +1,167 @@ +use super::*; +use crate::config::{UpstreamConfig, UpstreamType}; +use crate::crypto::sha256_hmac; +use crate::protocol::constants::{HANDSHAKE_LEN, TLS_VERSION}; +use crate::protocol::tls; +use std::sync::Arc; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::time::{Duration, Instant}; + +fn new_upstream_manager(stats: Arc) -> Arc { + Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats, + )) +} + +fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec { + let total_len = 5 + tls_len; + let mut handshake = vec![fill; total_len]; + + handshake[0] = 0x16; + handshake[1] = 0x03; + handshake[2] = 0x01; + handshake[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes()); + + let session_id_len: usize = 32; + handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8; + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + let computed = sha256_hmac(secret, &handshake); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] + .copy_from_slice(&digest); + handshake +} + +async fn run_replay_candidate_session( + replay_checker: Arc, + hello: &[u8], + peer: SocketAddr, + drive_mtproto_fail: bool, +) -> Duration { + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = 1; + cfg.censorship.mask_timing_normalization_enabled = false; + cfg.access.ignore_time_skew = true; + cfg.access.users.insert( + "user".to_string(), + "abababababababababababababababab".to_string(), + ); + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(65536); + let started = Instant::now(); + + let task = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats.clone(), + new_upstream_manager(stats), + replay_checker, + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + beobachten, + false, + )); + + client_side.write_all(hello).await.unwrap(); + + if drive_mtproto_fail { + let mut server_hello_head = [0u8; 5]; + client_side + .read_exact(&mut server_hello_head) + .await + .unwrap(); + assert_eq!(server_hello_head[0], 0x16); + let body_len = u16::from_be_bytes([server_hello_head[3], server_hello_head[4]]) as usize; + let mut body = vec![0u8; body_len]; + client_side.read_exact(&mut body).await.unwrap(); + + let mut invalid_mtproto_record = Vec::with_capacity(5 + HANDSHAKE_LEN); + invalid_mtproto_record.push(0x17); + invalid_mtproto_record.extend_from_slice(&TLS_VERSION); + invalid_mtproto_record.extend_from_slice(&(HANDSHAKE_LEN as u16).to_be_bytes()); + invalid_mtproto_record.extend_from_slice(&vec![0u8; HANDSHAKE_LEN]); + client_side + .write_all(&invalid_mtproto_record) + .await + .unwrap(); + client_side + .write_all(b"GET /replay-fallback HTTP/1.1\r\nHost: x\r\n\r\n") + .await + .unwrap(); + } + + client_side.shutdown().await.unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(4), task) + .await + .unwrap() + .unwrap(); + + started.elapsed() +} + +#[tokio::test] +async fn replay_reject_still_honors_masking_timing_budget() { + let replay_checker = Arc::new(ReplayChecker::new(256, Duration::from_secs(60))); + let hello = make_valid_tls_client_hello(&[0xAB; 16], 7, 600, 0x51); + + let seed_elapsed = run_replay_candidate_session( + Arc::clone(&replay_checker), + &hello, + "198.51.100.201:58001".parse().unwrap(), + true, + ) + .await; + + assert!( + seed_elapsed >= Duration::from_millis(40) && seed_elapsed < Duration::from_millis(250), + "seed replay-candidate run must honor masking timing budget without unbounded delay" + ); + + let replay_elapsed = run_replay_candidate_session( + Arc::clone(&replay_checker), + &hello, + "198.51.100.202:58002".parse().unwrap(), + false, + ) + .await; + + assert!( + replay_elapsed >= Duration::from_millis(40) && replay_elapsed < Duration::from_millis(250), + "replay rejection path must still satisfy masking timing budget without unbounded DB/CPU delay" + ); +} diff --git a/src/proxy/tests/client_more_advanced_tests.rs b/src/proxy/tests/client_more_advanced_tests.rs new file mode 100644 index 0000000..8f9d832 --- /dev/null +++ b/src/proxy/tests/client_more_advanced_tests.rs @@ -0,0 +1,288 @@ +use super::*; +use crate::config::ProxyConfig; +use crate::stats::Stats; +use crate::transport::UpstreamManager; +use std::sync::Arc; +use std::time::Duration; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; + +fn preload_user_quota(stats: &Stats, user: &str, bytes: u64) { + let user_stats = stats.get_or_create_user_stats_handle(user); + stats.quota_charge_post_write(user_stats.as_ref(), bytes); +} + +#[tokio::test] +async fn edge_mask_delay_bypassed_if_max_is_zero() { + let mut config = ProxyConfig::default(); + config.censorship.server_hello_delay_min_ms = 10_000; + config.censorship.server_hello_delay_max_ms = 0; + + let start = std::time::Instant::now(); + maybe_apply_mask_reject_delay(&config).await; + assert!(start.elapsed() < Duration::from_millis(50)); +} + +#[test] +fn edge_beobachten_ttl_clamps_exactly_to_24_hours() { + let mut config = ProxyConfig::default(); + config.general.beobachten = true; + config.general.beobachten_minutes = 100_000; + + let ttl = beobachten_ttl(&config); + assert_eq!(ttl.as_secs(), 24 * 60 * 60); +} + +#[test] +fn edge_wrap_tls_application_record_empty_payload() { + let wrapped = wrap_tls_application_record(&[]); + assert_eq!(wrapped.len(), 5); + assert_eq!(wrapped[0], TLS_RECORD_APPLICATION); + assert_eq!(&wrapped[3..5], &[0, 0]); +} + +#[tokio::test] +async fn boundary_user_data_quota_exact_match_rejects() { + let user = "quota-boundary-user"; + let mut config = ProxyConfig::default(); + config.access.user_data_quota.insert(user.to_string(), 1024); + + let stats = Arc::new(Stats::new()); + preload_user_quota(stats.as_ref(), user, 1024); + + let ip_tracker = Arc::new(UserIpTracker::new()); + let peer = "198.51.100.10:55000".parse().unwrap(); + + let result = RunningClientHandler::acquire_user_connection_reservation_static( + user, &config, stats, peer, ip_tracker, + ) + .await; + + assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. }))); +} + +#[tokio::test] +async fn boundary_user_expiration_in_past_rejects() { + let user = "expired-boundary-user"; + let mut config = ProxyConfig::default(); + let expired_time = chrono::Utc::now() - chrono::Duration::milliseconds(1); + config + .access + .user_expirations + .insert(user.to_string(), expired_time); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + let peer = "198.51.100.11:55000".parse().unwrap(); + + let result = RunningClientHandler::acquire_user_connection_reservation_static( + user, &config, stats, peer, ip_tracker, + ) + .await; + + assert!(matches!(result, Err(ProxyError::UserExpired { .. }))); +} + +#[tokio::test] +async fn blackhat_proxy_protocol_massive_garbage_rejected_quickly() { + let mut cfg = ProxyConfig::default(); + cfg.server.proxy_protocol_header_timeout_ms = 300; + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + + let (server_side, mut client_side) = duplex(4096); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.12:55000".parse().unwrap(), + config, + stats.clone(), + Arc::new(UpstreamManager::new( + vec![], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + true, + )); + + client_side.write_all(&vec![b'A'; 2000]).await.unwrap(); + + let result = tokio::time::timeout(Duration::from_secs(2), handler) + .await + .unwrap() + .unwrap(); + assert!(matches!(result, Err(ProxyError::InvalidProxyProtocol))); + assert_eq!(stats.get_connects_bad(), 1); +} + +#[tokio::test] +async fn edge_tls_body_immediate_eof_triggers_masking_and_bad_connect() { + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = true; + cfg.general.beobachten_minutes = 1; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(4096); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.13:55000".parse().unwrap(), + config, + stats.clone(), + Arc::new(UpstreamManager::new( + vec![], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + beobachten.clone(), + false, + )); + + client_side + .write_all(&[0x16, 0x03, 0x01, 0x00, 100]) + .await + .unwrap(); + client_side.shutdown().await.unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(2), handler) + .await + .unwrap(); + + assert_eq!(stats.get_connects_bad(), 1); +} + +#[tokio::test] +async fn security_classic_mode_disabled_masks_valid_length_payload() { + let mut cfg = ProxyConfig::default(); + cfg.general.modes.classic = false; + cfg.general.modes.secure = false; + cfg.censorship.mask = true; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + + let (server_side, mut client_side) = duplex(4096); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.15:55000".parse().unwrap(), + config, + stats.clone(), + Arc::new(UpstreamManager::new( + vec![], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + false, + )); + + client_side.write_all(&vec![0xEF; 64]).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(2), handler) + .await + .unwrap(); + assert_eq!(stats.get_connects_bad(), 1); +} + +#[tokio::test] +async fn concurrency_ip_tracker_strict_limit_one_rapid_churn() { + let user = "rapid-churn-user"; + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 10); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 1).await; + + let peer = "198.51.100.16:55000".parse().unwrap(); + + for _ in 0..500 { + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await + .unwrap(); + reservation.release().await; + } + + assert_eq!(stats.get_user_curr_connects(user), 0); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 0); +} + +#[tokio::test] +async fn quirk_read_with_progress_zero_length_buffer_returns_zero_immediately() { + let (mut server_side, _client_side) = duplex(4096); + let mut empty_buf = &mut [][..]; + + let result = tokio::time::timeout( + Duration::from_millis(50), + read_with_progress(&mut server_side, &mut empty_buf), + ) + .await; + + assert!(result.is_ok()); + assert_eq!(result.unwrap().unwrap(), 0); +} + +#[tokio::test] +async fn stress_read_with_progress_cancellation_safety() { + let (mut server_side, mut client_side) = duplex(4096); + + client_side.write_all(b"12345").await.unwrap(); + + let mut buf = [0u8; 10]; + let result = tokio::time::timeout( + Duration::from_millis(50), + read_with_progress(&mut server_side, &mut buf), + ) + .await; + + assert!(result.is_err()); + + client_side.write_all(b"67890").await.unwrap(); + let mut buf2 = [0u8; 5]; + server_side.read_exact(&mut buf2).await.unwrap(); + assert_eq!(&buf2, b"67890"); +} diff --git a/src/proxy/tests/client_security_tests.rs b/src/proxy/tests/client_security_tests.rs index 6338e23..1b46c6d 100644 --- a/src/proxy/tests/client_security_tests.rs +++ b/src/proxy/tests/client_security_tests.rs @@ -7,6 +7,9 @@ use crate::protocol::tls; use crate::proxy::handshake::HandshakeSuccess; use crate::stream::{CryptoReader, CryptoWriter}; use crate::transport::proxy_protocol::ProxyProtocolV1Builder; +use rand::Rng; +use rand::SeedableRng; +use rand::rngs::StdRng; use std::net::Ipv4Addr; use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; use tokio::net::{TcpListener, TcpStream}; @@ -25,6 +28,220 @@ fn synthetic_local_addr_uses_configured_port_for_max() { assert_eq!(addr.port(), u16::MAX); } +#[test] +fn handshake_timeout_with_mask_grace_includes_mask_margin() { + let mut config = ProxyConfig::default(); + config.timeouts.client_handshake = 2; + + config.censorship.mask = false; + assert_eq!( + handshake_timeout_with_mask_grace(&config), + Duration::from_secs(2) + ); + + config.censorship.mask = true; + assert_eq!( + handshake_timeout_with_mask_grace(&config), + Duration::from_millis(2750), + "mask mode extends handshake timeout by 750 ms" + ); +} + +#[tokio::test] +async fn read_with_progress_reads_partial_buffers_before_eof() { + let data = vec![0xAA, 0xBB, 0xCC]; + let mut reader = std::io::Cursor::new(data); + let mut buf = [0u8; 5]; + + let read = read_with_progress(&mut reader, &mut buf).await.unwrap(); + assert_eq!(read, 3); + assert_eq!(&buf[..3], &[0xAA, 0xBB, 0xCC]); +} + +#[test] +fn is_trusted_proxy_source_respects_cidr_list_and_empty_rejects_all() { + let peer: IpAddr = "10.10.10.10".parse().unwrap(); + assert!(!is_trusted_proxy_source(peer, &[])); + + let trusted = vec!["10.0.0.0/8".parse().unwrap()]; + assert!(is_trusted_proxy_source(peer, &trusted)); + + let not_trusted = vec!["192.0.2.0/24".parse().unwrap()]; + assert!(!is_trusted_proxy_source(peer, ¬_trusted)); +} + +#[test] +fn is_trusted_proxy_source_accepts_cidr_zero_zero_as_global_cidr() { + let peer: IpAddr = "203.0.113.42".parse().unwrap(); + let trust_all = vec!["0.0.0.0/0".parse().unwrap()]; + assert!(is_trusted_proxy_source(peer, &trust_all)); + + let peer_v6: IpAddr = "2001:db8::1".parse().unwrap(); + let trust_all_v6 = vec!["::/0".parse().unwrap()]; + assert!(is_trusted_proxy_source(peer_v6, &trust_all_v6)); +} + +struct ErrorReader; + +impl tokio::io::AsyncRead for ErrorReader { + fn poll_read( + self: std::pin::Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + _buf: &mut tokio::io::ReadBuf<'_>, + ) -> std::task::Poll> { + std::task::Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + "fake error", + ))) + } +} + +#[tokio::test] +async fn read_with_progress_returns_error_from_failed_reader() { + let mut reader = ErrorReader; + let mut buf = [0u8; 8]; + let err = read_with_progress(&mut reader, &mut buf).await.unwrap_err(); + assert_eq!(err.kind(), std::io::ErrorKind::UnexpectedEof); +} + +#[test] +fn handshake_timeout_with_mask_grace_handles_maximum_values_without_overflow() { + let mut config = ProxyConfig::default(); + config.timeouts.client_handshake = u64::MAX; + config.censorship.mask = true; + + let timeout = handshake_timeout_with_mask_grace(&config); + assert!(timeout >= Duration::from_secs(u64::MAX)); +} + +#[tokio::test] +async fn read_with_progress_zero_length_buffer_returns_zero() { + let data = vec![1, 2, 3]; + let mut reader = std::io::Cursor::new(data); + let mut buf = []; + + let read = read_with_progress(&mut reader, &mut buf).await.unwrap(); + assert_eq!(read, 0); +} + +#[test] +fn handshake_timeout_without_mask_is_exact_base() { + let mut config = ProxyConfig::default(); + config.timeouts.client_handshake = 7; + config.censorship.mask = false; + + assert_eq!( + handshake_timeout_with_mask_grace(&config), + Duration::from_secs(7) + ); +} + +#[test] +fn handshake_timeout_mask_enabled_adds_750ms() { + let mut config = ProxyConfig::default(); + config.timeouts.client_handshake = 3; + config.censorship.mask = true; + + assert_eq!( + handshake_timeout_with_mask_grace(&config), + Duration::from_millis(3750) + ); +} + +#[tokio::test] +async fn read_with_progress_full_then_empty_transition() { + let data = vec![0x10, 0x20]; + let mut cursor = std::io::Cursor::new(data); + let mut buf = [0u8; 2]; + + assert_eq!(read_with_progress(&mut cursor, &mut buf).await.unwrap(), 2); + assert_eq!(read_with_progress(&mut cursor, &mut buf).await.unwrap(), 0); +} + +#[tokio::test] +async fn read_with_progress_fragmented_io_works_over_multiple_calls() { + let mut cursor = std::io::Cursor::new(vec![1, 2, 3, 4, 5]); + let mut result = Vec::new(); + + for chunk_size in 1..=5 { + let mut b = vec![0u8; chunk_size]; + let n = read_with_progress(&mut cursor, &mut b).await.unwrap(); + result.extend_from_slice(&b[..n]); + if n == 0 { + break; + } + } + + assert_eq!(result, vec![1, 2, 3, 4, 5]); +} + +#[tokio::test] +async fn read_with_progress_stress_randomized_chunk_sizes() { + for i in 0..128 { + let mut rng = StdRng::seed_from_u64(i as u64 + 1); + let mut input: Vec = (0..(i % 41)).map(|_| rng.next_u32() as u8).collect(); + let mut cursor = std::io::Cursor::new(input.clone()); + let mut collected = Vec::new(); + + while cursor.position() < cursor.get_ref().len() as u64 { + let chunk = 1 + (rng.next_u32() as usize % 8); + let mut b = vec![0u8; chunk]; + let read = read_with_progress(&mut cursor, &mut b).await.unwrap(); + collected.extend_from_slice(&b[..read]); + if read == 0 { + break; + } + } + + assert_eq!(collected, input); + } +} + +#[test] +fn is_trusted_proxy_source_boundary_narrow_ipv4() { + let matching = "172.16.0.1".parse().unwrap(); + let not_matching = "172.15.255.255".parse().unwrap(); + let cidr = vec!["172.16.0.0/12".parse().unwrap()]; + assert!(is_trusted_proxy_source(matching, &cidr)); + assert!(!is_trusted_proxy_source(not_matching, &cidr)); +} + +#[test] +fn is_trusted_proxy_source_rejects_out_of_family_ipv6_v4_cidr() { + let peer = "2001:db8::1".parse().unwrap(); + let cidr = vec!["10.0.0.0/8".parse().unwrap()]; + assert!(!is_trusted_proxy_source(peer, &cidr)); +} + +#[test] +fn wrap_tls_application_record_reserved_chunks_look_reasonable() { + let payload = vec![0xAA; 1 + (u16::MAX as usize) + 2]; + let wrapped = wrap_tls_application_record(&payload); + assert!(wrapped.len() > payload.len()); + assert!(wrapped.contains(&0x17)); +} + +#[test] +fn wrap_tls_application_record_roundtrip_size_check() { + let payload_len = 3000; + let payload = vec![0x55; payload_len]; + let wrapped = wrap_tls_application_record(&payload); + + let mut idx = 0; + let mut consumed = 0; + while idx + 5 <= wrapped.len() { + assert_eq!(wrapped[idx], 0x17); + let len = u16::from_be_bytes([wrapped[idx + 3], wrapped[idx + 4]]) as usize; + consumed += len; + idx += 5 + len; + if idx >= wrapped.len() { + break; + } + } + + assert_eq!(consumed, payload_len); +} + fn make_crypto_reader(reader: R) -> CryptoReader where R: tokio::io::AsyncRead + Unpin, @@ -43,6 +260,11 @@ where CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) } +fn preload_user_quota(stats: &Stats, user: &str, bytes: u64) { + let user_stats = stats.get_or_create_user_stats_handle(user); + stats.quota_charge_post_write(user_stats.as_ref(), bytes); +} + #[tokio::test] async fn user_connection_reservation_drop_enqueues_cleanup_synchronously() { let ip_tracker = Arc::new(crate::ip_tracker::UserIpTracker::new()); @@ -2841,7 +3063,7 @@ async fn quota_rejection_does_not_reserve_ip_or_trigger_rollback() { .insert("user".to_string(), 1024); let stats = Stats::new(); - stats.add_user_octets_from("user", 1024); + preload_user_quota(&stats, "user", 1024); let ip_tracker = UserIpTracker::new(); let peer_addr: SocketAddr = "203.0.113.211:50001".parse().unwrap(); diff --git a/src/proxy/tests/client_tls_record_wrap_hardening_security_tests.rs b/src/proxy/tests/client_tls_record_wrap_hardening_security_tests.rs index 08f52d1..7964cdd 100644 --- a/src/proxy/tests/client_tls_record_wrap_hardening_security_tests.rs +++ b/src/proxy/tests/client_tls_record_wrap_hardening_security_tests.rs @@ -25,13 +25,26 @@ fn wrap_tls_application_record_oversized_payload_is_chunked_without_truncation() let len = u16::from_be_bytes([record[offset + 3], record[offset + 4]]) as usize; let body_start = offset + 5; let body_end = body_start + len; - assert!(body_end <= record.len(), "declared TLS record length must be in-bounds"); + assert!( + body_end <= record.len(), + "declared TLS record length must be in-bounds" + ); recovered.extend_from_slice(&record[body_start..body_end]); offset = body_end; frames += 1; } - assert_eq!(offset, record.len(), "record parser must consume exact output size"); - assert_eq!(frames, 2, "oversized payload should split into exactly two records"); - assert_eq!(recovered, payload, "chunked records must preserve full payload"); + assert_eq!( + offset, + record.len(), + "record parser must consume exact output size" + ); + assert_eq!( + frames, 2, + "oversized payload should split into exactly two records" + ); + assert_eq!( + recovered, payload, + "chunked records must preserve full payload" + ); } diff --git a/src/proxy/tests/direct_relay_security_tests.rs b/src/proxy/tests/direct_relay_security_tests.rs index 16fe8da..a731830 100644 --- a/src/proxy/tests/direct_relay_security_tests.rs +++ b/src/proxy/tests/direct_relay_security_tests.rs @@ -773,8 +773,7 @@ fn anchored_open_nix_path_writes_expected_lines() { "target/telemt-unknown-dc-anchored-open-ok-{}/unknown-dc.log", std::process::id() ); - let sanitized = - sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize"); + let sanitized = sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize"); let _ = fs::remove_file(&sanitized.resolved_path); let mut first = open_unknown_dc_log_append_anchored(&sanitized) @@ -787,7 +786,10 @@ fn anchored_open_nix_path_writes_expected_lines() { let content = fs::read_to_string(&sanitized.resolved_path).expect("anchored log file must be readable"); - let lines: Vec<&str> = content.lines().filter(|line| !line.trim().is_empty()).collect(); + let lines: Vec<&str> = content + .lines() + .filter(|line| !line.trim().is_empty()) + .collect(); assert_eq!(lines.len(), 2, "expected one line per anchored append call"); assert!( lines.contains(&"dc_idx=31200") && lines.contains(&"dc_idx=31201"), @@ -811,8 +813,7 @@ fn anchored_open_parallel_appends_preserve_line_integrity() { "target/telemt-unknown-dc-anchored-open-parallel-{}/unknown-dc.log", std::process::id() ); - let sanitized = - sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize"); + let sanitized = sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize"); let _ = fs::remove_file(&sanitized.resolved_path); let mut workers = Vec::new(); @@ -831,8 +832,15 @@ fn anchored_open_parallel_appends_preserve_line_integrity() { let content = fs::read_to_string(&sanitized.resolved_path).expect("parallel log file must be readable"); - let lines: Vec<&str> = content.lines().filter(|line| !line.trim().is_empty()).collect(); - assert_eq!(lines.len(), 64, "expected one complete line per worker append"); + let lines: Vec<&str> = content + .lines() + .filter(|line| !line.trim().is_empty()) + .collect(); + assert_eq!( + lines.len(), + 64, + "expected one complete line per worker append" + ); for line in lines { assert!( line.starts_with("dc_idx="), @@ -867,8 +875,7 @@ fn anchored_open_creates_private_0600_file_permissions() { "target/telemt-unknown-dc-anchored-perms-{}/unknown-dc.log", std::process::id() ); - let sanitized = - sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize"); + let sanitized = sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize"); let _ = fs::remove_file(&sanitized.resolved_path); let mut file = open_unknown_dc_log_append_anchored(&sanitized) @@ -905,8 +912,7 @@ fn anchored_open_rejects_existing_symlink_target() { "target/telemt-unknown-dc-anchored-symlink-target-{}/unknown-dc.log", std::process::id() ); - let sanitized = - sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize"); + let sanitized = sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize"); let outside = std::env::temp_dir().join(format!( "telemt-unknown-dc-anchored-symlink-outside-{}.log", @@ -943,8 +949,7 @@ fn anchored_open_high_contention_multi_write_preserves_complete_lines() { "target/telemt-unknown-dc-anchored-contention-{}/unknown-dc.log", std::process::id() ); - let sanitized = - sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize"); + let sanitized = sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize"); let _ = fs::remove_file(&sanitized.resolved_path); let workers = 24usize; @@ -970,7 +975,10 @@ fn anchored_open_high_contention_multi_write_preserves_complete_lines() { let content = fs::read_to_string(&sanitized.resolved_path) .expect("contention output file must be readable"); - let lines: Vec<&str> = content.lines().filter(|line| !line.trim().is_empty()).collect(); + let lines: Vec<&str> = content + .lines() + .filter(|line| !line.trim().is_empty()) + .collect(); assert_eq!( lines.len(), workers * rounds, @@ -1014,8 +1022,7 @@ fn append_unknown_dc_line_returns_error_for_read_only_descriptor() { "target/telemt-unknown-dc-append-ro-{}/unknown-dc.log", std::process::id() ); - let sanitized = - sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize"); + let sanitized = sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize"); fs::write(&sanitized.resolved_path, "seed\n").expect("seed file must be writable"); let mut readonly = std::fs::OpenOptions::new() diff --git a/src/proxy/tests/handshake_advanced_clever_tests.rs b/src/proxy/tests/handshake_advanced_clever_tests.rs new file mode 100644 index 0000000..76347c4 --- /dev/null +++ b/src/proxy/tests/handshake_advanced_clever_tests.rs @@ -0,0 +1,719 @@ +use super::*; +use crate::crypto::{AesCtr, sha256, sha256_hmac}; +use crate::protocol::constants::{ProtoTag, RESERVED_NONCE_BEGINNINGS, RESERVED_NONCE_FIRST_BYTES}; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; +use std::sync::Arc; +use std::time::{Duration, Instant}; + +// --- Helpers --- + +fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> { + auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig { + let mut cfg = ProxyConfig::default(); + cfg.access.users.clear(); + cfg.access + .users + .insert("user".to_string(), secret_hex.to_string()); + cfg.access.ignore_time_skew = true; + cfg.general.modes.secure = true; + cfg.general.modes.classic = true; + cfg.general.modes.tls = true; + cfg +} + +fn make_valid_tls_handshake(secret: &[u8], timestamp: u32) -> Vec { + let session_id_len: usize = 32; + let len = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + session_id_len; + let mut handshake = vec![0x42u8; len]; + + handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8; + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + + let computed = sha256_hmac(secret, &handshake); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] + .copy_from_slice(&digest); + handshake +} + +fn make_valid_mtproto_handshake( + secret_hex: &str, + proto_tag: ProtoTag, + dc_idx: i16, +) -> [u8; HANDSHAKE_LEN] { + let secret = hex::decode(secret_hex).expect("secret hex must decode"); + let mut handshake = [0x5Au8; HANDSHAKE_LEN]; + for (idx, b) in handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN] + .iter_mut() + .enumerate() + { + *b = (idx as u8).wrapping_add(1); + } + + let dec_prekey = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN]; + let dec_iv_bytes = &handshake[SKIP_LEN + PREKEY_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]; + + let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len()); + dec_key_input.extend_from_slice(dec_prekey); + dec_key_input.extend_from_slice(&secret); + let dec_key = sha256(&dec_key_input); + + let mut dec_iv_arr = [0u8; IV_LEN]; + dec_iv_arr.copy_from_slice(dec_iv_bytes); + let dec_iv = u128::from_be_bytes(dec_iv_arr); + + let mut stream = AesCtr::new(&dec_key, dec_iv); + let keystream = stream.encrypt(&[0u8; HANDSHAKE_LEN]); + + let mut target_plain = [0u8; HANDSHAKE_LEN]; + target_plain[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes()); + target_plain[DC_IDX_POS..DC_IDX_POS + 2].copy_from_slice(&dc_idx.to_le_bytes()); + + for idx in PROTO_TAG_POS..HANDSHAKE_LEN { + handshake[idx] = target_plain[idx] ^ keystream[idx]; + } + + handshake +} + +fn make_valid_tls_client_hello_with_alpn( + secret: &[u8], + timestamp: u32, + alpn_protocols: &[&[u8]], +) -> Vec { + let mut body = Vec::new(); + body.extend_from_slice(&TLS_VERSION); + body.extend_from_slice(&[0u8; 32]); + body.push(32); + body.extend_from_slice(&[0x42u8; 32]); + body.extend_from_slice(&2u16.to_be_bytes()); + body.extend_from_slice(&[0x13, 0x01]); + body.push(1); + body.push(0); + + let mut ext_blob = Vec::new(); + if !alpn_protocols.is_empty() { + let mut alpn_list = Vec::new(); + for proto in alpn_protocols { + alpn_list.push(proto.len() as u8); + alpn_list.extend_from_slice(proto); + } + let mut alpn_data = Vec::new(); + alpn_data.extend_from_slice(&(alpn_list.len() as u16).to_be_bytes()); + alpn_data.extend_from_slice(&alpn_list); + + ext_blob.extend_from_slice(&0x0010u16.to_be_bytes()); + ext_blob.extend_from_slice(&(alpn_data.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&alpn_data); + } + body.extend_from_slice(&(ext_blob.len() as u16).to_be_bytes()); + body.extend_from_slice(&ext_blob); + + let mut handshake = Vec::new(); + handshake.push(0x01); + let body_len = (body.len() as u32).to_be_bytes(); + handshake.extend_from_slice(&body_len[1..4]); + handshake.extend_from_slice(&body); + + let mut record = Vec::new(); + record.push(TLS_RECORD_HANDSHAKE); + record.extend_from_slice(&[0x03, 0x01]); + record.extend_from_slice(&(handshake.len() as u16).to_be_bytes()); + record.extend_from_slice(&handshake); + + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + let computed = sha256_hmac(secret, &record); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].copy_from_slice(&digest); + + record +} + +// --- Category 1: Edge Cases & Protocol Boundaries --- + +#[tokio::test] +async fn tls_minimum_viable_length_boundary() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret = [0x11u8; 16]; + let config = test_config_with_secret_hex("11111111111111111111111111111111"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "192.0.2.1:12345".parse().unwrap(); + + let min_len = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1; + let mut exact_min_handshake = vec![0x42u8; min_len]; + exact_min_handshake[min_len - 1] = 0; + exact_min_handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + let digest = sha256_hmac(&secret, &exact_min_handshake); + exact_min_handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] + .copy_from_slice(&digest); + + let res = handle_tls_handshake( + &exact_min_handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!( + matches!(res, HandshakeResult::Success(_)), + "Exact minimum length TLS handshake must succeed" + ); + + let short_handshake = vec![0x42u8; min_len - 1]; + let res_short = handle_tls_handshake( + &short_handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!( + matches!(res_short, HandshakeResult::BadClient { .. }), + "Handshake 1 byte shorter than minimum must fail closed" + ); +} + +#[tokio::test] +async fn mtproto_extreme_dc_index_serialization() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret_hex = "22222222222222222222222222222222"; + let config = test_config_with_secret_hex(secret_hex); + for (idx, extreme_dc) in [i16::MIN, i16::MAX, -1, 0].into_iter().enumerate() { + // Keep replay state independent per case so we validate dc_idx encoding, + // not duplicate-handshake rejection behavior. + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 0, 2, 2)), 12345 + idx as u16); + let handshake = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, extreme_dc); + let res = handle_mtproto_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + + match res { + HandshakeResult::Success((_, _, success)) => { + assert_eq!( + success.dc_idx, extreme_dc, + "Extreme DC index {} must serialize/deserialize perfectly", + extreme_dc + ); + } + _ => panic!( + "MTProto handshake with extreme DC index {} failed", + extreme_dc + ), + } + } +} + +#[tokio::test] +async fn alpn_strict_case_and_padding_rejection() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret = [0x33u8; 16]; + let mut config = test_config_with_secret_hex("33333333333333333333333333333333"); + config.censorship.alpn_enforce = true; + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "192.0.2.3:12345".parse().unwrap(); + + let bad_alpns: &[&[u8]] = &[b"H2", b"h2\0", b" http/1.1", b"http/1.1\n"]; + + for bad_alpn in bad_alpns { + let handshake = make_valid_tls_client_hello_with_alpn(&secret, 0, &[*bad_alpn]); + let res = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!( + matches!(res, HandshakeResult::BadClient { .. }), + "ALPN strict enforcement must reject {:?}", + bad_alpn + ); + } +} + +#[test] +fn ipv4_mapped_ipv6_bucketing_anomaly() { + let ipv4_mapped_1 = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0xffff, 0xc000, 0x0201)); + let ipv4_mapped_2 = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0xffff, 0xc633, 0x6402)); + + let norm_1 = normalize_auth_probe_ip(ipv4_mapped_1); + let norm_2 = normalize_auth_probe_ip(ipv4_mapped_2); + + assert_eq!( + norm_1, norm_2, + "IPv4-mapped IPv6 addresses must collapse into the same /64 bucket (::0)" + ); + assert_eq!( + norm_1, + IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), + "The bucket must be exactly ::0" + ); +} + +// --- Category 2: Adversarial & Black Hat --- + +#[tokio::test] +async fn mtproto_invalid_ciphertext_does_not_poison_replay_cache() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret_hex = "55555555555555555555555555555555"; + let config = test_config_with_secret_hex(secret_hex); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let peer: SocketAddr = "192.0.2.5:12345".parse().unwrap(); + + let valid_handshake = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 1); + let mut invalid_handshake = valid_handshake; + invalid_handshake[SKIP_LEN + PREKEY_LEN + IV_LEN + 1] ^= 0xFF; + + let res_invalid = handle_mtproto_handshake( + &invalid_handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + assert!(matches!(res_invalid, HandshakeResult::BadClient { .. })); + + let res_valid = handle_mtproto_handshake( + &valid_handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + assert!( + matches!(res_valid, HandshakeResult::Success(_)), + "Invalid MTProto ciphertext must not poison the replay cache" + ); +} + +#[tokio::test] +async fn tls_invalid_session_does_not_poison_replay_cache() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret = [0x66u8; 16]; + let config = test_config_with_secret_hex("66666666666666666666666666666666"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "192.0.2.6:12345".parse().unwrap(); + + let valid_handshake = make_valid_tls_handshake(&secret, 0); + let mut invalid_handshake = valid_handshake.clone(); + let session_idx = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1; + invalid_handshake[session_idx] ^= 0xFF; + + let res_invalid = handle_tls_handshake( + &invalid_handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(res_invalid, HandshakeResult::BadClient { .. })); + + let res_valid = handle_tls_handshake( + &valid_handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!( + matches!(res_valid, HandshakeResult::Success(_)), + "Invalid TLS payload must not poison the replay cache" + ); +} + +#[tokio::test] +async fn server_hello_delay_timing_neutrality_on_hmac_failure() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret = [0x77u8; 16]; + let mut config = test_config_with_secret_hex("77777777777777777777777777777777"); + config.censorship.server_hello_delay_min_ms = 50; + config.censorship.server_hello_delay_max_ms = 50; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "192.0.2.7:12345".parse().unwrap(); + + let mut invalid_handshake = make_valid_tls_handshake(&secret, 0); + invalid_handshake[tls::TLS_DIGEST_POS] ^= 0xFF; + + let start = Instant::now(); + let res = handle_tls_handshake( + &invalid_handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + let elapsed = start.elapsed(); + + assert!(matches!(res, HandshakeResult::BadClient { .. })); + assert!( + elapsed >= Duration::from_millis(45), + "Invalid HMAC must still incur the configured ServerHello delay to prevent timing side-channels" + ); +} + +#[tokio::test] +async fn server_hello_delay_inversion_resilience() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret = [0x88u8; 16]; + let mut config = test_config_with_secret_hex("88888888888888888888888888888888"); + config.censorship.server_hello_delay_min_ms = 100; + config.censorship.server_hello_delay_max_ms = 10; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "192.0.2.8:12345".parse().unwrap(); + + let valid_handshake = make_valid_tls_handshake(&secret, 0); + + let start = Instant::now(); + let res = handle_tls_handshake( + &valid_handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + let elapsed = start.elapsed(); + + assert!(matches!(res, HandshakeResult::Success(_))); + assert!( + elapsed >= Duration::from_millis(90), + "Delay logic must gracefully handle min > max inversions via max.max(min)" + ); +} + +#[tokio::test] +async fn mixed_valid_and_invalid_user_secrets_configuration() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + let _warn_guard = warned_secrets_test_lock().lock().unwrap(); + clear_warned_secrets_for_testing(); + + let mut config = ProxyConfig::default(); + config.access.ignore_time_skew = true; + + for i in 0..9 { + let bad_secret = if i % 2 == 0 { "badhex!" } else { "1122" }; + config + .access + .users + .insert(format!("bad_user_{}", i), bad_secret.to_string()); + } + let valid_secret_hex = "99999999999999999999999999999999"; + config + .access + .users + .insert("good_user".to_string(), valid_secret_hex.to_string()); + config.general.modes.secure = true; + config.general.modes.classic = true; + config.general.modes.tls = true; + + let secret = [0x99u8; 16]; + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "192.0.2.9:12345".parse().unwrap(); + + let valid_handshake = make_valid_tls_handshake(&secret, 0); + + let res = handle_tls_handshake( + &valid_handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!( + matches!(res, HandshakeResult::Success(_)), + "Proxy must gracefully skip invalid secrets and authenticate the valid one" + ); +} + +#[tokio::test] +async fn tls_emulation_fallback_when_cache_missing() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret = [0xAAu8; 16]; + let mut config = test_config_with_secret_hex("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"); + config.censorship.tls_emulation = true; + config.general.modes.tls = true; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "192.0.2.10:12345".parse().unwrap(); + + let valid_handshake = make_valid_tls_handshake(&secret, 0); + + let res = handle_tls_handshake( + &valid_handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!( + matches!(res, HandshakeResult::Success(_)), + "TLS emulation must gracefully fall back to standard ServerHello if cache is missing" + ); +} + +#[tokio::test] +async fn classic_mode_over_tls_transport_protocol_confusion() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret_hex = "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb"; + let mut config = test_config_with_secret_hex(secret_hex); + config.general.modes.classic = true; + config.general.modes.tls = true; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let peer: SocketAddr = "192.0.2.11:12345".parse().unwrap(); + + let handshake = make_valid_mtproto_handshake(secret_hex, ProtoTag::Intermediate, 1); + + let res = handle_mtproto_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + true, + None, + ) + .await; + + assert!( + matches!(res, HandshakeResult::Success(_)), + "Intermediate tag over TLS must succeed if classic mode is enabled, locking in cross-transport behavior" + ); +} + +#[test] +fn generate_tg_nonce_never_emits_reserved_bytes() { + let client_enc_key = [0xCCu8; 32]; + let client_enc_iv = 123456789u128; + let rng = SecureRandom::new(); + + for _ in 0..10_000 { + let (nonce, _, _, _, _) = generate_tg_nonce( + ProtoTag::Secure, + 1, + &client_enc_key, + client_enc_iv, + &rng, + false, + ); + + assert!( + !RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]), + "Nonce must never start with reserved bytes" + ); + let first_four: [u8; 4] = [nonce[0], nonce[1], nonce[2], nonce[3]]; + assert!( + !RESERVED_NONCE_BEGINNINGS.contains(&first_four), + "Nonce must never match reserved 4-byte beginnings" + ); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn dashmap_concurrent_saturation_stress() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let ip_a: IpAddr = "192.0.2.13".parse().unwrap(); + let ip_b: IpAddr = "198.51.100.13".parse().unwrap(); + let mut tasks = Vec::new(); + + for i in 0..100 { + let target_ip = if i % 2 == 0 { ip_a } else { ip_b }; + tasks.push(tokio::spawn(async move { + for _ in 0..50 { + auth_probe_record_failure(target_ip, Instant::now()); + } + })); + } + + for task in tasks { + task.await + .expect("Task panicked during concurrent DashMap stress"); + } + + assert!( + auth_probe_is_throttled_for_testing(ip_a), + "IP A must be throttled after concurrent stress" + ); + assert!( + auth_probe_is_throttled_for_testing(ip_b), + "IP B must be throttled after concurrent stress" + ); +} + +#[test] +fn prototag_invalid_bytes_fail_closed() { + let invalid_tags: [[u8; 4]; 5] = [ + [0, 0, 0, 0], + [0xFF, 0xFF, 0xFF, 0xFF], + [0xDE, 0xAD, 0xBE, 0xEF], + [0xDD, 0xDD, 0xDD, 0xDE], + [0x11, 0x22, 0x33, 0x44], + ]; + + for tag in invalid_tags { + assert_eq!( + ProtoTag::from_bytes(tag), + None, + "Invalid ProtoTag bytes {:?} must fail closed", + tag + ); + } +} + +#[test] +fn auth_probe_eviction_hash_collision_stress() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let state = auth_probe_state_map(); + let now = Instant::now(); + + for i in 0..10_000u32 { + let ip = IpAddr::V4(Ipv4Addr::new(10, 0, (i >> 8) as u8, (i & 0xFF) as u8)); + auth_probe_record_failure_with_state(state, ip, now); + } + + assert!( + state.len() <= AUTH_PROBE_TRACK_MAX_ENTRIES, + "Eviction logic must successfully bound the map size under heavy insertion stress" + ); +} + +#[test] +fn encrypt_tg_nonce_with_ciphers_advances_counter_correctly() { + let client_enc_key = [0xDDu8; 32]; + let client_enc_iv = 987654321u128; + let rng = SecureRandom::new(); + + let (nonce, _, _, _, _) = generate_tg_nonce( + ProtoTag::Secure, + 2, + &client_enc_key, + client_enc_iv, + &rng, + false, + ); + + let (_, mut returned_encryptor, _) = encrypt_tg_nonce_with_ciphers(&nonce); + let zeros = [0u8; 64]; + let returned_keystream = returned_encryptor.encrypt(&zeros); + + let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN]; + let mut expected_enc_key = [0u8; 32]; + expected_enc_key.copy_from_slice(&enc_key_iv[..KEY_LEN]); + let mut expected_enc_iv_arr = [0u8; IV_LEN]; + expected_enc_iv_arr.copy_from_slice(&enc_key_iv[KEY_LEN..]); + let expected_enc_iv = u128::from_be_bytes(expected_enc_iv_arr); + + let mut manual_encryptor = AesCtr::new(&expected_enc_key, expected_enc_iv); + + let mut manual_input = Vec::new(); + manual_input.extend_from_slice(&nonce); + manual_input.extend_from_slice(&zeros); + let manual_output = manual_encryptor.encrypt(&manual_input); + + assert_eq!( + returned_keystream, + &manual_output[64..128], + "encrypt_tg_nonce_with_ciphers must correctly advance the AES-CTR counter by exactly the nonce length" + ); +} diff --git a/src/proxy/tests/handshake_auth_probe_eviction_bias_security_tests.rs b/src/proxy/tests/handshake_auth_probe_eviction_bias_security_tests.rs new file mode 100644 index 0000000..77cea19 --- /dev/null +++ b/src/proxy/tests/handshake_auth_probe_eviction_bias_security_tests.rs @@ -0,0 +1,96 @@ +use super::*; +use std::collections::HashSet; +use std::net::{IpAddr, Ipv4Addr}; +use std::time::{Duration, Instant}; + +fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> { + auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +#[test] +fn adversarial_large_state_offsets_escape_first_scan_window() { + let _guard = auth_probe_test_guard(); + let base = Instant::now(); + let state_len = 65_536usize; + let scan_limit = 1_024usize; + + let mut saw_offset_outside_first_window = false; + for i in 0..8_192u64 { + let ip = IpAddr::V4(Ipv4Addr::new( + ((i >> 16) & 0xff) as u8, + ((i >> 8) & 0xff) as u8, + (i & 0xff) as u8, + ((i.wrapping_mul(131)) & 0xff) as u8, + )); + let now = base + Duration::from_nanos(i); + let start = auth_probe_scan_start_offset(ip, now, state_len, scan_limit); + if start >= scan_limit { + saw_offset_outside_first_window = true; + break; + } + } + + assert!( + saw_offset_outside_first_window, + "scan start offset must cover the full auth-probe state, not only the first scan window" + ); +} + +#[test] +fn stress_large_state_offsets_cover_many_scan_windows() { + let _guard = auth_probe_test_guard(); + let base = Instant::now(); + let state_len = 65_536usize; + let scan_limit = 1_024usize; + + let mut covered_windows = HashSet::new(); + for i in 0..16_384u64 { + let ip = IpAddr::V4(Ipv4Addr::new( + ((i >> 16) & 0xff) as u8, + ((i >> 8) & 0xff) as u8, + (i & 0xff) as u8, + ((i.wrapping_mul(17)) & 0xff) as u8, + )); + let now = base + Duration::from_micros(i); + let start = auth_probe_scan_start_offset(ip, now, state_len, scan_limit); + covered_windows.insert(start / scan_limit); + } + + assert!( + covered_windows.len() >= 16, + "eviction scan must not collapse to a tiny hot zone; covered windows={} out of {}", + covered_windows.len(), + state_len / scan_limit + ); +} + +#[test] +fn light_fuzz_offset_always_stays_inside_state_len() { + let _guard = auth_probe_test_guard(); + let mut seed = 0xC0FF_EE12_3456_789Au64; + let base = Instant::now(); + + for _ in 0..8_192usize { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let ip = IpAddr::V4(Ipv4Addr::new( + (seed >> 24) as u8, + (seed >> 16) as u8, + (seed >> 8) as u8, + seed as u8, + )); + let state_len = ((seed >> 16) as usize % 200_000).saturating_add(1); + let scan_limit = ((seed >> 40) as usize % 2_048).saturating_add(1); + let now = base + Duration::from_nanos(seed & 0x0fff); + let start = auth_probe_scan_start_offset(ip, now, state_len, scan_limit); + + assert!( + start < state_len, + "scan offset must stay inside state length" + ); + } +} diff --git a/src/proxy/tests/handshake_auth_probe_scan_budget_security_tests.rs b/src/proxy/tests/handshake_auth_probe_scan_budget_security_tests.rs new file mode 100644 index 0000000..c91a215 --- /dev/null +++ b/src/proxy/tests/handshake_auth_probe_scan_budget_security_tests.rs @@ -0,0 +1,99 @@ +use super::*; +use std::net::{IpAddr, Ipv4Addr}; +use std::time::{Duration, Instant}; + +fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> { + auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +#[test] +fn edge_zero_state_len_yields_zero_start_offset() { + let _guard = auth_probe_test_guard(); + let ip = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 44)); + let now = Instant::now(); + + assert_eq!( + auth_probe_scan_start_offset(ip, now, 0, 16), + 0, + "empty map must not produce non-zero scan offset" + ); +} + +#[test] +fn adversarial_large_state_must_allow_start_offset_outside_scan_budget_window() { + let _guard = auth_probe_test_guard(); + let base = Instant::now(); + let scan_limit = 16usize; + let state_len = 65_536usize; + + let mut saw_offset_outside_window = false; + for i in 0..2048u32 { + let ip = IpAddr::V4(Ipv4Addr::new( + 203, + ((i >> 16) & 0xff) as u8, + ((i >> 8) & 0xff) as u8, + (i & 0xff) as u8, + )); + let now = base + Duration::from_micros(i as u64); + let start = auth_probe_scan_start_offset(ip, now, state_len, scan_limit); + assert!( + start < state_len, + "start offset must stay within state length; start={start}, len={state_len}" + ); + if start >= scan_limit { + saw_offset_outside_window = true; + break; + } + } + + assert!( + saw_offset_outside_window, + "large-state eviction must sample beyond the first scan window" + ); +} + +#[test] +fn positive_state_smaller_than_scan_limit_caps_to_state_len() { + let _guard = auth_probe_test_guard(); + let ip = IpAddr::V4(Ipv4Addr::new(192, 0, 2, 17)); + let now = Instant::now(); + + for state_len in 1..32usize { + let start = auth_probe_scan_start_offset(ip, now, state_len, 64); + assert!( + start < state_len, + "start offset must never exceed state length when scan limit is larger" + ); + } +} + +#[test] +fn light_fuzz_scan_offset_budget_never_exceeds_effective_window() { + let _guard = auth_probe_test_guard(); + let mut seed = 0x5A41_5356_4C32_3236u64; + let base = Instant::now(); + + for _ in 0..4096 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let ip = IpAddr::V4(Ipv4Addr::new( + (seed >> 24) as u8, + (seed >> 16) as u8, + (seed >> 8) as u8, + seed as u8, + )); + let state_len = ((seed >> 8) as usize % 131_072).saturating_add(1); + let scan_limit = ((seed >> 32) as usize % 512).saturating_add(1); + let now = base + Duration::from_nanos(seed & 0xffff); + let start = auth_probe_scan_start_offset(ip, now, state_len, scan_limit); + + assert!( + start < state_len, + "scan offset must stay inside state length" + ); + } +} diff --git a/src/proxy/tests/handshake_auth_probe_scan_offset_stress_tests.rs b/src/proxy/tests/handshake_auth_probe_scan_offset_stress_tests.rs new file mode 100644 index 0000000..bf97990 --- /dev/null +++ b/src/proxy/tests/handshake_auth_probe_scan_offset_stress_tests.rs @@ -0,0 +1,116 @@ +use super::*; +use std::collections::HashSet; +use std::net::{IpAddr, Ipv4Addr}; +use std::time::{Duration, Instant}; + +fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> { + auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +#[test] +fn positive_same_ip_moving_time_yields_diverse_scan_offsets() { + let _guard = auth_probe_test_guard(); + let ip = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 77)); + let base = Instant::now(); + let mut uniq = HashSet::new(); + + for i in 0..512u64 { + let now = base + Duration::from_nanos(i); + let offset = auth_probe_scan_start_offset(ip, now, 65_536, 16); + uniq.insert(offset); + } + + assert!( + uniq.len() >= 256, + "offset randomization collapsed unexpectedly for same-ip moving-time samples (uniq={})", + uniq.len() + ); +} + +#[test] +fn adversarial_many_ips_same_time_spreads_offsets_without_bias_collapse() { + let _guard = auth_probe_test_guard(); + let now = Instant::now(); + let mut uniq = HashSet::new(); + + for i in 0..1024u32 { + let ip = IpAddr::V4(Ipv4Addr::new( + (i >> 16) as u8, + (i >> 8) as u8, + i as u8, + (255 - (i as u8)), + )); + uniq.insert(auth_probe_scan_start_offset(ip, now, 65_536, 16)); + } + + assert!( + uniq.len() >= 512, + "scan offset distribution collapsed unexpectedly across adversarial peer set (uniq={})", + uniq.len() + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_parallel_failure_churn_under_saturation_remains_capped_and_live() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let start = Instant::now(); + let mut workers = Vec::new(); + for worker in 0..8u8 { + workers.push(tokio::spawn(async move { + for i in 0..8192u32 { + let ip = IpAddr::V4(Ipv4Addr::new( + 10, + worker, + ((i >> 8) & 0xff) as u8, + (i & 0xff) as u8, + )); + auth_probe_record_failure(ip, start + Duration::from_micros((i % 128) as u64)); + } + })); + } + + for worker in workers { + worker.await.expect("saturation worker must not panic"); + } + + assert!( + auth_probe_state_map().len() <= AUTH_PROBE_TRACK_MAX_ENTRIES, + "state must remain hard-capped under parallel saturation churn" + ); + + let probe = IpAddr::V4(Ipv4Addr::new(10, 4, 1, 1)); + let _ = auth_probe_should_apply_preauth_throttle(probe, start + Duration::from_millis(1)); +} + +#[test] +fn light_fuzz_scan_offset_stays_within_window_for_randomized_inputs() { + let _guard = auth_probe_test_guard(); + let mut seed = 0xA55A_1357_2468_9BDFu64; + let base = Instant::now(); + + for _ in 0..8192 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let ip = IpAddr::V4(Ipv4Addr::new( + (seed >> 24) as u8, + (seed >> 16) as u8, + (seed >> 8) as u8, + seed as u8, + )); + let state_len = ((seed >> 8) as usize % 200_000).saturating_add(1); + let scan_limit = ((seed >> 40) as usize % 1024).saturating_add(1); + let now = base + Duration::from_nanos(seed & 0x1fff); + + let offset = auth_probe_scan_start_offset(ip, now, state_len, scan_limit); + assert!( + offset < state_len, + "scan offset must always remain inside state length" + ); + } +} diff --git a/src/proxy/tests/handshake_key_material_zeroization_security_tests.rs b/src/proxy/tests/handshake_key_material_zeroization_security_tests.rs new file mode 100644 index 0000000..7176b1c --- /dev/null +++ b/src/proxy/tests/handshake_key_material_zeroization_security_tests.rs @@ -0,0 +1,42 @@ +use super::*; + +fn handshake_source() -> &'static str { + include_str!("../handshake.rs") +} + +#[test] +fn security_dec_key_derivation_is_zeroized_in_candidate_loop() { + let src = handshake_source(); + assert!( + src.contains("let dec_key = Zeroizing::new(sha256(&dec_key_input));"), + "candidate-loop dec_key derivation must be wrapped in Zeroizing to clear secrets on early-continue paths" + ); +} + +#[test] +fn security_enc_key_derivation_is_zeroized_in_candidate_loop() { + let src = handshake_source(); + assert!( + src.contains("let enc_key = Zeroizing::new(sha256(&enc_key_input));"), + "candidate-loop enc_key derivation must be wrapped in Zeroizing to clear secrets on early-continue paths" + ); +} + +#[test] +fn security_aes_ctr_initialization_uses_zeroizing_references() { + let src = handshake_source(); + assert!( + src.contains("let mut decryptor = AesCtr::new(&dec_key, dec_iv);") + && src.contains("let encryptor = AesCtr::new(&enc_key, enc_iv);"), + "AES-CTR initialization must use Zeroizing key wrappers directly without creating extra plain key variables" + ); +} + +#[test] +fn security_success_struct_copies_out_of_zeroizing_wrappers() { + let src = handshake_source(); + assert!( + src.contains("dec_key: *dec_key,") && src.contains("enc_key: *enc_key,"), + "HandshakeSuccess construction must copy from Zeroizing wrappers so loop-local key material is dropped and zeroized" + ); +} diff --git a/src/proxy/tests/handshake_more_clever_tests.rs b/src/proxy/tests/handshake_more_clever_tests.rs new file mode 100644 index 0000000..9782469 --- /dev/null +++ b/src/proxy/tests/handshake_more_clever_tests.rs @@ -0,0 +1,686 @@ +use super::*; +use crate::crypto::{AesCtr, sha256, sha256_hmac}; +use crate::protocol::constants::{ProtoTag, RESERVED_NONCE_BEGINNINGS, RESERVED_NONCE_FIRST_BYTES}; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; +use std::collections::HashSet; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tokio::sync::Barrier; + +// --- Helpers --- + +fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> { + auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig { + let mut cfg = ProxyConfig::default(); + cfg.access.users.clear(); + cfg.access + .users + .insert("user".to_string(), secret_hex.to_string()); + cfg.access.ignore_time_skew = true; + cfg.general.modes.secure = true; + cfg.general.modes.classic = true; + cfg.general.modes.tls = true; + cfg +} + +fn make_valid_tls_handshake(secret: &[u8], timestamp: u32) -> Vec { + let session_id_len: usize = 32; + let len = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + session_id_len; + let mut handshake = vec![0x42u8; len]; + + handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8; + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + + let computed = sha256_hmac(secret, &handshake); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] + .copy_from_slice(&digest); + handshake +} + +fn make_valid_mtproto_handshake( + secret_hex: &str, + proto_tag: ProtoTag, + dc_idx: i16, +) -> [u8; HANDSHAKE_LEN] { + let secret = hex::decode(secret_hex).expect("secret hex must decode"); + let mut handshake = [0x5Au8; HANDSHAKE_LEN]; + for (idx, b) in handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN] + .iter_mut() + .enumerate() + { + *b = (idx as u8).wrapping_add(1); + } + + let dec_prekey = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN]; + let dec_iv_bytes = &handshake[SKIP_LEN + PREKEY_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]; + + let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len()); + dec_key_input.extend_from_slice(dec_prekey); + dec_key_input.extend_from_slice(&secret); + let dec_key = sha256(&dec_key_input); + + let mut dec_iv_arr = [0u8; IV_LEN]; + dec_iv_arr.copy_from_slice(dec_iv_bytes); + let dec_iv = u128::from_be_bytes(dec_iv_arr); + + let mut stream = AesCtr::new(&dec_key, dec_iv); + let keystream = stream.encrypt(&[0u8; HANDSHAKE_LEN]); + + let mut target_plain = [0u8; HANDSHAKE_LEN]; + target_plain[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes()); + target_plain[DC_IDX_POS..DC_IDX_POS + 2].copy_from_slice(&dc_idx.to_le_bytes()); + + for idx in PROTO_TAG_POS..HANDSHAKE_LEN { + handshake[idx] = target_plain[idx] ^ keystream[idx]; + } + + handshake +} + +fn make_valid_tls_client_hello_with_sni_and_alpn( + secret: &[u8], + timestamp: u32, + sni_host: &str, + alpn_protocols: &[&[u8]], +) -> Vec { + let mut body = Vec::new(); + body.extend_from_slice(&TLS_VERSION); + body.extend_from_slice(&[0u8; 32]); + body.push(32); + body.extend_from_slice(&[0x42u8; 32]); + body.extend_from_slice(&2u16.to_be_bytes()); + body.extend_from_slice(&[0x13, 0x01]); + body.push(1); + body.push(0); + + let mut ext_blob = Vec::new(); + + let host_bytes = sni_host.as_bytes(); + let mut sni_payload = Vec::new(); + sni_payload.extend_from_slice(&((host_bytes.len() + 3) as u16).to_be_bytes()); + sni_payload.push(0); + sni_payload.extend_from_slice(&(host_bytes.len() as u16).to_be_bytes()); + sni_payload.extend_from_slice(host_bytes); + ext_blob.extend_from_slice(&0x0000u16.to_be_bytes()); + ext_blob.extend_from_slice(&(sni_payload.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&sni_payload); + + if !alpn_protocols.is_empty() { + let mut alpn_list = Vec::new(); + for proto in alpn_protocols { + alpn_list.push(proto.len() as u8); + alpn_list.extend_from_slice(proto); + } + let mut alpn_data = Vec::new(); + alpn_data.extend_from_slice(&(alpn_list.len() as u16).to_be_bytes()); + alpn_data.extend_from_slice(&alpn_list); + + ext_blob.extend_from_slice(&0x0010u16.to_be_bytes()); + ext_blob.extend_from_slice(&(alpn_data.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&alpn_data); + } + + body.extend_from_slice(&(ext_blob.len() as u16).to_be_bytes()); + body.extend_from_slice(&ext_blob); + + let mut handshake = Vec::new(); + handshake.push(0x01); + let body_len = (body.len() as u32).to_be_bytes(); + handshake.extend_from_slice(&body_len[1..4]); + handshake.extend_from_slice(&body); + + let mut record = Vec::new(); + record.push(TLS_RECORD_HANDSHAKE); + record.extend_from_slice(&[0x03, 0x01]); + record.extend_from_slice(&(handshake.len() as u16).to_be_bytes()); + record.extend_from_slice(&handshake); + + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + let computed = sha256_hmac(secret, &record); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].copy_from_slice(&digest); + + record +} + +// --- Category 1: Timing & Delay Invariants --- + +#[tokio::test] +async fn server_hello_delay_bypassed_if_max_is_zero_despite_high_min() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret = [0x1Au8; 16]; + let mut config = test_config_with_secret_hex("1a1a1a1a1a1a1a1a1a1a1a1a1a1a1a1a"); + config.censorship.server_hello_delay_min_ms = 5000; + config.censorship.server_hello_delay_max_ms = 0; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "192.0.2.101:12345".parse().unwrap(); + + let mut invalid_handshake = make_valid_tls_handshake(&secret, 0); + invalid_handshake[tls::TLS_DIGEST_POS] ^= 0xFF; + + let fut = handle_tls_handshake( + &invalid_handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ); + + // Deterministic assertion: with max_ms == 0 there must be no sleep path, + // so the handshake should complete promptly under a generous timeout budget. + let res = tokio::time::timeout(Duration::from_millis(250), fut) + .await + .expect("max_ms=0 should bypass artificial delay and complete quickly"); + + assert!(matches!(res, HandshakeResult::BadClient { .. })); +} + +#[test] +fn auth_probe_backoff_extreme_fail_streak_clamps_safely() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let state = auth_probe_state_map(); + let peer_ip = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 99)); + let now = Instant::now(); + + state.insert( + peer_ip, + AuthProbeState { + fail_streak: u32::MAX - 1, + blocked_until: now, + last_seen: now, + }, + ); + + auth_probe_record_failure_with_state(&state, peer_ip, now); + + let updated = state.get(&peer_ip).unwrap(); + assert_eq!(updated.fail_streak, u32::MAX); + + let expected_blocked_until = now + Duration::from_millis(AUTH_PROBE_BACKOFF_MAX_MS); + assert_eq!( + updated.blocked_until, expected_blocked_until, + "Extreme fail streak must clamp cleanly to AUTH_PROBE_BACKOFF_MAX_MS" + ); +} + +#[test] +fn generate_tg_nonce_cryptographic_uniqueness_and_entropy() { + let client_enc_key = [0x2Bu8; 32]; + let client_enc_iv = 1337u128; + let rng = SecureRandom::new(); + + let mut nonces = HashSet::new(); + let mut total_set_bits = 0usize; + let iterations = 5_000; + + for _ in 0..iterations { + let (nonce, _, _, _, _) = generate_tg_nonce( + ProtoTag::Secure, + 2, + &client_enc_key, + client_enc_iv, + &rng, + false, + ); + + for byte in nonce.iter() { + total_set_bits += byte.count_ones() as usize; + } + + assert!( + nonces.insert(nonce), + "generate_tg_nonce emitted a duplicate nonce! RNG is stuck." + ); + } + + let total_bits = iterations * HANDSHAKE_LEN * 8; + let ratio = (total_set_bits as f64) / (total_bits as f64); + assert!( + ratio > 0.48 && ratio < 0.52, + "Nonce entropy is degraded. Set bit ratio: {}", + ratio + ); +} + +#[tokio::test] +async fn mtproto_multi_user_decryption_isolation() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let mut config = ProxyConfig::default(); + config.general.modes.secure = true; + config.access.ignore_time_skew = true; + + config.access.users.insert( + "user_a".to_string(), + "11111111111111111111111111111111".to_string(), + ); + config.access.users.insert( + "user_b".to_string(), + "22222222222222222222222222222222".to_string(), + ); + let good_secret_hex = "33333333333333333333333333333333"; + config + .access + .users + .insert("user_c".to_string(), good_secret_hex.to_string()); + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let peer: SocketAddr = "192.0.2.104:12345".parse().unwrap(); + + let valid_handshake = make_valid_mtproto_handshake(good_secret_hex, ProtoTag::Secure, 1); + + let res = handle_mtproto_handshake( + &valid_handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + + match res { + HandshakeResult::Success((_, _, success)) => { + assert_eq!( + success.user, "user_c", + "Decryption attempts on previous users must not corrupt the handshake buffer for the valid user" + ); + } + _ => panic!( + "Multi-user MTProto handshake failed. Decryption buffer might be mutating in place." + ), + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn invalid_secret_warning_lock_contention_and_bound() { + let _guard = warned_secrets_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_warned_secrets_for_testing(); + + let tasks = 50; + let iterations_per_task = 100; + let barrier = Arc::new(Barrier::new(tasks)); + let mut handles = Vec::new(); + + for t in 0..tasks { + let b = barrier.clone(); + handles.push(tokio::spawn(async move { + b.wait().await; + for i in 0..iterations_per_task { + let user_name = format!("contention_user_{}_{}", t, i); + warn_invalid_secret_once(&user_name, "invalid_hex", ACCESS_SECRET_BYTES, None); + } + })); + } + + for handle in handles { + handle.await.unwrap(); + } + + let warned = INVALID_SECRET_WARNED.get().unwrap(); + let guard = warned + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + + assert_eq!( + guard.len(), + WARNED_SECRET_MAX_ENTRIES, + "Concurrent spam of invalid secrets must strictly bound the HashSet memory to WARNED_SECRET_MAX_ENTRIES" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn mtproto_strict_concurrent_replay_race_condition() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret_hex = "4A4A4A4A4A4A4A4A4A4A4A4A4A4A4A4A"; + let config = Arc::new(test_config_with_secret_hex(secret_hex)); + let replay_checker = Arc::new(ReplayChecker::new(4096, Duration::from_secs(60))); + let valid_handshake = Arc::new(make_valid_mtproto_handshake( + secret_hex, + ProtoTag::Secure, + 1, + )); + + let tasks = 100; + let barrier = Arc::new(Barrier::new(tasks)); + let mut handles = Vec::new(); + + for i in 0..tasks { + let b = barrier.clone(); + let cfg = config.clone(); + let rc = replay_checker.clone(); + let hs = valid_handshake.clone(); + + handles.push(tokio::spawn(async move { + let peer = SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(10, 0, 0, (i % 250) as u8)), + 10000 + i as u16, + ); + b.wait().await; + handle_mtproto_handshake( + &hs, + tokio::io::empty(), + tokio::io::sink(), + peer, + &cfg, + &rc, + false, + None, + ) + .await + })); + } + + let mut successes = 0; + let mut failures = 0; + + for handle in handles { + match handle.await.unwrap() { + HandshakeResult::Success(_) => successes += 1, + HandshakeResult::BadClient { .. } => failures += 1, + _ => panic!("Unexpected error result in concurrent MTProto replay test"), + } + } + + assert_eq!( + successes, 1, + "Replay cache race condition allowed multiple identical MTProto handshakes to succeed" + ); + assert_eq!( + failures, + tasks - 1, + "Replay cache failed to forcefully reject concurrent duplicates" + ); +} + +#[tokio::test] +async fn tls_alpn_zero_length_protocol_handled_safely() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret = [0x5Bu8; 16]; + let mut config = test_config_with_secret_hex("5b5b5b5b5b5b5b5b5b5b5b5b5b5b5b5b"); + config.censorship.alpn_enforce = true; + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "192.0.2.107:12345".parse().unwrap(); + + let handshake = + make_valid_tls_client_hello_with_sni_and_alpn(&secret, 0, "example.com", &[b""]); + + let res = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!( + matches!(res, HandshakeResult::BadClient { .. }), + "0-length ALPN must be safely rejected without panicking" + ); +} + +#[tokio::test] +async fn tls_sni_massive_hostname_does_not_panic() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret = [0x6Cu8; 16]; + let config = test_config_with_secret_hex("6c6c6c6c6c6c6c6c6c6c6c6c6c6c6c6c"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "192.0.2.108:12345".parse().unwrap(); + + let massive_hostname = String::from_utf8(vec![b'a'; 65000]).unwrap(); + let handshake = + make_valid_tls_client_hello_with_sni_and_alpn(&secret, 0, &massive_hostname, &[]); + + let res = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!( + matches!( + res, + HandshakeResult::Success(_) | HandshakeResult::BadClient { .. } + ), + "Massive SNI hostname must be processed or ignored without stack overflow or panic" + ); +} + +#[tokio::test] +async fn tls_progressive_truncation_fuzzing_no_panics() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret = [0x7Du8; 16]; + let config = test_config_with_secret_hex("7d7d7d7d7d7d7d7d7d7d7d7d7d7d7d7d"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "192.0.2.109:12345".parse().unwrap(); + + let valid_handshake = + make_valid_tls_client_hello_with_sni_and_alpn(&secret, 0, "example.com", &[b"h2"]); + let full_len = valid_handshake.len(); + + // Truncated corpus only: full_len is a valid baseline and should not be + // asserted as BadClient in a truncation-specific test. + for i in (0..full_len).rev() { + let truncated = &valid_handshake[..i]; + let res = handle_tls_handshake( + truncated, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!( + matches!(res, HandshakeResult::BadClient { .. }), + "Truncated TLS handshake at len {} must fail safely without panicking", + i + ); + } +} + +#[tokio::test] +async fn mtproto_pure_entropy_fuzzing_no_panics() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let config = test_config_with_secret_hex("8e8e8e8e8e8e8e8e8e8e8e8e8e8e8e8e"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let peer: SocketAddr = "192.0.2.110:12345".parse().unwrap(); + + let mut seeded = StdRng::seed_from_u64(0xDEADBEEFCAFE); + + for _ in 0..10_000 { + let mut noise = [0u8; HANDSHAKE_LEN]; + seeded.fill_bytes(&mut noise); + + let res = handle_mtproto_handshake( + &noise, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + + assert!( + matches!(res, HandshakeResult::BadClient { .. }), + "Pure entropy MTProto payload must fail closed and never panic" + ); + } +} + +#[test] +fn decode_user_secret_odd_length_hex_rejection() { + let _guard = warned_secrets_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_warned_secrets_for_testing(); + + let mut config = ProxyConfig::default(); + config.access.users.clear(); + config.access.users.insert( + "odd_user".to_string(), + "1234567890123456789012345678901".to_string(), + ); + + let decoded = decode_user_secrets(&config, None); + assert!( + decoded.is_empty(), + "Odd-length hex string must be gracefully rejected by hex::decode without unwrapping" + ); +} + +#[test] +fn saturation_grace_pre_existing_high_fail_streak_immediate_throttle() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let state = auth_probe_state_map(); + let peer_ip = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 112)); + let now = Instant::now(); + + let extreme_streak = AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS + 5; + state.insert( + peer_ip, + AuthProbeState { + fail_streak: extreme_streak, + blocked_until: now + Duration::from_secs(5), + last_seen: now, + }, + ); + + { + let mut guard = auth_probe_saturation_state_lock(); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(5), + last_seen: now, + }); + } + + let is_throttled = auth_probe_should_apply_preauth_throttle(peer_ip, now); + assert!( + is_throttled, + "A peer with a pre-existing high fail streak must be immediately throttled when saturation begins, receiving no unearned grace period" + ); +} + +#[test] +fn auth_probe_saturation_note_resets_retention_window() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let base_time = Instant::now(); + + auth_probe_note_saturation(base_time); + let later = base_time + Duration::from_secs(AUTH_PROBE_TRACK_RETENTION_SECS - 1); + auth_probe_note_saturation(later); + + let check_time = base_time + Duration::from_secs(AUTH_PROBE_TRACK_RETENTION_SECS + 5); + + // This call may return false if backoff has elapsed, but it must not clear + // the saturation state because `later` refreshed last_seen. + let _ = auth_probe_saturation_is_throttled_at_for_testing(check_time); + let guard = auth_probe_saturation_state_lock(); + assert!( + guard.is_some(), + "Ongoing saturation notes must refresh last_seen so saturation state remains retained past the original window" + ); +} + +#[test] +fn mtproto_classic_tags_rejected_when_only_secure_mode_enabled() { + let mut config = ProxyConfig::default(); + config.general.modes.classic = false; + config.general.modes.secure = true; + config.general.modes.tls = false; + + assert!(!mode_enabled_for_proto(&config, ProtoTag::Abridged, false)); + assert!(!mode_enabled_for_proto( + &config, + ProtoTag::Intermediate, + false + )); +} + +#[test] +fn mtproto_secure_tag_rejected_when_only_classic_mode_enabled() { + let mut config = ProxyConfig::default(); + config.general.modes.classic = true; + config.general.modes.secure = false; + config.general.modes.tls = false; + + assert!(!mode_enabled_for_proto(&config, ProtoTag::Secure, false)); +} + +#[test] +fn ipv6_localhost_and_unspecified_normalization() { + let localhost = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)); + let unspecified = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)); + + let norm_local = normalize_auth_probe_ip(localhost); + let norm_unspec = normalize_auth_probe_ip(unspecified); + + let expected_bucket = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)); + + assert_eq!(norm_local, expected_bucket); + assert_eq!(norm_unspec, expected_bucket); +} diff --git a/src/proxy/tests/handshake_real_bug_stress_tests.rs b/src/proxy/tests/handshake_real_bug_stress_tests.rs new file mode 100644 index 0000000..1e27ed5 --- /dev/null +++ b/src/proxy/tests/handshake_real_bug_stress_tests.rs @@ -0,0 +1,340 @@ +use super::*; +use crate::crypto::{AesCtr, SecureRandom, sha256, sha256_hmac}; +use crate::protocol::constants::{ProtoTag, TLS_RECORD_HANDSHAKE, TLS_VERSION}; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tokio::sync::Barrier; + +fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> { + auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig { + let mut cfg = ProxyConfig::default(); + cfg.access.users.clear(); + cfg.access + .users + .insert("user".to_string(), secret_hex.to_string()); + cfg.access.ignore_time_skew = true; + cfg.general.modes.secure = true; + cfg.general.modes.classic = true; + cfg.general.modes.tls = true; + cfg +} + +fn make_valid_tls_client_hello_with_alpn( + secret: &[u8], + timestamp: u32, + alpn_protocols: &[&[u8]], +) -> Vec { + let mut body = Vec::new(); + body.extend_from_slice(&TLS_VERSION); + body.extend_from_slice(&[0u8; 32]); + body.push(32); + body.extend_from_slice(&[0x42u8; 32]); + body.extend_from_slice(&2u16.to_be_bytes()); + body.extend_from_slice(&[0x13, 0x01]); + body.push(1); + body.push(0); + + let mut ext_blob = Vec::new(); + if !alpn_protocols.is_empty() { + let mut alpn_list = Vec::new(); + for proto in alpn_protocols { + alpn_list.push(proto.len() as u8); + alpn_list.extend_from_slice(proto); + } + let mut alpn_data = Vec::new(); + alpn_data.extend_from_slice(&(alpn_list.len() as u16).to_be_bytes()); + alpn_data.extend_from_slice(&alpn_list); + + ext_blob.extend_from_slice(&0x0010u16.to_be_bytes()); + ext_blob.extend_from_slice(&(alpn_data.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&alpn_data); + } + + body.extend_from_slice(&(ext_blob.len() as u16).to_be_bytes()); + body.extend_from_slice(&ext_blob); + + let mut handshake = Vec::new(); + handshake.push(0x01); + let body_len = (body.len() as u32).to_be_bytes(); + handshake.extend_from_slice(&body_len[1..4]); + handshake.extend_from_slice(&body); + + let mut record = Vec::new(); + record.push(TLS_RECORD_HANDSHAKE); + record.extend_from_slice(&[0x03, 0x01]); + record.extend_from_slice(&(handshake.len() as u16).to_be_bytes()); + record.extend_from_slice(&handshake); + + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + + let computed = sha256_hmac(secret, &record); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].copy_from_slice(&digest); + + record +} + +fn make_valid_mtproto_handshake( + secret_hex: &str, + proto_tag: ProtoTag, + dc_idx: i16, +) -> [u8; HANDSHAKE_LEN] { + let secret = hex::decode(secret_hex).expect("secret hex must decode"); + let mut handshake = [0x5Au8; HANDSHAKE_LEN]; + for (idx, b) in handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN] + .iter_mut() + .enumerate() + { + *b = (idx as u8).wrapping_add(1); + } + + let dec_prekey = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN]; + let dec_iv_bytes = &handshake[SKIP_LEN + PREKEY_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]; + + let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len()); + dec_key_input.extend_from_slice(dec_prekey); + dec_key_input.extend_from_slice(&secret); + let dec_key = sha256(&dec_key_input); + + let mut dec_iv_arr = [0u8; IV_LEN]; + dec_iv_arr.copy_from_slice(dec_iv_bytes); + let dec_iv = u128::from_be_bytes(dec_iv_arr); + + let mut stream = AesCtr::new(&dec_key, dec_iv); + let keystream = stream.encrypt(&[0u8; HANDSHAKE_LEN]); + + let mut target_plain = [0u8; HANDSHAKE_LEN]; + target_plain[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes()); + target_plain[DC_IDX_POS..DC_IDX_POS + 2].copy_from_slice(&dc_idx.to_le_bytes()); + + for idx in PROTO_TAG_POS..HANDSHAKE_LEN { + handshake[idx] = target_plain[idx] ^ keystream[idx]; + } + + handshake +} + +#[tokio::test] +async fn tls_alpn_reject_does_not_pollute_replay_cache() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret = [0x11u8; 16]; + let mut config = test_config_with_secret_hex("11111111111111111111111111111111"); + config.censorship.alpn_enforce = true; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "192.0.2.201:12345".parse().unwrap(); + + let handshake = make_valid_tls_client_hello_with_alpn(&secret, 0, &[b"h3"]); + let before = replay_checker.stats(); + + let res = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + let after = replay_checker.stats(); + + assert!(matches!(res, HandshakeResult::BadClient { .. })); + assert_eq!( + before.total_additions, after.total_additions, + "ALPN policy reject must not add TLS digest into replay cache" + ); +} + +#[tokio::test] +async fn tls_truncated_session_id_len_fails_closed_without_panic() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let config = test_config_with_secret_hex("33333333333333333333333333333333"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "192.0.2.203:12345".parse().unwrap(); + + let min_len = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1; + let mut malicious = vec![0x42u8; min_len]; + malicious[min_len - 1] = u8::MAX; + + let res = handle_tls_handshake( + &malicious, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(res, HandshakeResult::BadClient { .. })); +} + +#[test] +fn auth_probe_eviction_identical_timestamps_keeps_map_bounded() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let state = auth_probe_state_map(); + let same = Instant::now(); + + for i in 0..AUTH_PROBE_TRACK_MAX_ENTRIES { + let ip = IpAddr::V4(Ipv4Addr::new(10, 1, (i >> 8) as u8, (i & 0xFF) as u8)); + state.insert( + ip, + AuthProbeState { + fail_streak: 7, + blocked_until: same, + last_seen: same, + }, + ); + } + + let new_ip = IpAddr::V4(Ipv4Addr::new(192, 168, 21, 21)); + auth_probe_record_failure_with_state(state, new_ip, same + Duration::from_millis(1)); + + assert_eq!(state.len(), AUTH_PROBE_TRACK_MAX_ENTRIES); + assert!(state.contains_key(&new_ip)); +} + +#[test] +fn clear_auth_probe_state_recovers_from_poisoned_saturation_lock() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let saturation = auth_probe_saturation_state(); + let poison_thread = std::thread::spawn(move || { + let _hold = saturation + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + panic!("intentional poison for regression coverage"); + }); + let _ = poison_thread.join(); + + clear_auth_probe_state_for_testing(); + + let guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + assert!(guard.is_none()); +} + +#[tokio::test] +async fn mtproto_invalid_length_secret_is_ignored_and_valid_user_still_auths() { + let _probe_guard = auth_probe_test_guard(); + let _warn_guard = warned_secrets_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + clear_warned_secrets_for_testing(); + + let mut config = ProxyConfig::default(); + config.general.modes.secure = true; + config.access.ignore_time_skew = true; + + config.access.users.insert( + "short_user".to_string(), + "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaa".to_string(), + ); + + let valid_secret_hex = "77777777777777777777777777777777"; + config + .access + .users + .insert("good_user".to_string(), valid_secret_hex.to_string()); + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let peer: SocketAddr = "192.0.2.207:12345".parse().unwrap(); + let handshake = make_valid_mtproto_handshake(valid_secret_hex, ProtoTag::Secure, 1); + + let res = handle_mtproto_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + + assert!(matches!(res, HandshakeResult::Success(_))); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn saturation_grace_exhaustion_under_concurrency_keeps_peer_throttled() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let peer_ip = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 80)); + let now = Instant::now(); + + { + let mut guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(5), + last_seen: now, + }); + } + + let state = auth_probe_state_map(); + state.insert( + peer_ip, + AuthProbeState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS - 1, + blocked_until: now, + last_seen: now, + }, + ); + + let tasks = 32; + let barrier = Arc::new(Barrier::new(tasks)); + let mut handles = Vec::new(); + + for _ in 0..tasks { + let b = barrier.clone(); + handles.push(tokio::spawn(async move { + b.wait().await; + auth_probe_record_failure(peer_ip, Instant::now()); + })); + } + + for handle in handles { + handle.await.unwrap(); + } + + let final_state = state.get(&peer_ip).expect("state must exist"); + assert!( + final_state.fail_streak + >= AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS + ); + assert!(auth_probe_should_apply_preauth_throttle( + peer_ip, + Instant::now() + )); +} diff --git a/src/proxy/tests/handshake_security_tests.rs b/src/proxy/tests/handshake_security_tests.rs index d06f63e..6796c5c 100644 --- a/src/proxy/tests/handshake_security_tests.rs +++ b/src/proxy/tests/handshake_security_tests.rs @@ -956,6 +956,89 @@ async fn stress_tls_sni_preferred_user_hint_scales_to_large_user_set() { } } +#[tokio::test] +async fn tls_unknown_sni_drop_policy_returns_hard_error() { + let secret = [0x48u8; 16]; + let mut config = test_config_with_secret_hex("48484848484848484848484848484848"); + config.censorship.unknown_sni_action = UnknownSniAction::Drop; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.190:44326".parse().unwrap(); + let handshake = + make_valid_tls_client_hello_with_sni_and_alpn(&secret, 0, "unknown.example", &[b"h2"]); + + let result = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!( + result, + HandshakeResult::Error(ProxyError::UnknownTlsSni) + )); +} + +#[tokio::test] +async fn tls_unknown_sni_mask_policy_falls_back_to_bad_client() { + let secret = [0x49u8; 16]; + let mut config = test_config_with_secret_hex("49494949494949494949494949494949"); + config.censorship.unknown_sni_action = UnknownSniAction::Mask; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.191:44326".parse().unwrap(); + let handshake = + make_valid_tls_client_hello_with_sni_and_alpn(&secret, 0, "unknown.example", &[b"h2"]); + + let result = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(result, HandshakeResult::BadClient { .. })); +} + +#[tokio::test] +async fn tls_missing_sni_keeps_legacy_auth_path() { + let secret = [0x4Au8; 16]; + let mut config = test_config_with_secret_hex("4a4a4a4a4a4a4a4a4a4a4a4a4a4a4a4a"); + config.censorship.unknown_sni_action = UnknownSniAction::Drop; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.192:44326".parse().unwrap(); + let handshake = make_valid_tls_handshake(&secret, 0); + + let result = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(result, HandshakeResult::Success(_))); +} + #[tokio::test] async fn alpn_enforce_rejects_unsupported_client_alpn() { let secret = [0x33u8; 16]; diff --git a/src/proxy/tests/handshake_timing_manual_bench_tests.rs b/src/proxy/tests/handshake_timing_manual_bench_tests.rs new file mode 100644 index 0000000..13d112c --- /dev/null +++ b/src/proxy/tests/handshake_timing_manual_bench_tests.rs @@ -0,0 +1,318 @@ +use super::*; +use crate::crypto::{AesCtr, SecureRandom, sha256, sha256_hmac}; +use crate::protocol::constants::{ProtoTag, TLS_RECORD_HANDSHAKE, TLS_VERSION}; +use std::net::SocketAddr; +use std::time::{Duration, Instant}; + +fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> { + auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +fn make_valid_mtproto_handshake( + secret_hex: &str, + proto_tag: ProtoTag, + dc_idx: i16, + salt: u8, +) -> [u8; HANDSHAKE_LEN] { + let secret = hex::decode(secret_hex).expect("secret hex must decode"); + let mut handshake = [0x5Au8; HANDSHAKE_LEN]; + + for (idx, b) in handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN] + .iter_mut() + .enumerate() + { + *b = (idx as u8).wrapping_add(1).wrapping_add(salt); + } + + let dec_prekey = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN]; + let dec_iv_bytes = &handshake[SKIP_LEN + PREKEY_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]; + + let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len()); + dec_key_input.extend_from_slice(dec_prekey); + dec_key_input.extend_from_slice(&secret); + let dec_key = sha256(&dec_key_input); + + let mut dec_iv_arr = [0u8; IV_LEN]; + dec_iv_arr.copy_from_slice(dec_iv_bytes); + let dec_iv = u128::from_be_bytes(dec_iv_arr); + + let mut stream = AesCtr::new(&dec_key, dec_iv); + let keystream = stream.encrypt(&[0u8; HANDSHAKE_LEN]); + + let mut target_plain = [0u8; HANDSHAKE_LEN]; + target_plain[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes()); + target_plain[DC_IDX_POS..DC_IDX_POS + 2].copy_from_slice(&dc_idx.to_le_bytes()); + + for idx in PROTO_TAG_POS..HANDSHAKE_LEN { + handshake[idx] = target_plain[idx] ^ keystream[idx]; + } + + handshake +} + +fn make_valid_tls_handshake(secret: &[u8], timestamp: u32) -> Vec { + let session_id_len: usize = 32; + let len = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + session_id_len; + let mut handshake = vec![0x42u8; len]; + + handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8; + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + + let computed = sha256_hmac(secret, &handshake); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] + .copy_from_slice(&digest); + handshake +} + +fn make_valid_tls_client_hello_with_sni_and_alpn( + secret: &[u8], + timestamp: u32, + sni_host: &str, + alpn_protocols: &[&[u8]], +) -> Vec { + let mut body = Vec::new(); + body.extend_from_slice(&TLS_VERSION); + body.extend_from_slice(&[0u8; 32]); + body.push(32); + body.extend_from_slice(&[0x42u8; 32]); + body.extend_from_slice(&2u16.to_be_bytes()); + body.extend_from_slice(&[0x13, 0x01]); + body.push(1); + body.push(0); + + let mut ext_blob = Vec::new(); + let host_bytes = sni_host.as_bytes(); + let mut sni_payload = Vec::new(); + sni_payload.extend_from_slice(&((host_bytes.len() + 3) as u16).to_be_bytes()); + sni_payload.push(0); + sni_payload.extend_from_slice(&(host_bytes.len() as u16).to_be_bytes()); + sni_payload.extend_from_slice(host_bytes); + ext_blob.extend_from_slice(&0x0000u16.to_be_bytes()); + ext_blob.extend_from_slice(&(sni_payload.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&sni_payload); + + if !alpn_protocols.is_empty() { + let mut alpn_list = Vec::new(); + for proto in alpn_protocols { + alpn_list.push(proto.len() as u8); + alpn_list.extend_from_slice(proto); + } + let mut alpn_data = Vec::new(); + alpn_data.extend_from_slice(&(alpn_list.len() as u16).to_be_bytes()); + alpn_data.extend_from_slice(&alpn_list); + + ext_blob.extend_from_slice(&0x0010u16.to_be_bytes()); + ext_blob.extend_from_slice(&(alpn_data.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&alpn_data); + } + + body.extend_from_slice(&(ext_blob.len() as u16).to_be_bytes()); + body.extend_from_slice(&ext_blob); + + let mut handshake = Vec::new(); + handshake.push(0x01); + let body_len = (body.len() as u32).to_be_bytes(); + handshake.extend_from_slice(&body_len[1..4]); + handshake.extend_from_slice(&body); + + let mut record = Vec::new(); + record.push(TLS_RECORD_HANDSHAKE); + record.extend_from_slice(&[0x03, 0x01]); + record.extend_from_slice(&(handshake.len() as u16).to_be_bytes()); + record.extend_from_slice(&handshake); + + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + let computed = sha256_hmac(secret, &record); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].copy_from_slice(&digest); + + record +} + +fn median_ns(samples: &mut [u128]) -> u128 { + samples.sort_unstable(); + samples[samples.len() / 2] +} + +#[tokio::test] +#[ignore = "manual benchmark: timing-sensitive and host-dependent"] +async fn mtproto_user_scan_timing_manual_benchmark() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + const DECOY_USERS: usize = 8_000; + const ITERATIONS: usize = 250; + + let preferred_user = "target_user"; + let target_secret_hex = "dededededededededededededededede"; + + let mut config = ProxyConfig::default(); + config.general.modes.secure = true; + config.access.ignore_time_skew = true; + + for i in 0..DECOY_USERS { + config.access.users.insert( + format!("decoy_{i}"), + "00000000000000000000000000000000".to_string(), + ); + } + + config + .access + .users + .insert(preferred_user.to_string(), target_secret_hex.to_string()); + + let replay_checker_preferred = ReplayChecker::new(65_536, Duration::from_secs(60)); + let replay_checker_full_scan = ReplayChecker::new(65_536, Duration::from_secs(60)); + let peer_a: SocketAddr = "192.0.2.241:12345".parse().unwrap(); + let peer_b: SocketAddr = "192.0.2.242:12345".parse().unwrap(); + + let mut preferred_samples = Vec::with_capacity(ITERATIONS); + let mut full_scan_samples = Vec::with_capacity(ITERATIONS); + + for i in 0..ITERATIONS { + let handshake = make_valid_mtproto_handshake( + target_secret_hex, + ProtoTag::Secure, + 1 + i as i16, + (i % 251) as u8, + ); + + let started_preferred = Instant::now(); + let preferred = handle_mtproto_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer_a, + &config, + &replay_checker_preferred, + false, + Some(preferred_user), + ) + .await; + preferred_samples.push(started_preferred.elapsed().as_nanos()); + assert!(matches!(preferred, HandshakeResult::Success(_))); + + let started_scan = Instant::now(); + let full_scan = handle_mtproto_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer_b, + &config, + &replay_checker_full_scan, + false, + None, + ) + .await; + full_scan_samples.push(started_scan.elapsed().as_nanos()); + assert!(matches!(full_scan, HandshakeResult::Success(_))); + } + + let preferred_median = median_ns(&mut preferred_samples); + let full_scan_median = median_ns(&mut full_scan_samples); + + let ratio = if preferred_median == 0 { + 0.0 + } else { + full_scan_median as f64 / preferred_median as f64 + }; + + println!( + "manual timing benchmark: decoys={DECOY_USERS}, iters={ITERATIONS}, preferred_median_ns={preferred_median}, full_scan_median_ns={full_scan_median}, ratio={ratio:.3}" + ); + + assert!( + full_scan_median >= preferred_median, + "full user scan should not be faster than preferred-user path in this benchmark" + ); +} + +#[tokio::test] +#[ignore = "manual benchmark: timing-sensitive and host-dependent"] +async fn tls_sni_preferred_vs_no_sni_fallback_manual_benchmark() { + let _guard = auth_probe_test_guard(); + + const DECOY_USERS: usize = 8_000; + const ITERATIONS: usize = 250; + + let preferred_user = "user-b"; + let target_secret_hex = "abababababababababababababababab"; + let target_secret = [0xABu8; 16]; + + let mut config = ProxyConfig::default(); + config.general.modes.tls = true; + config.access.ignore_time_skew = true; + + for i in 0..DECOY_USERS { + config.access.users.insert( + format!("decoy_{i}"), + "00000000000000000000000000000000".to_string(), + ); + } + + config + .access + .users + .insert(preferred_user.to_string(), target_secret_hex.to_string()); + + let mut sni_samples = Vec::with_capacity(ITERATIONS); + let mut no_sni_samples = Vec::with_capacity(ITERATIONS); + + for i in 0..ITERATIONS { + let with_sni = make_valid_tls_client_hello_with_sni_and_alpn( + &target_secret, + i as u32, + preferred_user, + &[b"h2"], + ); + let no_sni = make_valid_tls_handshake(&target_secret, (i as u32).wrapping_add(10_000)); + + let started_sni = Instant::now(); + let sni_secrets = decode_user_secrets(&config, Some(preferred_user)); + let sni_result = tls::validate_tls_handshake_with_replay_window( + &with_sni, + &sni_secrets, + config.access.ignore_time_skew, + config.access.replay_window_secs, + ); + sni_samples.push(started_sni.elapsed().as_nanos()); + assert!(sni_result.is_some()); + + let started_no_sni = Instant::now(); + let no_sni_secrets = decode_user_secrets(&config, None); + let no_sni_result = tls::validate_tls_handshake_with_replay_window( + &no_sni, + &no_sni_secrets, + config.access.ignore_time_skew, + config.access.replay_window_secs, + ); + no_sni_samples.push(started_no_sni.elapsed().as_nanos()); + assert!(no_sni_result.is_some()); + } + + let sni_median = median_ns(&mut sni_samples); + let no_sni_median = median_ns(&mut no_sni_samples); + + let ratio = if sni_median == 0 { + 0.0 + } else { + no_sni_median as f64 / sni_median as f64 + }; + + println!( + "manual tls benchmark: decoys={DECOY_USERS}, iters={ITERATIONS}, sni_median_ns={sni_median}, no_sni_median_ns={no_sni_median}, ratio_no_sni_over_sni={ratio:.3}" + ); +} diff --git a/src/proxy/tests/masking_ab_envelope_blur_integration_security_tests.rs b/src/proxy/tests/masking_ab_envelope_blur_integration_security_tests.rs index 3e860e8..a977409 100644 --- a/src/proxy/tests/masking_ab_envelope_blur_integration_security_tests.rs +++ b/src/proxy/tests/masking_ab_envelope_blur_integration_security_tests.rs @@ -493,9 +493,12 @@ async fn timing_classifier_light_fuzz_pairwise_bucketed_accuracy_stays_bounded_u ]; let mut meaningful_improvement_seen = false; - let mut baseline_sum = 0.0f64; - let mut hardened_sum = 0.0f64; - let mut pair_count = 0usize; + let mut informative_baseline_sum = 0.0f64; + let mut informative_hardened_sum = 0.0f64; + let mut informative_pair_count = 0usize; + let mut low_info_baseline_sum = 0.0f64; + let mut low_info_hardened_sum = 0.0f64; + let mut low_info_pair_count = 0usize; let acc_quant_step = 1.0 / (2 * SAMPLE_COUNT) as f64; let tolerated_pair_regression = acc_quant_step + 0.03; @@ -522,6 +525,16 @@ async fn timing_classifier_light_fuzz_pairwise_bucketed_accuracy_stays_bounded_u hardened_acc <= baseline_acc + tolerated_pair_regression, "normalization should not materially worsen informative pair: baseline={baseline_acc:.3} hardened={hardened_acc:.3} tolerated={tolerated_pair_regression:.3}" ); + informative_baseline_sum += baseline_acc; + informative_hardened_sum += hardened_acc; + informative_pair_count += 1; + } else { + // Low-information pairs (near-random baseline separability) are expected + // to exhibit quantized jitter at low sample counts; do not fold them into + // strict average-regression checks used for informative side-channel signal. + low_info_baseline_sum += baseline_acc; + low_info_hardened_sum += hardened_acc; + low_info_pair_count += 1; } println!( @@ -531,20 +544,30 @@ async fn timing_classifier_light_fuzz_pairwise_bucketed_accuracy_stays_bounded_u if hardened_acc + 0.05 <= baseline_acc { meaningful_improvement_seen = true; } - - baseline_sum += baseline_acc; - hardened_sum += hardened_acc; - pair_count += 1; } - let baseline_avg = baseline_sum / pair_count as f64; - let hardened_avg = hardened_sum / pair_count as f64; + assert!( + informative_pair_count > 0, + "expected at least one informative pair for timing-separability guard" + ); + + let informative_baseline_avg = informative_baseline_sum / informative_pair_count as f64; + let informative_hardened_avg = informative_hardened_sum / informative_pair_count as f64; assert!( - hardened_avg <= baseline_avg + 0.10, - "normalization should not materially increase average pairwise separability: baseline_avg={baseline_avg:.3} hardened_avg={hardened_avg:.3}" + informative_hardened_avg <= informative_baseline_avg + 0.10, + "normalization should not materially increase informative average separability: baseline_avg={informative_baseline_avg:.3} hardened_avg={informative_hardened_avg:.3}" ); + if low_info_pair_count > 0 { + let low_info_baseline_avg = low_info_baseline_sum / low_info_pair_count as f64; + let low_info_hardened_avg = low_info_hardened_sum / low_info_pair_count as f64; + assert!( + low_info_hardened_avg <= low_info_baseline_avg + 0.40, + "normalization low-info average drift exceeded jitter budget: baseline_avg={low_info_baseline_avg:.3} hardened_avg={low_info_hardened_avg:.3}" + ); + } + // Optional signal only: do not require improvement on every run because // noisy CI schedulers can flatten pairwise differences at low sample counts. let _ = meaningful_improvement_seen; diff --git a/src/proxy/tests/masking_additional_hardening_security_tests.rs b/src/proxy/tests/masking_additional_hardening_security_tests.rs new file mode 100644 index 0000000..a6f6386 --- /dev/null +++ b/src/proxy/tests/masking_additional_hardening_security_tests.rs @@ -0,0 +1,126 @@ +use super::*; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::task::{Context, Poll}; +use tokio::io::AsyncRead; +use tokio::time::{Duration, timeout}; + +struct EndlessReader { + produced: Arc, +} + +impl AsyncRead for EndlessReader { + fn poll_read( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + let len = buf.remaining().max(1); + let fill = vec![0xAA; len]; + buf.put_slice(&fill); + self.produced.fetch_add(len, Ordering::Relaxed); + Poll::Ready(Ok(())) + } +} + +#[test] +fn loop_guard_unspecified_bind_uses_interface_inventory() { + let local: SocketAddr = "0.0.0.0:443".parse().unwrap(); + let resolved: SocketAddr = "192.168.44.10:443".parse().unwrap(); + let interfaces = vec!["192.168.44.10".parse().unwrap()]; + + assert!(is_mask_target_local_listener_with_interfaces( + "mask.example", + 443, + local, + Some(resolved), + &interfaces, + )); +} + +#[tokio::test] +async fn consume_client_data_stops_after_byte_cap_without_eof() { + let produced = Arc::new(AtomicUsize::new(0)); + let reader = EndlessReader { + produced: Arc::clone(&produced), + }; + let cap = 10_000usize; + + consume_client_data(reader, cap).await; + + let total = produced.load(Ordering::Relaxed); + assert!( + total >= cap, + "consume path must read at least up to cap before stopping" + ); + assert!( + total <= cap + 8192, + "consume path must stop within one read chunk above cap" + ); +} + +#[test] +fn masking_beobachten_minutes_zero_fail_closes_to_minimum_ttl() { + let mut config = ProxyConfig::default(); + config.general.beobachten = true; + config.general.beobachten_minutes = 0; + + let ttl = masking_beobachten_ttl(&config); + assert_eq!(ttl, std::time::Duration::from_secs(60)); +} + +#[test] +fn timing_normalization_zero_floor_safety_net_defaults_to_mask_timeout() { + let mut config = ProxyConfig::default(); + config.censorship.mask_timing_normalization_enabled = true; + config.censorship.mask_timing_normalization_floor_ms = 0; + config.censorship.mask_timing_normalization_ceiling_ms = 0; + + let budget = mask_outcome_target_budget(&config); + assert_eq!( + budget, + Duration::from_millis(0), + "zero floor/ceiling must produce zero extra normalization budget" + ); +} + +#[tokio::test] +async fn loop_guard_blocks_self_target_before_proxy_protocol_header_growth() { + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let accept_task = tokio::spawn(async move { + timeout(Duration::from_millis(120), listener.accept()) + .await + .is_ok() + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = None; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_proxy_protocol = 2; + + let peer: SocketAddr = "203.0.113.251:55991".parse().unwrap(); + let local_addr: SocketAddr = format!("0.0.0.0:{}", backend_addr.port()).parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + handle_bad_client( + tokio::io::empty(), + tokio::io::sink(), + b"GET / HTTP/1.1\r\n\r\n", + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + let accepted = accept_task.await.unwrap(); + assert!( + !accepted, + "loop guard must fail closed before any recursive PROXY protocol amplification" + ); +} diff --git a/src/proxy/tests/masking_aggressive_mode_security_tests.rs b/src/proxy/tests/masking_aggressive_mode_security_tests.rs index a77fc14..7356dc0 100644 --- a/src/proxy/tests/masking_aggressive_mode_security_tests.rs +++ b/src/proxy/tests/masking_aggressive_mode_security_tests.rs @@ -85,7 +85,10 @@ async fn aggressive_mode_shapes_backend_silent_non_eof_path() { let legacy = capture_forwarded_len_with_mode(body_sent, false, false, false, 0).await; let aggressive = capture_forwarded_len_with_mode(body_sent, false, true, false, 0).await; - assert!(legacy < floor, "legacy mode should keep timeout path unshaped"); + assert!( + legacy < floor, + "legacy mode should keep timeout path unshaped" + ); assert!( aggressive >= floor, "aggressive mode must shape backend-silent non-EOF paths (aggressive={aggressive}, floor={floor})" diff --git a/src/proxy/tests/masking_classification_completeness_security_tests.rs b/src/proxy/tests/masking_classification_completeness_security_tests.rs new file mode 100644 index 0000000..35bf87b --- /dev/null +++ b/src/proxy/tests/masking_classification_completeness_security_tests.rs @@ -0,0 +1,16 @@ +use super::*; + +#[test] +fn detect_client_type_recognizes_extended_http_probe_verbs() { + assert_eq!(detect_client_type(b"CONNECT / HTTP/1.1\r\n"), "HTTP"); + assert_eq!(detect_client_type(b"TRACE / HTTP/1.1\r\n"), "HTTP"); + assert_eq!(detect_client_type(b"PATCH / HTTP/1.1\r\n"), "HTTP"); +} + +#[test] +fn detect_client_type_recognizes_fragmented_http_method_prefixes() { + assert_eq!(detect_client_type(b"CO"), "HTTP"); + assert_eq!(detect_client_type(b"CON"), "HTTP"); + assert_eq!(detect_client_type(b"TR"), "HTTP"); + assert_eq!(detect_client_type(b"PAT"), "HTTP"); +} diff --git a/src/proxy/tests/masking_connect_failure_close_matrix_security_tests.rs b/src/proxy/tests/masking_connect_failure_close_matrix_security_tests.rs new file mode 100644 index 0000000..718189c --- /dev/null +++ b/src/proxy/tests/masking_connect_failure_close_matrix_security_tests.rs @@ -0,0 +1,126 @@ +use super::*; +use crate::network::dns_overrides::install_entries; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::time::{Duration, Instant, timeout}; + +async fn run_connect_failure_case( + host: &str, + port: u16, + timing_normalization_enabled: bool, + peer: SocketAddr, +) -> Duration { + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = None; + config.censorship.mask_host = Some(host.to_string()); + config.censorship.mask_port = port; + config.censorship.mask_timing_normalization_enabled = timing_normalization_enabled; + config.censorship.mask_timing_normalization_floor_ms = 120; + config.censorship.mask_timing_normalization_ceiling_ms = 120; + + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + let probe = b"CONNECT example.org:443 HTTP/1.1\r\nHost: example.org\r\n\r\n"; + + let (mut client_writer, client_reader) = duplex(1024); + let (mut client_visible_reader, client_visible_writer) = duplex(1024); + + let started = Instant::now(); + let task = tokio::spawn(async move { + handle_bad_client( + client_reader, + client_visible_writer, + probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + }); + + client_writer.shutdown().await.unwrap(); + + timeout(Duration::from_secs(4), task) + .await + .unwrap() + .unwrap(); + + let mut buf = [0u8; 1]; + let n = timeout(Duration::from_secs(1), client_visible_reader.read(&mut buf)) + .await + .unwrap() + .unwrap(); + assert_eq!( + n, 0, + "connect-failure path must close client-visible writer" + ); + + started.elapsed() +} + +#[tokio::test] +async fn connect_failure_refusal_close_behavior_matrix() { + let temp_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let unused_port = temp_listener.local_addr().unwrap().port(); + drop(temp_listener); + + for (idx, timing_normalization_enabled) in [false, true].into_iter().enumerate() { + let peer: SocketAddr = format!("203.0.113.210:{}", 54100 + idx as u16) + .parse() + .unwrap(); + let elapsed = + run_connect_failure_case("127.0.0.1", unused_port, timing_normalization_enabled, peer) + .await; + + if timing_normalization_enabled { + assert!( + elapsed >= Duration::from_millis(110) && elapsed < Duration::from_millis(250), + "normalized refusal path must honor configured timing envelope without stalling" + ); + } else { + assert!( + elapsed >= Duration::from_millis(40) && elapsed < Duration::from_millis(150), + "non-normalized refusal path must honor baseline connect budget without stalling" + ); + } + } +} + +#[tokio::test] +async fn connect_failure_overridden_hostname_close_behavior_matrix() { + let temp_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let unused_port = temp_listener.local_addr().unwrap().port(); + drop(temp_listener); + + // Make hostname resolution deterministic in tests so timing ceilings are meaningful. + install_entries(&[format!("mask.invalid:{}:127.0.0.1", unused_port)]).unwrap(); + + for (idx, timing_normalization_enabled) in [false, true].into_iter().enumerate() { + let peer: SocketAddr = format!("203.0.113.220:{}", 54200 + idx as u16) + .parse() + .unwrap(); + let elapsed = run_connect_failure_case( + "mask.invalid", + unused_port, + timing_normalization_enabled, + peer, + ) + .await; + + if timing_normalization_enabled { + assert!( + elapsed >= Duration::from_millis(110) && elapsed < Duration::from_millis(250), + "normalized overridden-host path must honor configured timing envelope without stalling" + ); + } else { + assert!( + elapsed >= Duration::from_millis(40) && elapsed < Duration::from_millis(150), + "non-normalized overridden-host path must honor baseline connect budget without stalling" + ); + } + } + + install_entries(&[]).unwrap(); +} diff --git a/src/proxy/tests/masking_consume_idle_timeout_security_tests.rs b/src/proxy/tests/masking_consume_idle_timeout_security_tests.rs new file mode 100644 index 0000000..f2c39a2 --- /dev/null +++ b/src/proxy/tests/masking_consume_idle_timeout_security_tests.rs @@ -0,0 +1,88 @@ +use super::*; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::time::Instant; +use tokio::io::{AsyncRead, ReadBuf}; + +struct OneByteThenStall { + sent: bool, +} + +impl AsyncRead for OneByteThenStall { + fn poll_read( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + if !self.sent { + self.sent = true; + buf.put_slice(&[0x42]); + Poll::Ready(Ok(())) + } else { + Poll::Pending + } + } +} + +#[tokio::test] +async fn stalling_client_terminates_at_idle_not_relay_timeout() { + let reader = OneByteThenStall { sent: false }; + let started = Instant::now(); + + let result = tokio::time::timeout( + MASK_RELAY_TIMEOUT, + consume_client_data(reader, MASK_BUFFER_SIZE * 4), + ) + .await; + + assert!( + result.is_ok(), + "consume_client_data should complete by per-read idle timeout, not hit relay timeout" + ); + + let elapsed = started.elapsed(); + assert!( + elapsed >= (MASK_RELAY_IDLE_TIMEOUT / 2), + "consume_client_data returned too quickly for idle-timeout path: {elapsed:?}" + ); + assert!( + elapsed < MASK_RELAY_TIMEOUT, + "consume_client_data waited full relay timeout ({elapsed:?}); \ + per-read idle timeout is missing" + ); +} + +#[tokio::test] +async fn fast_reader_drains_to_eof() { + let data = vec![0xAAu8; 32 * 1024]; + let reader = std::io::Cursor::new(data); + + tokio::time::timeout(MASK_RELAY_TIMEOUT, consume_client_data(reader, usize::MAX)) + .await + .expect("consume_client_data did not complete for fast EOF reader"); +} + +#[tokio::test] +async fn io_error_terminates_cleanly() { + struct ErrReader; + + impl AsyncRead for ErrReader { + fn poll_read( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _buf: &mut ReadBuf<'_>, + ) -> Poll> { + Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::ConnectionReset, + "simulated reset", + ))) + } + } + + tokio::time::timeout( + MASK_RELAY_TIMEOUT, + consume_client_data(ErrReader, usize::MAX), + ) + .await + .expect("consume_client_data did not return on I/O error"); +} diff --git a/src/proxy/tests/masking_consume_stress_adversarial_tests.rs b/src/proxy/tests/masking_consume_stress_adversarial_tests.rs new file mode 100644 index 0000000..12287b5 --- /dev/null +++ b/src/proxy/tests/masking_consume_stress_adversarial_tests.rs @@ -0,0 +1,64 @@ +use super::*; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::time::Instant; +use tokio::io::{AsyncRead, ReadBuf}; +use tokio::task::JoinSet; + +struct OneByteThenStall { + sent: bool, +} + +impl AsyncRead for OneByteThenStall { + fn poll_read( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + if !self.sent { + self.sent = true; + buf.put_slice(&[0xAA]); + Poll::Ready(Ok(())) + } else { + Poll::Pending + } + } +} + +#[tokio::test] +async fn consume_stall_stress_finishes_within_idle_budget() { + let mut set = JoinSet::new(); + let started = Instant::now(); + + for _ in 0..64 { + set.spawn(async { + tokio::time::timeout( + MASK_RELAY_TIMEOUT, + consume_client_data(OneByteThenStall { sent: false }, usize::MAX), + ) + .await + .expect("consume_client_data exceeded relay timeout under stall load"); + }); + } + + while let Some(res) = set.join_next().await { + res.unwrap(); + } + + // Under test constants idle=100ms, relay=200ms. 64 concurrent tasks stalling + // for 100ms should complete well under a strict 600ms boundary. + assert!( + started.elapsed() < MASK_RELAY_TIMEOUT * 3, + "stall stress batch completed too slowly; possible async executor starvation or head-of-line blocking" + ); +} + +#[tokio::test] +async fn consume_zero_cap_returns_immediately() { + let started = Instant::now(); + consume_client_data(tokio::io::empty(), 0).await; + assert!( + started.elapsed() < MASK_RELAY_IDLE_TIMEOUT, + "zero byte cap must return immediately" + ); +} diff --git a/src/proxy/tests/masking_extended_attack_surface_security_tests.rs b/src/proxy/tests/masking_extended_attack_surface_security_tests.rs new file mode 100644 index 0000000..650731c --- /dev/null +++ b/src/proxy/tests/masking_extended_attack_surface_security_tests.rs @@ -0,0 +1,225 @@ +use super::*; +use tokio::io::{AsyncWriteExt, duplex}; +use tokio::time::{Duration, Instant, timeout}; + +fn make_self_target_config( + timing_normalization_enabled: bool, + floor_ms: u64, + ceiling_ms: u64, + beobachten_enabled: bool, +) -> ProxyConfig { + let mut config = ProxyConfig::default(); + config.general.beobachten = beobachten_enabled; + config.general.beobachten_minutes = 5; + config.censorship.mask = true; + config.censorship.mask_unix_sock = None; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = 443; + config.censorship.mask_timing_normalization_enabled = timing_normalization_enabled; + config.censorship.mask_timing_normalization_floor_ms = floor_ms; + config.censorship.mask_timing_normalization_ceiling_ms = ceiling_ms; + config +} + +async fn run_self_target_refusal( + config: ProxyConfig, + peer: SocketAddr, + initial: &'static [u8], +) -> Duration { + let beobachten = BeobachtenStore::new(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().expect("valid local addr"); + + let (mut client, server) = duplex(1024); + let started = Instant::now(); + let task = tokio::spawn(async move { + handle_bad_client( + server, + tokio::io::sink(), + initial, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + }); + + client + .shutdown() + .await + .expect("client shutdown must succeed"); + + timeout(Duration::from_secs(3), task) + .await + .expect("self-target refusal must complete in bounded time") + .expect("self-target refusal task must not panic"); + + started.elapsed() +} + +#[tokio::test] +async fn positive_self_target_refusal_honors_normalization_floor() { + let config = make_self_target_config(true, 120, 120, false); + let peer: SocketAddr = "203.0.113.41:54041".parse().expect("valid peer"); + + let elapsed = run_self_target_refusal(config, peer, b"GET / HTTP/1.1\r\n\r\n").await; + + assert!( + elapsed >= Duration::from_millis(110) && elapsed < Duration::from_millis(260), + "normalized self-target refusal must stay within expected envelope" + ); +} + +#[tokio::test] +async fn negative_non_normalized_refusal_does_not_sleep_to_large_floor() { + let config = make_self_target_config(false, 240, 240, false); + let peer: SocketAddr = "203.0.113.42:54042".parse().expect("valid peer"); + + let elapsed = run_self_target_refusal(config, peer, b"GET / HTTP/1.1\r\n\r\n").await; + + assert!( + elapsed < Duration::from_millis(180), + "non-normalized path must not inherit normalization floor delays" + ); +} + +#[tokio::test] +async fn edge_ceiling_below_floor_uses_floor_fail_closed() { + let config = make_self_target_config(true, 140, 80, false); + let peer: SocketAddr = "203.0.113.43:54043".parse().expect("valid peer"); + + let elapsed = run_self_target_refusal(config, peer, b"GET / HTTP/1.1\r\n\r\n").await; + + assert!( + elapsed >= Duration::from_millis(130) && elapsed < Duration::from_millis(280), + "ceiling max { + max = elapsed; + } + assert!( + elapsed >= Duration::from_millis(100) && elapsed < Duration::from_millis(320), + "parallel probe latency must stay bounded under normalization" + ); + } + + assert!( + max.saturating_sub(min) <= Duration::from_millis(130), + "normalization should limit path variance across adversarial parallel probes" + ); +} + +#[tokio::test] +async fn integration_beobachten_records_probe_classification_on_refusal() { + let config = make_self_target_config(false, 0, 0, true); + let peer: SocketAddr = "198.51.100.71:55071".parse().expect("valid peer"); + let local_addr: SocketAddr = "127.0.0.1:443".parse().expect("valid local addr"); + let beobachten = BeobachtenStore::new(); + + let (mut client, server) = duplex(1024); + let task = tokio::spawn(async move { + handle_bad_client( + server, + tokio::io::sink(), + b"GET /classified HTTP/1.1\r\nHost: demo\r\n\r\n", + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + beobachten.snapshot_text(Duration::from_secs(60)) + }); + + client + .shutdown() + .await + .expect("client shutdown must succeed"); + + let snapshot = timeout(Duration::from_secs(3), task) + .await + .expect("integration task must complete") + .expect("integration task must not panic"); + + assert!(snapshot.contains("[HTTP]")); + assert!(snapshot.contains("198.51.100.71-1")); +} + +#[tokio::test] +async fn light_fuzz_timing_configuration_matrix_is_bounded() { + let mut seed = 0xA17E_55AA_2026_0323u64; + + for case in 0..48u64 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let enabled = (seed & 1) == 0; + let floor = (seed >> 8) % 180; + let ceiling = (seed >> 24) % 180; + let config = make_self_target_config(enabled, floor, ceiling, false); + let peer: SocketAddr = format!("203.0.113.90:{}", 56000 + (case as u16)) + .parse() + .expect("valid peer"); + + let elapsed = run_self_target_refusal(config, peer, b"HEAD /h HTTP/1.1\r\n\r\n").await; + + assert!( + elapsed < Duration::from_millis(420), + "fuzz case must stay bounded and never hang" + ); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_high_fanout_self_target_refusal_no_deadlock_or_timeout() { + let workers = 64usize; + let mut tasks = Vec::with_capacity(workers); + + for idx in 0..workers { + tasks.push(tokio::spawn(async move { + let config = make_self_target_config(false, 0, 0, false); + let peer: SocketAddr = format!("198.51.100.200:{}", 57000 + idx as u16) + .parse() + .expect("valid peer"); + run_self_target_refusal(config, peer, b"GET /stress HTTP/1.1\r\n\r\n").await + })); + } + + timeout(Duration::from_secs(5), async { + for task in tasks { + let elapsed = task.await.expect("stress task must not panic"); + assert!( + elapsed < Duration::from_millis(260), + "stress refusal must remain bounded without normalization" + ); + } + }) + .await + .expect("high-fanout refusal workload must complete without deadlock"); +} diff --git a/src/proxy/tests/masking_http2_preface_integration_security_tests.rs b/src/proxy/tests/masking_http2_preface_integration_security_tests.rs new file mode 100644 index 0000000..7f1c03f --- /dev/null +++ b/src/proxy/tests/masking_http2_preface_integration_security_tests.rs @@ -0,0 +1,55 @@ +use super::*; +use tokio::net::TcpListener; +use tokio::time::Duration; + +#[tokio::test] +async fn http2_preface_is_forwarded_and_recorded_as_http() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let preface = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn({ + let preface = preface.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut received = vec![0u8; preface.len()]; + stream.read_exact(&mut received).await.unwrap(); + assert_eq!(received, preface); + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = true; + config.general.beobachten_minutes = 1; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let peer: SocketAddr = "198.51.100.130:54130".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let (client_reader, _client_writer) = tokio::io::duplex(512); + let (_client_visible_reader, client_visible_writer) = tokio::io::duplex(512); + let beobachten = BeobachtenStore::new(); + + handle_bad_client( + client_reader, + client_visible_writer, + &preface, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + tokio::time::timeout(Duration::from_secs(2), accept_task) + .await + .unwrap() + .unwrap(); + + let snapshot = beobachten.snapshot_text(Duration::from_secs(60)); + assert!(snapshot.contains("[HTTP]")); + assert!(snapshot.contains("198.51.100.130-1")); +} diff --git a/src/proxy/tests/masking_http2_probe_classification_security_tests.rs b/src/proxy/tests/masking_http2_probe_classification_security_tests.rs new file mode 100644 index 0000000..34e04a9 --- /dev/null +++ b/src/proxy/tests/masking_http2_probe_classification_security_tests.rs @@ -0,0 +1,92 @@ +use super::*; + +#[test] +fn full_http2_preface_classified_as_http_probe() { + let preface = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n"; + assert!( + is_http_probe(preface), + "HTTP/2 connection preface must be classified as HTTP probe" + ); +} + +#[test] +fn partial_http2_preface_3_bytes_classified() { + assert!( + is_http_probe(b"PRI"), + "3-byte HTTP/2 preface prefix must be classified" + ); +} + +#[test] +fn partial_http2_preface_2_bytes_classified() { + assert!( + is_http_probe(b"PR"), + "2-byte HTTP/2 preface prefix must be classified" + ); +} + +#[test] +fn existing_http1_methods_unaffected() { + for prefix in [ + b"GET / HTTP/1.1\r\n".as_ref(), + b"POST /api HTTP/1.1\r\n".as_ref(), + b"CONNECT example.com:443 HTTP/1.1\r\n".as_ref(), + b"TRACE / HTTP/1.1\r\n".as_ref(), + b"PATCH / HTTP/1.1\r\n".as_ref(), + ] { + assert!(is_http_probe(prefix)); + } +} + +#[test] +fn non_http_data_not_classified() { + for data in [ + b"\x16\x03\x01\x00\xf1".as_ref(), + b"SSH-2.0-OpenSSH_8.9\r\n".as_ref(), + b"\x00\x01\x02\x03".as_ref(), + b"".as_ref(), + b"P".as_ref(), + ] { + assert!(!is_http_probe(data)); + } +} + +#[test] +fn light_fuzz_non_http_prefixes_not_misclassified() { + // Deterministic pseudo-fuzz to exercise classifier edges while avoiding + // known HTTP method and partial windows. + let mut x = 0x1234_5678u32; + for _ in 0..1024 { + x = x.wrapping_mul(1664525).wrapping_add(1013904223); + let len = 4 + ((x >> 8) as usize % 12); + let mut data = vec![0u8; len]; + for byte in &mut data { + x = x.wrapping_mul(1664525).wrapping_add(1013904223); + *byte = (x & 0xFF) as u8; + } + + if [ + b"GET ".as_ref(), + b"POST".as_ref(), + b"HEAD".as_ref(), + b"PUT ".as_ref(), + b"DELETE".as_ref(), + b"OPTIONS".as_ref(), + b"CONNECT".as_ref(), + b"TRACE".as_ref(), + b"PATCH".as_ref(), + b"PRI ".as_ref(), + ] + .iter() + .any(|m| data.starts_with(m)) + { + continue; + } + + assert!( + !is_http_probe(&data), + "non-http pseudo-fuzz input misclassified: {:?}", + &data[..data.len().min(8)] + ); + } +} diff --git a/src/proxy/tests/masking_http_probe_boundary_security_tests.rs b/src/proxy/tests/masking_http_probe_boundary_security_tests.rs new file mode 100644 index 0000000..c8f3ec0 --- /dev/null +++ b/src/proxy/tests/masking_http_probe_boundary_security_tests.rs @@ -0,0 +1,85 @@ +use super::*; + +#[test] +fn exact_four_byte_http_tokens_are_classified() { + for token in [ + b"GET ".as_ref(), + b"POST".as_ref(), + b"HEAD".as_ref(), + b"PUT ".as_ref(), + b"PRI ".as_ref(), + ] { + assert!( + is_http_probe(token), + "exact 4-byte token must be classified as HTTP probe: {:?}", + token + ); + } +} + +#[test] +fn exact_four_byte_non_http_tokens_are_not_classified() { + for token in [ + b"GEX ".as_ref(), + b"POXT".as_ref(), + b"HEA/".as_ref(), + b"PU\0 ".as_ref(), + b"PRI/".as_ref(), + ] { + assert!( + !is_http_probe(token), + "non-HTTP 4-byte token must not be classified: {:?}", + token + ); + } +} + +#[test] +fn detect_client_type_keeps_http_label_for_minimal_four_byte_http_prefixes() { + assert_eq!(detect_client_type(b"GET "), "HTTP"); + assert_eq!(detect_client_type(b"PRI "), "HTTP"); +} + +#[test] +fn exact_long_http_tokens_are_classified() { + for token in [b"CONNECT".as_ref(), b"TRACE".as_ref(), b"PATCH".as_ref()] { + assert!( + is_http_probe(token), + "exact long HTTP token must be classified as HTTP probe: {:?}", + token + ); + } +} + +#[test] +fn detect_client_type_keeps_http_label_for_exact_long_http_tokens() { + assert_eq!(detect_client_type(b"CONNECT"), "HTTP"); + assert_eq!(detect_client_type(b"TRACE"), "HTTP"); + assert_eq!(detect_client_type(b"PATCH"), "HTTP"); +} + +#[test] +fn light_fuzz_four_byte_ascii_noise_not_misclassified() { + // Deterministic pseudo-fuzz over 4-byte printable ASCII inputs. + let mut x = 0xA17C_93E5u32; + for _ in 0..2048 { + let mut token = [0u8; 4]; + for byte in &mut token { + x = x.wrapping_mul(1664525).wrapping_add(1013904223); + *byte = 32 + ((x & 0x3F) as u8); // printable ASCII subset + } + + if [b"GET ", b"POST", b"HEAD", b"PUT ", b"PRI "] + .iter() + .any(|m| token.as_slice() == *m) + { + continue; + } + + assert!( + !is_http_probe(&token), + "pseudo-fuzz noise misclassified as HTTP probe: {:?}", + token + ); + } +} diff --git a/src/proxy/tests/masking_interface_cache_concurrency_security_tests.rs b/src/proxy/tests/masking_interface_cache_concurrency_security_tests.rs new file mode 100644 index 0000000..ed6d1ab --- /dev/null +++ b/src/proxy/tests/masking_interface_cache_concurrency_security_tests.rs @@ -0,0 +1,41 @@ +#![cfg(unix)] + +use super::*; +use std::sync::{Mutex, OnceLock}; +use tokio::sync::Barrier; + +fn interface_cache_test_lock() -> &'static Mutex<()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::new(())) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn adversarial_parallel_cold_miss_performs_single_interface_refresh() { + let _guard = interface_cache_test_lock() + .lock() + .unwrap_or_else(|poison| poison.into_inner()); + reset_local_interface_enumerations_for_tests(); + + let local_addr: SocketAddr = "0.0.0.0:443".parse().expect("valid local addr"); + let workers = 32usize; + let barrier = std::sync::Arc::new(Barrier::new(workers)); + let mut tasks = Vec::with_capacity(workers); + + for _ in 0..workers { + let barrier = std::sync::Arc::clone(&barrier); + tasks.push(tokio::spawn(async move { + barrier.wait().await; + is_mask_target_local_listener_async("127.0.0.1", 443, local_addr, None).await + })); + } + + for task in tasks { + let _ = task.await.expect("parallel cache task must not panic"); + } + + assert_eq!( + local_interface_enumerations_for_tests(), + 1, + "parallel cold misses must coalesce into a single interface enumeration" + ); +} diff --git a/src/proxy/tests/masking_interface_cache_defense_in_depth_security_tests.rs b/src/proxy/tests/masking_interface_cache_defense_in_depth_security_tests.rs new file mode 100644 index 0000000..d82cf82 --- /dev/null +++ b/src/proxy/tests/masking_interface_cache_defense_in_depth_security_tests.rs @@ -0,0 +1,51 @@ +#![cfg(unix)] + +use super::*; + +#[test] +fn defense_in_depth_empty_refresh_preserves_previous_non_empty_interfaces() { + let previous = vec![ + "192.168.100.7" + .parse::() + .expect("must parse interface ip"), + ]; + let refreshed = Vec::new(); + + let next = choose_interface_snapshot(&previous, refreshed); + + assert_eq!( + next, previous, + "empty refresh should preserve previous non-empty snapshot to avoid fail-open loop-guard regressions" + ); +} + +#[test] +fn defense_in_depth_non_empty_refresh_replaces_previous_snapshot() { + let previous = vec![ + "192.168.100.7" + .parse::() + .expect("must parse interface ip"), + ]; + let refreshed = vec![ + "10.55.0.3" + .parse::() + .expect("must parse refreshed interface ip"), + ]; + + let next = choose_interface_snapshot(&previous, refreshed.clone()); + + assert_eq!(next, refreshed); +} + +#[test] +fn defense_in_depth_empty_refresh_keeps_empty_when_no_previous_snapshot_exists() { + let previous = Vec::new(); + let refreshed = Vec::new(); + + let next = choose_interface_snapshot(&previous, refreshed); + + assert!( + next.is_empty(), + "empty refresh with no previous snapshot should remain empty" + ); +} diff --git a/src/proxy/tests/masking_interface_cache_security_tests.rs b/src/proxy/tests/masking_interface_cache_security_tests.rs new file mode 100644 index 0000000..17debb0 --- /dev/null +++ b/src/proxy/tests/masking_interface_cache_security_tests.rs @@ -0,0 +1,49 @@ +#![cfg(unix)] + +use super::*; +use std::sync::{Mutex, OnceLock}; + +fn interface_cache_test_lock() -> &'static Mutex<()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::new(())) +} + +#[tokio::test] +async fn tdd_repeated_local_listener_checks_do_not_repeat_interface_enumeration_within_window() { + let _guard = interface_cache_test_lock() + .lock() + .unwrap_or_else(|poison| poison.into_inner()); + reset_local_interface_enumerations_for_tests(); + + let local_addr: SocketAddr = "0.0.0.0:443".parse().expect("valid local addr"); + + let _ = is_mask_target_local_listener_async("127.0.0.1", 443, local_addr, None).await; + let _ = is_mask_target_local_listener_async("127.0.0.1", 443, local_addr, None).await; + + assert_eq!( + local_interface_enumerations_for_tests(), + 1, + "interface enumeration must be cached across repeated bad-client checks" + ); +} + +#[tokio::test] +async fn tdd_non_local_port_short_circuit_does_not_enumerate_interfaces() { + let _guard = interface_cache_test_lock() + .lock() + .unwrap_or_else(|poison| poison.into_inner()); + reset_local_interface_enumerations_for_tests(); + + let local_addr: SocketAddr = "0.0.0.0:443".parse().expect("valid local addr"); + let is_local = is_mask_target_local_listener_async("127.0.0.1", 8443, local_addr, None).await; + + assert!( + !is_local, + "different port must not be treated as local listener" + ); + assert_eq!( + local_interface_enumerations_for_tests(), + 0, + "port mismatch should bypass interface enumeration entirely" + ); +} diff --git a/src/proxy/tests/masking_offline_target_redteam_expected_fail_tests.rs b/src/proxy/tests/masking_offline_target_redteam_expected_fail_tests.rs new file mode 100644 index 0000000..efa4529 --- /dev/null +++ b/src/proxy/tests/masking_offline_target_redteam_expected_fail_tests.rs @@ -0,0 +1,178 @@ +use super::*; +use std::net::{SocketAddr, TcpListener as StdTcpListener}; +use tokio::io::{AsyncWriteExt, duplex}; +use tokio::time::{Duration, Instant}; + +fn closed_local_port() -> u16 { + let listener = StdTcpListener::bind("127.0.0.1:0").unwrap(); + let port = listener.local_addr().unwrap().port(); + drop(listener); + port +} + +#[tokio::test] +#[ignore = "red-team expected-fail: offline mask target keeps bad-client socket alive before consume timeout boundary"] +async fn redteam_offline_target_should_drop_idle_client_early() { + let (client_read, mut client_write) = duplex(1024); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = closed_local_port(); + cfg.censorship.mask_timing_normalization_enabled = false; + + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let peer_addr: SocketAddr = "192.0.2.50:5000".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + let handler = tokio::spawn(async move { + handle_bad_client( + client_read, + tokio::io::sink(), + b"GET / HTTP/1.1\r\n\r\n", + peer_addr, + local_addr, + &cfg, + &beobachten, + ) + .await; + }); + + tokio::time::sleep(Duration::from_millis(150)).await; + let write_res = client_write.write_all(b"probe-should-be-closed").await; + assert!( + write_res.is_err(), + "offline target path still keeps client writable before consume timeout" + ); + + handler.abort(); +} + +#[tokio::test] +#[ignore = "red-team expected-fail: proxy should mimic immediate RST-like close when target is offline"] +async fn redteam_offline_target_should_not_sleep_to_mask_refusal() { + let (client_read, mut client_write) = duplex(1024); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = closed_local_port(); + cfg.censorship.mask_timing_normalization_enabled = false; + + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let peer_addr: SocketAddr = "192.0.2.51:5000".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + let handler = tokio::spawn(async move { + handle_bad_client( + client_read, + tokio::io::sink(), + b"\x16\x03\x01\x00\x05hello", + peer_addr, + local_addr, + &cfg, + &beobachten, + ) + .await; + }); + + client_write.shutdown().await.unwrap(); + let _ = handler.await; + let elapsed = started.elapsed(); + + assert!( + elapsed < Duration::from_millis(10), + "offline target path still applies coarse masking sleep and is fingerprintable" + ); +} + +#[tokio::test] +#[ignore = "red-team expected-fail: refusal path should remain below strict latency envelope under burst"] +async fn redteam_offline_refusal_burst_timing_spread_should_be_tight() { + let mut samples = Vec::new(); + + for i in 0..12u16 { + let (client_read, mut client_write) = duplex(1024); + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = closed_local_port(); + cfg.censorship.mask_timing_normalization_enabled = false; + + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let peer_addr: SocketAddr = format!("192.0.2.52:{}", 5100 + i).parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + let handler = tokio::spawn(async move { + handle_bad_client( + client_read, + tokio::io::sink(), + b"GET / HTTP/1.1\r\n\r\n", + peer_addr, + local_addr, + &cfg, + &beobachten, + ) + .await; + }); + + client_write.shutdown().await.unwrap(); + let _ = handler.await; + samples.push(started.elapsed()); + } + + let min = samples.iter().copied().min().unwrap_or_default(); + let max = samples.iter().copied().max().unwrap_or_default(); + let spread = max.saturating_sub(min); + + assert!( + spread <= Duration::from_millis(5), + "offline refusal timing spread too wide for strict red-team envelope: {:?}", + spread + ); +} + +#[tokio::test] +#[ignore = "manual red-team: host resolver failure should complete without panic"] +async fn redteam_dns_resolution_failure_must_not_panic() { + let (client_read, mut client_write) = duplex(1024); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("this.domain.definitely.does.not.exist.invalid".to_string()); + cfg.censorship.mask_port = 443; + + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let peer_addr: SocketAddr = "192.0.2.99:5999".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + let handler = tokio::spawn(async move { + handle_bad_client( + client_read, + tokio::io::sink(), + b"GET / HTTP/1.1\r\n\r\n", + peer_addr, + local_addr, + &cfg, + &beobachten, + ) + .await; + }); + + client_write.shutdown().await.unwrap(); + let result = tokio::time::timeout(Duration::from_secs(2), handler).await; + assert!( + result.is_ok(), + "dns failure path stalled or panicked instead of terminating" + ); +} diff --git a/src/proxy/tests/masking_padding_timeout_adversarial_tests.rs b/src/proxy/tests/masking_padding_timeout_adversarial_tests.rs new file mode 100644 index 0000000..b99b4bc --- /dev/null +++ b/src/proxy/tests/masking_padding_timeout_adversarial_tests.rs @@ -0,0 +1,51 @@ +use super::*; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::time::Instant; +use tokio::io::AsyncWrite; + +struct NeverWritable; + +impl AsyncWrite for NeverWritable { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _buf: &[u8], + ) -> Poll> { + Poll::Pending + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Pending + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} + +#[tokio::test] +async fn shape_padding_returns_before_global_mask_timeout_on_blocked_writer() { + let mut writer = NeverWritable; + let started = Instant::now(); + + maybe_write_shape_padding(&mut writer, 1, true, 256, 4096, false, 0, false).await; + + assert!( + started.elapsed() <= MASK_TIMEOUT + std::time::Duration::from_millis(30), + "shape padding blocked past timeout budget" + ); +} + +#[tokio::test] +async fn shape_padding_with_non_http_blur_disabled_at_cap_writes_nothing() { + let mut output = Vec::new(); + { + let mut writer = tokio::io::BufWriter::new(&mut output); + maybe_write_shape_padding(&mut writer, 4096, true, 64, 4096, false, 128, false).await; + use tokio::io::AsyncWriteExt; + writer.flush().await.unwrap(); + } + + assert!(output.is_empty()); +} diff --git a/src/proxy/tests/masking_production_cap_regression_security_tests.rs b/src/proxy/tests/masking_production_cap_regression_security_tests.rs new file mode 100644 index 0000000..9ff51ba --- /dev/null +++ b/src/proxy/tests/masking_production_cap_regression_security_tests.rs @@ -0,0 +1,283 @@ +use super::*; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::time::{Duration, Instant, timeout}; + +const PROD_CAP_BYTES: usize = 5 * 1024 * 1024; + +struct FinitePatternReader { + remaining: usize, + chunk: usize, + read_calls: Arc, +} + +impl FinitePatternReader { + fn new(total: usize, chunk: usize, read_calls: Arc) -> Self { + Self { + remaining: total, + chunk, + read_calls, + } + } +} + +impl AsyncRead for FinitePatternReader { + fn poll_read( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + self.read_calls.fetch_add(1, Ordering::Relaxed); + + if self.remaining == 0 { + return Poll::Ready(Ok(())); + } + + let take = self.remaining.min(self.chunk).min(buf.remaining()); + if take == 0 { + return Poll::Ready(Ok(())); + } + + let fill = vec![0x5Au8; take]; + buf.put_slice(&fill); + self.remaining -= take; + Poll::Ready(Ok(())) + } +} + +#[derive(Default)] +struct CountingWriter { + written: usize, +} + +impl AsyncWrite for CountingWriter { + fn poll_write( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.written = self.written.saturating_add(buf.len()); + Poll::Ready(Ok(buf.len())) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} + +struct NeverReadyReader; + +impl AsyncRead for NeverReadyReader { + fn poll_read( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + Poll::Pending + } +} + +struct BudgetProbeReader { + remaining: usize, + total_read: Arc, +} + +impl BudgetProbeReader { + fn new(total: usize, total_read: Arc) -> Self { + Self { + remaining: total, + total_read, + } + } +} + +impl AsyncRead for BudgetProbeReader { + fn poll_read( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + if self.remaining == 0 { + return Poll::Ready(Ok(())); + } + + let take = self.remaining.min(buf.remaining()); + if take == 0 { + return Poll::Ready(Ok(())); + } + + let fill = vec![0xA5u8; take]; + buf.put_slice(&fill); + self.remaining -= take; + self.total_read.fetch_add(take, Ordering::Relaxed); + Poll::Ready(Ok(())) + } +} + +#[tokio::test] +async fn positive_copy_with_production_cap_stops_exactly_at_budget() { + let read_calls = Arc::new(AtomicUsize::new(0)); + let mut reader = FinitePatternReader::new(PROD_CAP_BYTES + (256 * 1024), 4096, read_calls); + let mut writer = CountingWriter::default(); + + let outcome = copy_with_idle_timeout(&mut reader, &mut writer, PROD_CAP_BYTES, true).await; + + assert_eq!( + outcome.total, PROD_CAP_BYTES, + "copy path must stop at explicit production cap" + ); + assert_eq!(writer.written, PROD_CAP_BYTES); + assert!( + !outcome.ended_by_eof, + "byte-cap stop must not be misclassified as EOF" + ); +} + +#[tokio::test] +async fn negative_consume_with_zero_cap_performs_no_reads() { + let read_calls = Arc::new(AtomicUsize::new(0)); + let reader = FinitePatternReader::new(1024, 64, Arc::clone(&read_calls)); + + consume_client_data_with_timeout_and_cap(reader, 0).await; + + assert_eq!( + read_calls.load(Ordering::Relaxed), + 0, + "zero cap must return before reading attacker-controlled bytes" + ); +} + +#[tokio::test] +async fn edge_copy_below_cap_reports_eof_without_overread() { + let read_calls = Arc::new(AtomicUsize::new(0)); + let payload = 73 * 1024; + let mut reader = FinitePatternReader::new(payload, 3072, read_calls); + let mut writer = CountingWriter::default(); + + let outcome = copy_with_idle_timeout(&mut reader, &mut writer, PROD_CAP_BYTES, true).await; + + assert_eq!(outcome.total, payload); + assert_eq!(writer.written, payload); + assert!( + outcome.ended_by_eof, + "finite upstream below cap must terminate via EOF path" + ); +} + +#[tokio::test] +async fn adversarial_blackhat_never_ready_reader_is_bounded_by_timeout_guards() { + let started = Instant::now(); + + consume_client_data_with_timeout_and_cap(NeverReadyReader, PROD_CAP_BYTES).await; + + assert!( + started.elapsed() < Duration::from_millis(350), + "never-ready reader must be bounded by idle/relay timeout protections" + ); +} + +#[tokio::test] +async fn integration_consume_path_honors_production_cap_for_large_payload() { + let read_calls = Arc::new(AtomicUsize::new(0)); + let reader = FinitePatternReader::new(PROD_CAP_BYTES + (1024 * 1024), 8192, read_calls); + + let bounded = timeout( + Duration::from_millis(350), + consume_client_data_with_timeout_and_cap(reader, PROD_CAP_BYTES), + ) + .await; + + assert!( + bounded.is_ok(), + "consume path with production cap must finish within bounded time" + ); +} + +#[tokio::test] +async fn adversarial_consume_path_never_reads_beyond_declared_byte_cap() { + let byte_cap = 5usize; + let total_read = Arc::new(AtomicUsize::new(0)); + let reader = BudgetProbeReader::new(256 * 1024, Arc::clone(&total_read)); + + consume_client_data_with_timeout_and_cap(reader, byte_cap).await; + + assert!( + total_read.load(Ordering::Relaxed) <= byte_cap, + "consume path must not read more than configured byte cap" + ); +} + +#[tokio::test] +async fn light_fuzz_cap_and_payload_matrix_preserves_min_budget_invariant() { + let mut seed = 0x1234_5678_9ABC_DEF0u64; + + for _case in 0..96u32 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let cap = ((seed & 0x3ffff) as usize).saturating_add(1); + let payload = ((seed.rotate_left(11) & 0x7ffff) as usize).saturating_add(1); + let chunk = (((seed >> 5) & 0x1fff) as usize).saturating_add(1); + + let read_calls = Arc::new(AtomicUsize::new(0)); + let mut reader = FinitePatternReader::new(payload, chunk, read_calls); + let mut writer = CountingWriter::default(); + + let outcome = copy_with_idle_timeout(&mut reader, &mut writer, cap, true).await; + let expected = payload.min(cap); + + assert_eq!( + outcome.total, expected, + "copy total must match min(payload, cap) under fuzzed inputs" + ); + assert_eq!(writer.written, expected); + if payload <= cap { + assert!(outcome.ended_by_eof); + } else { + assert!(!outcome.ended_by_eof); + } + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_parallel_copy_tasks_with_production_cap_complete_without_leaks() { + let workers = 8usize; + let mut tasks = Vec::with_capacity(workers); + + for idx in 0..workers { + tasks.push(tokio::spawn(async move { + let read_calls = Arc::new(AtomicUsize::new(0)); + let mut reader = FinitePatternReader::new( + PROD_CAP_BYTES + (idx + 1) * 4096, + 4096 + (idx * 257), + read_calls, + ); + let mut writer = CountingWriter::default(); + copy_with_idle_timeout(&mut reader, &mut writer, PROD_CAP_BYTES, true).await + })); + } + + timeout(Duration::from_secs(3), async { + for task in tasks { + let outcome = task.await.expect("stress task must not panic"); + assert_eq!( + outcome.total, PROD_CAP_BYTES, + "stress copy task must stay within production cap" + ); + assert!( + !outcome.ended_by_eof, + "stress task should end due to cap, not EOF" + ); + } + }) + .await + .expect("stress suite must complete in bounded time"); +} diff --git a/src/proxy/tests/masking_relay_guardrails_security_tests.rs b/src/proxy/tests/masking_relay_guardrails_security_tests.rs new file mode 100644 index 0000000..257c0f8 --- /dev/null +++ b/src/proxy/tests/masking_relay_guardrails_security_tests.rs @@ -0,0 +1,105 @@ +use super::*; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex, sink}; +use tokio::time::{Duration, timeout}; + +#[tokio::test] +async fn relay_to_mask_enforces_masking_session_byte_cap() { + let initial = vec![0x16, 0x03, 0x01, 0x00, 0x01]; + let extra = vec![0xAB; 96 * 1024]; + + let (client_reader, mut client_writer) = duplex(128 * 1024); + let (mask_read, _mask_read_peer) = duplex(1024); + let (mut mask_observer, mask_write) = duplex(256 * 1024); + let initial_for_task = initial.clone(); + + let relay = tokio::spawn(async move { + relay_to_mask( + client_reader, + sink(), + mask_read, + mask_write, + &initial_for_task, + false, + 512, + 4096, + false, + 0, + false, + 32 * 1024, + ) + .await; + }); + + client_writer.write_all(&extra).await.unwrap(); + client_writer.shutdown().await.unwrap(); + + timeout(Duration::from_secs(2), relay) + .await + .unwrap() + .unwrap(); + + let mut observed = Vec::new(); + timeout( + Duration::from_secs(2), + mask_observer.read_to_end(&mut observed), + ) + .await + .unwrap() + .unwrap(); + + // In this deterministic test, relay must stop exactly at the configured cap. + assert_eq!( + observed.len(), + initial.len() + (32 * 1024), + "masked relay must forward exactly up to the cap (observed={} initial={} cap={})", + observed.len(), + initial.len(), + 32 * 1024 + ); +} + +#[tokio::test] +async fn relay_to_mask_propagates_client_half_close_without_waiting_for_other_direction_timeout() { + let initial = b"GET /half-close HTTP/1.1\r\n".to_vec(); + + let (client_reader, mut client_writer) = duplex(8 * 1024); + let (mask_read, _mask_read_peer) = duplex(8 * 1024); + let (mut mask_observer, mask_write) = duplex(8 * 1024); + let initial_for_task = initial.clone(); + + let relay = tokio::spawn(async move { + relay_to_mask( + client_reader, + sink(), + mask_read, + mask_write, + &initial_for_task, + false, + 512, + 4096, + false, + 0, + false, + 32 * 1024, + ) + .await; + }); + + client_writer.shutdown().await.unwrap(); + + let mut observed = Vec::new(); + timeout( + Duration::from_millis(80), + mask_observer.read_to_end(&mut observed), + ) + .await + .expect("mask backend write side should be half-closed promptly") + .unwrap(); + + assert_eq!(&observed[..initial.len()], initial.as_slice()); + + timeout(Duration::from_secs(2), relay) + .await + .unwrap() + .unwrap(); +} diff --git a/src/proxy/tests/masking_rng_hoist_perf_regression_tests.rs b/src/proxy/tests/masking_rng_hoist_perf_regression_tests.rs new file mode 100644 index 0000000..627c48b --- /dev/null +++ b/src/proxy/tests/masking_rng_hoist_perf_regression_tests.rs @@ -0,0 +1,100 @@ +use super::*; +use tokio::io::AsyncReadExt; +use tokio::time::{Duration, timeout}; + +async fn collect_padding( + total_sent: usize, + enabled: bool, + floor: usize, + cap: usize, + above_cap_blur: bool, + blur_max: usize, + aggressive: bool, +) -> Vec { + let (mut tx, mut rx) = tokio::io::duplex(256 * 1024); + + maybe_write_shape_padding( + &mut tx, + total_sent, + enabled, + floor, + cap, + above_cap_blur, + blur_max, + aggressive, + ) + .await; + + drop(tx); + + let mut output = Vec::new(); + timeout(Duration::from_secs(1), rx.read_to_end(&mut output)) + .await + .expect("reading padded output timed out") + .expect("failed reading padded output"); + output +} + +#[tokio::test] +async fn padding_output_is_not_all_zero() { + let output = collect_padding(1, true, 256, 4096, false, 0, false).await; + + assert!( + output.len() >= 255, + "expected at least 255 padding bytes, got {}", + output.len() + ); + + let nonzero = output.iter().filter(|&&b| b != 0).count(); + // In 255 bytes of uniform randomness, the expected number of zero bytes is ~1. + // A weak nonzero check can miss severe entropy collapse. + assert!( + nonzero >= 240, + "RNG output entropy collapsed, too many zero bytes: {} nonzero out of {}", + nonzero, + output.len(), + ); +} + +#[tokio::test] +async fn padding_reaches_first_bucket_boundary() { + let output = collect_padding(1, true, 64, 4096, false, 0, false).await; + assert_eq!(output.len(), 63); +} + +#[tokio::test] +async fn disabled_padding_produces_no_output() { + let output = collect_padding(0, false, 256, 4096, false, 0, false).await; + assert!(output.is_empty()); +} + +#[tokio::test] +async fn at_cap_without_blur_produces_no_output() { + let output = collect_padding(4096, true, 64, 4096, false, 0, false).await; + assert!(output.is_empty()); +} + +#[tokio::test] +async fn above_cap_blur_is_positive_and_bounded_in_aggressive_mode() { + let output = collect_padding(4096, true, 64, 4096, true, 128, true).await; + assert!(!output.is_empty()); + assert!(output.len() <= 128, "blur exceeded max: {}", output.len()); +} + +#[tokio::test] +async fn stress_padding_runs_are_not_constant_pattern() { + // Stress and sanity-check: repeated runs should not collapse to identical + // first 16 bytes across all samples. + let mut first_chunks = Vec::new(); + for _ in 0..64 { + let out = collect_padding(1, true, 64, 4096, false, 0, false).await; + first_chunks.push(out[..16].to_vec()); + } + + let first = &first_chunks[0]; + let all_same = first_chunks.iter().all(|chunk| chunk == first); + assert!( + !all_same, + "all stress samples had identical prefix, rng output appears degenerate" + ); +} diff --git a/src/proxy/tests/masking_security_tests.rs b/src/proxy/tests/masking_security_tests.rs index 4519d85..c698b55 100644 --- a/src/proxy/tests/masking_security_tests.rs +++ b/src/proxy/tests/masking_security_tests.rs @@ -1376,6 +1376,7 @@ async fn relay_to_mask_keeps_backend_to_client_flow_when_client_to_backend_stall false, 0, false, + 5 * 1024 * 1024, ) .await; }); @@ -1506,6 +1507,7 @@ async fn relay_to_mask_timeout_cancels_and_drops_all_io_endpoints() { false, 0, false, + 5 * 1024 * 1024, ), ) .await; diff --git a/src/proxy/tests/masking_self_target_loop_security_tests.rs b/src/proxy/tests/masking_self_target_loop_security_tests.rs new file mode 100644 index 0000000..7f6cb29 --- /dev/null +++ b/src/proxy/tests/masking_self_target_loop_security_tests.rs @@ -0,0 +1,330 @@ +use super::*; +use std::net::SocketAddr; +use std::net::TcpListener as StdTcpListener; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::net::TcpListener; +use tokio::time::{Duration, Instant, timeout}; + +fn closed_local_port() -> u16 { + let listener = StdTcpListener::bind("127.0.0.1:0").unwrap(); + let port = listener.local_addr().unwrap().port(); + drop(listener); + port +} + +#[tokio::test] +async fn self_target_detection_matches_literal_ipv4_listener() { + let local: SocketAddr = "198.51.100.40:443".parse().unwrap(); + assert!(is_mask_target_local_listener_async("198.51.100.40", 443, local, None,).await); +} + +#[tokio::test] +async fn self_target_detection_matches_bracketed_ipv6_listener() { + let local: SocketAddr = "[2001:db8::44]:8443".parse().unwrap(); + assert!(is_mask_target_local_listener_async("[2001:db8::44]", 8443, local, None,).await); +} + +#[tokio::test] +async fn self_target_detection_keeps_same_ip_different_port_forwardable() { + let local: SocketAddr = "203.0.113.44:443".parse().unwrap(); + assert!(!is_mask_target_local_listener_async("203.0.113.44", 8443, local, None,).await); +} + +#[tokio::test] +async fn self_target_detection_normalizes_ipv4_mapped_ipv6_literal() { + let local: SocketAddr = "127.0.0.1:443".parse().unwrap(); + assert!(is_mask_target_local_listener_async("::ffff:127.0.0.1", 443, local, None,).await); +} + +#[tokio::test] +async fn self_target_detection_unspecified_bind_blocks_loopback_target() { + let local: SocketAddr = "0.0.0.0:443".parse().unwrap(); + assert!(is_mask_target_local_listener_async("127.0.0.1", 443, local, None,).await); +} + +#[tokio::test] +async fn self_target_detection_unspecified_bind_keeps_remote_target_forwardable() { + let local: SocketAddr = "0.0.0.0:443".parse().unwrap(); + let remote: SocketAddr = "198.51.100.44:443".parse().unwrap(); + assert!(!is_mask_target_local_listener_async("mask.example", 443, local, Some(remote),).await); +} + +#[tokio::test] +async fn self_target_fallback_refuses_recursive_loopback_connect() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let local_addr = listener.local_addr().unwrap(); + let accept_task = tokio::spawn(async move { + timeout(Duration::from_millis(120), listener.accept()) + .await + .is_ok() + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = None; + config.censorship.mask_host = Some(local_addr.ip().to_string()); + config.censorship.mask_port = local_addr.port(); + config.censorship.mask_proxy_protocol = 0; + + let peer: SocketAddr = "203.0.113.90:55090".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + handle_bad_client( + tokio::io::empty(), + tokio::io::sink(), + b"GET /", + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + let accepted = accept_task.await.unwrap(); + assert!( + !accepted, + "self-target masking must fail closed without connecting to local listener" + ); +} + +#[tokio::test] +async fn same_ip_different_port_still_forwards_to_mask_backend() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = b"GET /".to_vec(); + let accept_task = tokio::spawn({ + let expected = probe.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = vec![0u8; expected.len()]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, expected); + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = None; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_proxy_protocol = 0; + + let peer: SocketAddr = "203.0.113.91:55091".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + handle_bad_client( + tokio::io::empty(), + tokio::io::sink(), + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + timeout(Duration::from_secs(2), accept_task) + .await + .unwrap() + .unwrap(); +} + +#[test] +fn detect_client_type_http_boundary_get_and_post() { + assert_eq!(detect_client_type(b"GET "), "HTTP"); + assert_eq!(detect_client_type(b"GET /"), "HTTP"); + + assert_eq!(detect_client_type(b"POST"), "HTTP"); + assert_eq!(detect_client_type(b"POST "), "HTTP"); + assert_eq!(detect_client_type(b"POSTX"), "HTTP"); +} + +#[test] +fn detect_client_type_tls_and_length_boundaries() { + assert_eq!(detect_client_type(b"\x16\x03\x01"), "port-scanner"); + assert_eq!(detect_client_type(b"\x16\x03\x01\x00"), "TLS-scanner"); + + assert_eq!(detect_client_type(b"123456789"), "port-scanner"); + assert_eq!(detect_client_type(b"1234567890"), "unknown"); +} + +#[test] +fn build_mask_proxy_header_v1_cross_family_falls_back_to_unknown() { + let peer: SocketAddr = "192.168.1.5:12345".parse().unwrap(); + let local: SocketAddr = "[2001:db8::1]:443".parse().unwrap(); + let header = build_mask_proxy_header(1, peer, local).unwrap(); + assert_eq!(header, b"PROXY UNKNOWN\r\n"); +} + +#[test] +fn next_mask_shape_bucket_checked_mul_overflow_fails_closed() { + let floor = usize::MAX / 2 + 1; + let cap = usize::MAX; + let total = floor + 1; + assert_eq!(next_mask_shape_bucket(total, floor, cap), total); +} + +#[tokio::test] +async fn self_target_reject_path_keeps_timing_budget() { + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = None; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = 443; + + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let peer: SocketAddr = "203.0.113.92:55092".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + let (client, server) = duplex(1024); + drop(client); + + let started = Instant::now(); + handle_bad_client( + server, + tokio::io::sink(), + b"GET / HTTP/1.1\r\n", + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + let elapsed = started.elapsed(); + assert!( + elapsed >= Duration::from_millis(40) && elapsed < Duration::from_millis(250), + "self-target reject path must keep coarse timing budget without stalling" + ); +} + +#[tokio::test] +async fn relay_path_idle_timeout_eviction_remains_effective() { + let (client_read, mut client_write) = duplex(1024); + let (mask_read, mask_write) = duplex(1024); + + tokio::spawn(async move { + tokio::time::sleep(Duration::from_millis(10)).await; + client_write.write_all(b"a").await.unwrap(); + tokio::time::sleep(Duration::from_millis(180)).await; + let _ = client_write.write_all(b"b").await; + }); + + let started = Instant::now(); + relay_to_mask( + client_read, + tokio::io::sink(), + mask_read, + mask_write, + b"init", + false, + 0, + 0, + false, + 0, + false, + 5 * 1024 * 1024, + ) + .await; + + let elapsed = started.elapsed(); + assert!( + elapsed >= Duration::from_millis(90) && elapsed < Duration::from_millis(180), + "idle-timeout eviction must occur before late trickle write" + ); +} + +#[tokio::test] +async fn offline_mask_target_refusal_respects_timing_normalization_budget() { + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = None; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = closed_local_port(); + config.censorship.mask_timing_normalization_enabled = true; + config.censorship.mask_timing_normalization_floor_ms = 120; + config.censorship.mask_timing_normalization_ceiling_ms = 120; + + let peer: SocketAddr = "203.0.113.93:55093".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + let (mut client, server) = duplex(1024); + let started = Instant::now(); + let task = tokio::spawn(async move { + handle_bad_client( + server, + tokio::io::sink(), + b"GET / HTTP/1.1\r\n\r\n", + peer, + local_addr, + &config, + &beobachten, + ) + .await; + }); + + client.shutdown().await.unwrap(); + timeout(Duration::from_secs(2), task) + .await + .unwrap() + .unwrap(); + let elapsed = started.elapsed(); + + assert!( + elapsed >= Duration::from_millis(110) && elapsed < Duration::from_millis(220), + "offline-refusal path must honor normalization budget without unbounded drift" + ); +} + +#[tokio::test] +async fn offline_mask_target_refusal_with_idle_client_is_bounded_by_consume_timeout() { + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = None; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = closed_local_port(); + config.censorship.mask_timing_normalization_enabled = false; + + let peer: SocketAddr = "203.0.113.94:55094".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + let (mut client, server) = duplex(1024); + let started = Instant::now(); + let task = tokio::spawn(async move { + handle_bad_client( + server, + tokio::io::sink(), + b"GET / HTTP/1.1\r\n\r\n", + peer, + local_addr, + &config, + &beobachten, + ) + .await; + }); + + tokio::time::sleep(Duration::from_millis(120)).await; + client + .write_all(b"still-open-before-timeout") + .await + .expect("connection should still be open before consume timeout expires"); + + timeout(Duration::from_secs(2), task) + .await + .unwrap() + .unwrap(); + let elapsed = started.elapsed(); + + assert!( + elapsed >= Duration::from_millis(190) && elapsed < Duration::from_millis(350), + "offline-refusal path must not retain idle client indefinitely" + ); +} diff --git a/src/proxy/tests/masking_shape_guard_adversarial_tests.rs b/src/proxy/tests/masking_shape_guard_adversarial_tests.rs index 982fd26..4fa8da7 100644 --- a/src/proxy/tests/masking_shape_guard_adversarial_tests.rs +++ b/src/proxy/tests/masking_shape_guard_adversarial_tests.rs @@ -43,6 +43,7 @@ async fn run_relay_case( above_cap_blur, above_cap_blur_max_bytes, false, + 5 * 1024 * 1024, ) .await; }); diff --git a/src/proxy/tests/masking_shape_hardening_adversarial_tests.rs b/src/proxy/tests/masking_shape_hardening_adversarial_tests.rs index 3c886ba..9abf3c0 100644 --- a/src/proxy/tests/masking_shape_hardening_adversarial_tests.rs +++ b/src/proxy/tests/masking_shape_hardening_adversarial_tests.rs @@ -88,6 +88,7 @@ async fn relay_to_mask_applies_cap_clamped_padding_for_non_power_of_two_cap() { false, 0, false, + 5 * 1024 * 1024, ) .await; }); diff --git a/src/proxy/tests/masking_timing_budget_coupling_security_tests.rs b/src/proxy/tests/masking_timing_budget_coupling_security_tests.rs new file mode 100644 index 0000000..fda6de7 --- /dev/null +++ b/src/proxy/tests/masking_timing_budget_coupling_security_tests.rs @@ -0,0 +1,58 @@ +#![cfg(unix)] + +use super::*; +use tokio::io::{AsyncWriteExt, duplex}; +use tokio::time::{Duration, Instant, timeout}; + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn adversarial_delayed_interface_lookup_does_not_consume_outcome_floor_budget() { + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = None; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = 443; + config.censorship.mask_timing_normalization_enabled = true; + config.censorship.mask_timing_normalization_floor_ms = 120; + config.censorship.mask_timing_normalization_ceiling_ms = 120; + + let peer: SocketAddr = "203.0.113.151:55151".parse().expect("valid peer"); + let local_addr: SocketAddr = "0.0.0.0:443".parse().expect("valid local addr"); + let beobachten = BeobachtenStore::new(); + + let refresh_lock = LOCAL_INTERFACE_REFRESH_LOCK.get_or_init(|| AsyncMutex::new(())); + let held_refresh_guard = refresh_lock.lock().await; + + let (mut client, server) = duplex(1024); + let started = Instant::now(); + let task = tokio::spawn(async move { + handle_bad_client( + server, + tokio::io::sink(), + b"GET / HTTP/1.1\r\n\r\n", + peer, + local_addr, + &config, + &beobachten, + ) + .await; + }); + + tokio::time::sleep(Duration::from_millis(80)).await; + drop(held_refresh_guard); + client + .shutdown() + .await + .expect("client shutdown must succeed"); + + timeout(Duration::from_secs(2), task) + .await + .expect("task must finish in bounded time") + .expect("task must not panic"); + let elapsed = started.elapsed(); + + assert!( + elapsed >= Duration::from_millis(180) && elapsed < Duration::from_millis(350), + "timing normalization floor must start after pre-outcome self-target checks" + ); +} diff --git a/src/proxy/tests/middle_relay_atomic_quota_invariant_tests.rs b/src/proxy/tests/middle_relay_atomic_quota_invariant_tests.rs new file mode 100644 index 0000000..7c176bc --- /dev/null +++ b/src/proxy/tests/middle_relay_atomic_quota_invariant_tests.rs @@ -0,0 +1,189 @@ +use super::*; +use crate::crypto::AesCtr; +use bytes::Bytes; +use std::io; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering}; +use std::task::{Context, Poll}; +use tokio::io::AsyncWrite; + +struct CountedWriter { + write_calls: Arc, + fail_writes: bool, +} + +impl CountedWriter { + fn new(write_calls: Arc, fail_writes: bool) -> Self { + Self { + write_calls, + fail_writes, + } + } +} + +impl AsyncWrite for CountedWriter { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let this = self.get_mut(); + this.write_calls.fetch_add(1, Ordering::Relaxed); + if this.fail_writes { + Poll::Ready(Err(io::Error::new( + io::ErrorKind::BrokenPipe, + "forced write failure", + ))) + } else { + Poll::Ready(Ok(buf.len())) + } + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} + +fn make_crypto_writer(inner: CountedWriter) -> CryptoWriter { + let key = [0u8; 32]; + let iv = 0u128; + CryptoWriter::new(inner, AesCtr::new(&key, iv), 8 * 1024) +} + +#[tokio::test] +async fn me_writer_write_fail_keeps_reserved_quota_and_tracks_fail_metrics() { + let stats = Stats::new(); + let user = "middle-me-writer-no-rollback-user"; + let user_stats = stats.get_or_create_user_stats_handle(user); + let write_calls = Arc::new(AtomicUsize::new(0)); + let mut writer = make_crypto_writer(CountedWriter::new(write_calls.clone(), true)); + let mut frame_buf = Vec::new(); + let bytes_me2c = AtomicU64::new(0); + let payload = Bytes::from_static(&[0x11, 0x22, 0x33, 0x44, 0x55]); + + let result = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: payload.clone(), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + user, + Some(user_stats.as_ref()), + Some(64), + 0, + &bytes_me2c, + 11, + true, + false, + ) + .await; + + assert!( + matches!(result, Err(ProxyError::Io(_))), + "write failure must propagate as I/O error" + ); + assert!( + write_calls.load(Ordering::Relaxed) > 0, + "writer must be attempted after successful quota reservation" + ); + assert_eq!( + stats.get_user_quota_used(user), + payload.len() as u64, + "reserved quota must not roll back on write failure" + ); + assert_eq!( + stats.get_quota_write_fail_bytes_total(), + payload.len() as u64, + "write-fail byte metric must include failed payload size" + ); + assert_eq!( + stats.get_quota_write_fail_events_total(), + 1, + "write-fail events metric must increment once" + ); + assert_eq!( + stats.get_user_total_octets(user), + 0, + "telemetry octets_to should not advance when write fails" + ); + assert_eq!( + bytes_me2c.load(Ordering::Relaxed), + 0, + "ME->C committed byte counter must not advance on write failure" + ); +} + +#[tokio::test] +async fn me_writer_pre_write_quota_reject_happens_before_writer_poll() { + let stats = Stats::new(); + let user = "middle-me-writer-precheck-user"; + let limit = 8u64; + let user_stats = stats.get_or_create_user_stats_handle(user); + stats.quota_charge_post_write(user_stats.as_ref(), limit); + + let write_calls = Arc::new(AtomicUsize::new(0)); + let mut writer = make_crypto_writer(CountedWriter::new(write_calls.clone(), false)); + let mut frame_buf = Vec::new(); + let bytes_me2c = AtomicU64::new(0); + + let result = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0xAA, 0xBB, 0xCC]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + user, + Some(user_stats.as_ref()), + Some(limit), + 0, + &bytes_me2c, + 12, + true, + false, + ) + .await; + + assert!( + matches!(result, Err(ProxyError::DataQuotaExceeded { .. })), + "pre-write quota rejection must return typed quota error" + ); + assert_eq!( + write_calls.load(Ordering::Relaxed), + 0, + "writer must not be polled when pre-write quota reservation fails" + ); + assert_eq!( + stats.get_me_d2c_quota_reject_pre_write_total(), + 1, + "pre-write quota reject metric must increment" + ); + assert_eq!( + stats.get_user_quota_used(user), + limit, + "failed pre-write reservation must keep previous quota usage unchanged" + ); + assert_eq!( + stats.get_quota_write_fail_bytes_total(), + 0, + "write-fail bytes metric must stay unchanged on pre-write reject" + ); + assert_eq!( + stats.get_quota_write_fail_events_total(), + 0, + "write-fail events metric must stay unchanged on pre-write reject" + ); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0); +} diff --git a/src/proxy/tests/middle_relay_blackhat_campaign_integration_tests.rs b/src/proxy/tests/middle_relay_blackhat_campaign_integration_tests.rs deleted file mode 100644 index 2c9f3f6..0000000 --- a/src/proxy/tests/middle_relay_blackhat_campaign_integration_tests.rs +++ /dev/null @@ -1,112 +0,0 @@ -use super::*; -use crate::stats::Stats; -use dashmap::DashMap; -use std::sync::Arc; -use std::sync::atomic::{AtomicUsize, Ordering}; -use tokio::sync::Barrier; -use tokio::time::{Duration, timeout}; - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn blackhat_campaign_saturation_quota_race_with_queue_pressure_stays_fail_closed() { - let _guard = super::quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - for idx in 0..QUOTA_USER_LOCKS_MAX { - retained.push(quota_user_lock(&format!( - "middle-blackhat-held-{}-{idx}", - std::process::id() - ))); - } - - assert_eq!( - map.len(), - QUOTA_USER_LOCKS_MAX, - "precondition: bounded lock cache must be saturated" - ); - - let (tx, _rx) = mpsc::channel::(1); - tx.send(C2MeCommand::Close) - .await - .expect("queue prefill should succeed"); - - let pressure_seq_before = relay_pressure_event_seq(); - let pressure_errors = Arc::new(AtomicUsize::new(0)); - let mut pressure_workers = Vec::new(); - for _ in 0..16 { - let tx = tx.clone(); - let pressure_errors = Arc::clone(&pressure_errors); - pressure_workers.push(tokio::spawn(async move { - if enqueue_c2me_command(&tx, C2MeCommand::Close).await.is_err() { - pressure_errors.fetch_add(1, Ordering::Relaxed); - } - })); - } - - let stats = Arc::new(Stats::new()); - let user = format!("middle-blackhat-quota-race-{}", std::process::id()); - let gate = Arc::new(Barrier::new(16)); - - let mut quota_workers = Vec::new(); - for _ in 0..16u8 { - let stats = Arc::clone(&stats); - let user = user.clone(); - let gate = Arc::clone(&gate); - quota_workers.push(tokio::spawn(async move { - gate.wait().await; - let user_lock = quota_user_lock(&user); - let _quota_guard = user_lock.lock().await; - - if quota_would_be_exceeded_for_user(&stats, &user, Some(1), 1) { - return false; - } - stats.add_user_octets_to(&user, 1); - true - })); - } - - let mut ok_count = 0usize; - let mut denied_count = 0usize; - for worker in quota_workers { - let result = timeout(Duration::from_secs(2), worker) - .await - .expect("quota worker must finish") - .expect("quota worker must not panic"); - if result { - ok_count += 1; - } else { - denied_count += 1; - } - } - - for worker in pressure_workers { - timeout(Duration::from_secs(2), worker) - .await - .expect("pressure worker must finish") - .expect("pressure worker must not panic"); - } - - assert_eq!( - stats.get_user_total_octets(&user), - 1, - "black-hat campaign must not overshoot same-user quota under saturation" - ); - assert!(ok_count <= 1, "at most one quota contender may succeed"); - assert!( - denied_count >= 15, - "all remaining contenders must be quota-denied" - ); - - let pressure_seq_after = relay_pressure_event_seq(); - assert!( - pressure_seq_after > pressure_seq_before, - "queue pressure leg must trigger pressure accounting" - ); - assert!( - pressure_errors.load(Ordering::Relaxed) >= 1, - "at least one pressure worker should fail from persistent backpressure" - ); - - drop(retained); -} diff --git a/src/proxy/tests/middle_relay_coverage_high_risk_security_tests.rs b/src/proxy/tests/middle_relay_coverage_high_risk_security_tests.rs deleted file mode 100644 index fff26b4..0000000 --- a/src/proxy/tests/middle_relay_coverage_high_risk_security_tests.rs +++ /dev/null @@ -1,708 +0,0 @@ -use super::*; -use crate::crypto::AesCtr; -use crate::crypto::SecureRandom; -use crate::stats::Stats; -use crate::stream::{BufferPool, PooledBuffer}; -use std::sync::Arc; -use tokio::io::AsyncReadExt; -use tokio::io::duplex; -use tokio::sync::mpsc; -use tokio::time::{Duration as TokioDuration, timeout}; - -fn make_pooled_payload(data: &[u8]) -> PooledBuffer { - let pool = Arc::new(BufferPool::with_config(data.len().max(1), 4)); - let mut payload = pool.get(); - payload.resize(data.len(), 0); - payload[..data.len()].copy_from_slice(data); - payload -} - -#[tokio::test] -async fn write_client_payload_abridged_short_quickack_sets_flag_and_preserves_payload() { - let (mut read_side, write_side) = duplex(4096); - let key = [0u8; 32]; - let iv = 0u128; - - let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); - let mut decryptor = AesCtr::new(&key, iv); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - let payload = vec![0xA1, 0xB2, 0xC3, 0xD4, 0x10, 0x20, 0x30, 0x40]; - - write_client_payload( - &mut writer, - ProtoTag::Abridged, - RPC_FLAG_QUICKACK, - &payload, - &rng, - &mut frame_buf, - ) - .await - .expect("abridged quickack payload should serialize"); - writer.flush().await.expect("flush must succeed"); - - let mut encrypted = vec![0u8; 1 + payload.len()]; - read_side - .read_exact(&mut encrypted) - .await - .expect("must read serialized abridged frame"); - let plaintext = decryptor.decrypt(&encrypted); - - assert_eq!(plaintext[0], 0x80 | ((payload.len() / 4) as u8)); - assert_eq!(&plaintext[1..], payload.as_slice()); -} - -#[tokio::test] -async fn write_client_payload_abridged_extended_header_is_encoded_correctly() { - let (mut read_side, write_side) = duplex(16 * 1024); - let key = [0u8; 32]; - let iv = 0u128; - - let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); - let mut decryptor = AesCtr::new(&key, iv); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - - // Boundary where abridged switches to extended length encoding. - let payload = vec![0x5Au8; 0x7f * 4]; - - write_client_payload( - &mut writer, - ProtoTag::Abridged, - RPC_FLAG_QUICKACK, - &payload, - &rng, - &mut frame_buf, - ) - .await - .expect("extended abridged payload should serialize"); - writer.flush().await.expect("flush must succeed"); - - let mut encrypted = vec![0u8; 4 + payload.len()]; - read_side - .read_exact(&mut encrypted) - .await - .expect("must read serialized extended abridged frame"); - let plaintext = decryptor.decrypt(&encrypted); - - assert_eq!(plaintext[0], 0xff, "0x7f with quickack bit must be set"); - assert_eq!(&plaintext[1..4], &[0x7f, 0x00, 0x00]); - assert_eq!(&plaintext[4..], payload.as_slice()); -} - -#[tokio::test] -async fn write_client_payload_abridged_misaligned_is_rejected_fail_closed() { - let (_read_side, write_side) = duplex(1024); - let key = [0u8; 32]; - let iv = 0u128; - - let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - - let err = write_client_payload( - &mut writer, - ProtoTag::Abridged, - 0, - &[1, 2, 3], - &rng, - &mut frame_buf, - ) - .await - .expect_err("misaligned abridged payload must be rejected"); - - let msg = format!("{err}"); - assert!( - msg.contains("4-byte aligned"), - "error should explain alignment contract, got: {msg}" - ); -} - -#[tokio::test] -async fn write_client_payload_secure_misaligned_is_rejected_fail_closed() { - let (_read_side, write_side) = duplex(1024); - let key = [0u8; 32]; - let iv = 0u128; - - let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - - let err = write_client_payload( - &mut writer, - ProtoTag::Secure, - 0, - &[9, 8, 7, 6, 5], - &rng, - &mut frame_buf, - ) - .await - .expect_err("misaligned secure payload must be rejected"); - - let msg = format!("{err}"); - assert!( - msg.contains("Secure payload must be 4-byte aligned"), - "error should be explicit for fail-closed triage, got: {msg}" - ); -} - -#[tokio::test] -async fn write_client_payload_intermediate_quickack_sets_length_msb() { - let (mut read_side, write_side) = duplex(4096); - let key = [0u8; 32]; - let iv = 0u128; - - let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); - let mut decryptor = AesCtr::new(&key, iv); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - let payload = b"hello-middle-relay"; - - write_client_payload( - &mut writer, - ProtoTag::Intermediate, - RPC_FLAG_QUICKACK, - payload, - &rng, - &mut frame_buf, - ) - .await - .expect("intermediate quickack payload should serialize"); - writer.flush().await.expect("flush must succeed"); - - let mut encrypted = vec![0u8; 4 + payload.len()]; - read_side - .read_exact(&mut encrypted) - .await - .expect("must read intermediate frame"); - let plaintext = decryptor.decrypt(&encrypted); - - let mut len_bytes = [0u8; 4]; - len_bytes.copy_from_slice(&plaintext[..4]); - let len_with_flags = u32::from_le_bytes(len_bytes); - assert_ne!(len_with_flags & 0x8000_0000, 0, "quickack bit must be set"); - assert_eq!((len_with_flags & 0x7fff_ffff) as usize, payload.len()); - assert_eq!(&plaintext[4..], payload); -} - -#[tokio::test] -async fn write_client_payload_secure_quickack_prefix_and_padding_bounds_hold() { - let (mut read_side, write_side) = duplex(4096); - let key = [0u8; 32]; - let iv = 0u128; - - let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); - let mut decryptor = AesCtr::new(&key, iv); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - let payload = vec![0x33u8; 100]; // 4-byte aligned as required by secure mode. - - write_client_payload( - &mut writer, - ProtoTag::Secure, - RPC_FLAG_QUICKACK, - &payload, - &rng, - &mut frame_buf, - ) - .await - .expect("secure quickack payload should serialize"); - writer.flush().await.expect("flush must succeed"); - - // Secure mode adds 1..=3 bytes of randomized tail padding. - let mut encrypted_header = [0u8; 4]; - read_side - .read_exact(&mut encrypted_header) - .await - .expect("must read secure header"); - let decrypted_header = decryptor.decrypt(&encrypted_header); - let header: [u8; 4] = decrypted_header - .try_into() - .expect("decrypted secure header must be 4 bytes"); - let wire_len_raw = u32::from_le_bytes(header); - - assert_ne!( - wire_len_raw & 0x8000_0000, - 0, - "secure quickack bit must be set" - ); - - let wire_len = (wire_len_raw & 0x7fff_ffff) as usize; - assert!(wire_len >= payload.len()); - let padding_len = wire_len - payload.len(); - assert!( - (1..=3).contains(&padding_len), - "secure writer must add bounded random tail padding, got {padding_len}" - ); - - let mut encrypted_body = vec![0u8; wire_len]; - read_side - .read_exact(&mut encrypted_body) - .await - .expect("must read secure body"); - let decrypted_body = decryptor.decrypt(&encrypted_body); - assert_eq!(&decrypted_body[..payload.len()], payload.as_slice()); -} - -#[tokio::test] -#[ignore = "heavy: allocates >64MiB to validate abridged too-large fail-closed branch"] -async fn write_client_payload_abridged_too_large_is_rejected_fail_closed() { - let (_read_side, write_side) = duplex(1024); - let key = [0u8; 32]; - let iv = 0u128; - - let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - - // Exactly one 4-byte word above the encodable 24-bit abridged length range. - let payload = vec![0x00u8; (1 << 24) * 4]; - let err = write_client_payload( - &mut writer, - ProtoTag::Abridged, - 0, - &payload, - &rng, - &mut frame_buf, - ) - .await - .expect_err("oversized abridged payload must be rejected"); - - let msg = format!("{err}"); - assert!( - msg.contains("Abridged frame too large"), - "error must clearly indicate oversize fail-close path, got: {msg}" - ); -} - -#[tokio::test] -async fn write_client_ack_intermediate_is_little_endian() { - let (mut read_side, write_side) = duplex(1024); - let key = [0u8; 32]; - let iv = 0u128; - let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); - let mut decryptor = AesCtr::new(&key, iv); - - write_client_ack(&mut writer, ProtoTag::Intermediate, 0x11_22_33_44) - .await - .expect("ack serialization should succeed"); - writer.flush().await.expect("flush must succeed"); - - let mut encrypted = [0u8; 4]; - read_side - .read_exact(&mut encrypted) - .await - .expect("must read ack bytes"); - let plain = decryptor.decrypt(&encrypted); - assert_eq!(plain.as_slice(), &0x11_22_33_44u32.to_le_bytes()); -} - -#[tokio::test] -async fn write_client_ack_abridged_is_big_endian() { - let (mut read_side, write_side) = duplex(1024); - let key = [0u8; 32]; - let iv = 0u128; - let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); - let mut decryptor = AesCtr::new(&key, iv); - - write_client_ack(&mut writer, ProtoTag::Abridged, 0xDE_AD_BE_EF) - .await - .expect("ack serialization should succeed"); - writer.flush().await.expect("flush must succeed"); - - let mut encrypted = [0u8; 4]; - read_side - .read_exact(&mut encrypted) - .await - .expect("must read ack bytes"); - let plain = decryptor.decrypt(&encrypted); - assert_eq!(plain.as_slice(), &0xDE_AD_BE_EFu32.to_be_bytes()); -} - -#[tokio::test] -async fn write_client_payload_abridged_short_boundary_0x7e_is_single_byte_header() { - let (mut read_side, write_side) = duplex(1024 * 1024); - let key = [0u8; 32]; - let iv = 0u128; - let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); - let mut decryptor = AesCtr::new(&key, iv); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - let payload = vec![0xABu8; 0x7e * 4]; - - write_client_payload( - &mut writer, - ProtoTag::Abridged, - 0, - &payload, - &rng, - &mut frame_buf, - ) - .await - .expect("boundary payload should serialize"); - writer.flush().await.expect("flush must succeed"); - - let mut encrypted = vec![0u8; 1 + payload.len()]; - read_side.read_exact(&mut encrypted).await.unwrap(); - let plain = decryptor.decrypt(&encrypted); - assert_eq!(plain[0], 0x7e); - assert_eq!(&plain[1..], payload.as_slice()); -} - -#[tokio::test] -async fn write_client_payload_abridged_extended_without_quickack_has_clean_prefix() { - let (mut read_side, write_side) = duplex(16 * 1024); - let key = [0u8; 32]; - let iv = 0u128; - let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); - let mut decryptor = AesCtr::new(&key, iv); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - let payload = vec![0x42u8; 0x80 * 4]; - - write_client_payload( - &mut writer, - ProtoTag::Abridged, - 0, - &payload, - &rng, - &mut frame_buf, - ) - .await - .expect("extended payload should serialize"); - writer.flush().await.expect("flush must succeed"); - - let mut encrypted = vec![0u8; 4 + payload.len()]; - read_side.read_exact(&mut encrypted).await.unwrap(); - let plain = decryptor.decrypt(&encrypted); - assert_eq!(plain[0], 0x7f); - assert_eq!(&plain[1..4], &[0x80, 0x00, 0x00]); - assert_eq!(&plain[4..], payload.as_slice()); -} - -#[tokio::test] -async fn write_client_payload_intermediate_zero_length_emits_header_only() { - let (mut read_side, write_side) = duplex(1024); - let key = [0u8; 32]; - let iv = 0u128; - let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); - let mut decryptor = AesCtr::new(&key, iv); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - - write_client_payload( - &mut writer, - ProtoTag::Intermediate, - 0, - &[], - &rng, - &mut frame_buf, - ) - .await - .expect("zero-length intermediate payload should serialize"); - writer.flush().await.expect("flush must succeed"); - - let mut encrypted = [0u8; 4]; - read_side.read_exact(&mut encrypted).await.unwrap(); - let plain = decryptor.decrypt(&encrypted); - assert_eq!(plain.as_slice(), &[0, 0, 0, 0]); -} - -#[tokio::test] -async fn write_client_payload_intermediate_ignores_unrelated_flags() { - let (mut read_side, write_side) = duplex(1024); - let key = [0u8; 32]; - let iv = 0u128; - let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); - let mut decryptor = AesCtr::new(&key, iv); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - let payload = [7u8; 12]; - - write_client_payload( - &mut writer, - ProtoTag::Intermediate, - 0x4000_0000, - &payload, - &rng, - &mut frame_buf, - ) - .await - .expect("payload should serialize"); - writer.flush().await.expect("flush must succeed"); - - let mut encrypted = [0u8; 16]; - read_side.read_exact(&mut encrypted).await.unwrap(); - let plain = decryptor.decrypt(&encrypted); - let len = u32::from_le_bytes(plain[0..4].try_into().unwrap()); - assert_eq!(len, payload.len() as u32, "only quickack bit may affect header"); - assert_eq!(&plain[4..], payload.as_slice()); -} - -#[tokio::test] -async fn write_client_payload_secure_without_quickack_keeps_msb_clear() { - let (mut read_side, write_side) = duplex(4096); - let key = [0u8; 32]; - let iv = 0u128; - let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); - let mut decryptor = AesCtr::new(&key, iv); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - let payload = [0x1Du8; 64]; - - write_client_payload( - &mut writer, - ProtoTag::Secure, - 0, - &payload, - &rng, - &mut frame_buf, - ) - .await - .expect("payload should serialize"); - writer.flush().await.expect("flush must succeed"); - - let mut encrypted_header = [0u8; 4]; - read_side.read_exact(&mut encrypted_header).await.unwrap(); - let plain_header = decryptor.decrypt(&encrypted_header); - let h: [u8; 4] = plain_header.as_slice().try_into().unwrap(); - let wire_len_raw = u32::from_le_bytes(h); - assert_eq!(wire_len_raw & 0x8000_0000, 0, "quickack bit must stay clear"); -} - -#[tokio::test] -async fn secure_padding_light_fuzz_distribution_has_multiple_outcomes() { - let (mut read_side, write_side) = duplex(256 * 1024); - let key = [0u8; 32]; - let iv = 0u128; - let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); - let mut decryptor = AesCtr::new(&key, iv); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - let payload = [0x55u8; 100]; - let mut seen = [false; 4]; - - for _ in 0..96 { - write_client_payload( - &mut writer, - ProtoTag::Secure, - 0, - &payload, - &rng, - &mut frame_buf, - ) - .await - .expect("secure payload should serialize"); - writer.flush().await.expect("flush must succeed"); - - let mut encrypted_header = [0u8; 4]; - read_side.read_exact(&mut encrypted_header).await.unwrap(); - let plain_header = decryptor.decrypt(&encrypted_header); - let h: [u8; 4] = plain_header.as_slice().try_into().unwrap(); - let wire_len = (u32::from_le_bytes(h) & 0x7fff_ffff) as usize; - let padding_len = wire_len - payload.len(); - assert!((1..=3).contains(&padding_len)); - seen[padding_len] = true; - - let mut encrypted_body = vec![0u8; wire_len]; - read_side.read_exact(&mut encrypted_body).await.unwrap(); - let _ = decryptor.decrypt(&encrypted_body); - } - - let distinct = (1..=3).filter(|idx| seen[*idx]).count(); - assert!( - distinct >= 2, - "padding generator should not collapse to a single outcome under campaign" - ); -} - -#[tokio::test] -async fn write_client_payload_mixed_proto_sequence_preserves_stream_sync() { - let (mut read_side, write_side) = duplex(128 * 1024); - let key = [0u8; 32]; - let iv = 0u128; - let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); - let mut decryptor = AesCtr::new(&key, iv); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - - let p1 = vec![1u8; 8]; - let p2 = vec![2u8; 16]; - let p3 = vec![3u8; 20]; - - write_client_payload(&mut writer, ProtoTag::Abridged, 0, &p1, &rng, &mut frame_buf) - .await - .unwrap(); - write_client_payload( - &mut writer, - ProtoTag::Intermediate, - RPC_FLAG_QUICKACK, - &p2, - &rng, - &mut frame_buf, - ) - .await - .unwrap(); - write_client_payload(&mut writer, ProtoTag::Secure, 0, &p3, &rng, &mut frame_buf) - .await - .unwrap(); - writer.flush().await.unwrap(); - - // Frame 1: abridged short. - let mut e1 = vec![0u8; 1 + p1.len()]; - read_side.read_exact(&mut e1).await.unwrap(); - let d1 = decryptor.decrypt(&e1); - assert_eq!(d1[0], (p1.len() / 4) as u8); - assert_eq!(&d1[1..], p1.as_slice()); - - // Frame 2: intermediate with quickack. - let mut e2 = vec![0u8; 4 + p2.len()]; - read_side.read_exact(&mut e2).await.unwrap(); - let d2 = decryptor.decrypt(&e2); - let l2 = u32::from_le_bytes(d2[0..4].try_into().unwrap()); - assert_ne!(l2 & 0x8000_0000, 0); - assert_eq!((l2 & 0x7fff_ffff) as usize, p2.len()); - assert_eq!(&d2[4..], p2.as_slice()); - - // Frame 3: secure with bounded tail. - let mut e3h = [0u8; 4]; - read_side.read_exact(&mut e3h).await.unwrap(); - let d3h = decryptor.decrypt(&e3h); - let l3 = (u32::from_le_bytes(d3h.as_slice().try_into().unwrap()) & 0x7fff_ffff) as usize; - assert!(l3 >= p3.len()); - assert!((1..=3).contains(&(l3 - p3.len()))); - let mut e3b = vec![0u8; l3]; - read_side.read_exact(&mut e3b).await.unwrap(); - let d3b = decryptor.decrypt(&e3b); - assert_eq!(&d3b[..p3.len()], p3.as_slice()); -} - -#[test] -fn should_yield_sender_boundary_matrix_blackhat() { - assert!(!should_yield_c2me_sender(0, false)); - assert!(!should_yield_c2me_sender(0, true)); - assert!(!should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET - 1, true)); - assert!(!should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET, false)); - assert!(should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET, true)); - assert!(should_yield_c2me_sender( - C2ME_SENDER_FAIRNESS_BUDGET.saturating_add(1024), - true - )); -} - -#[test] -fn should_yield_sender_light_fuzz_matches_oracle() { - let mut s: u64 = 0xD00D_BAAD_F00D_CAFE; - for _ in 0..5000 { - s ^= s << 7; - s ^= s >> 9; - s ^= s << 8; - let sent = (s as usize) & 0x1fff; - let backlog = (s & 1) != 0; - - let expected = backlog && sent >= C2ME_SENDER_FAIRNESS_BUDGET; - assert_eq!(should_yield_c2me_sender(sent, backlog), expected); - } -} - -#[test] -fn quota_would_be_exceeded_exact_remaining_one_byte() { - let stats = Stats::new(); - let user = "quota-edge"; - let quota = 100u64; - stats.add_user_octets_to(user, 99); - - assert!( - !quota_would_be_exceeded_for_user(&stats, user, Some(quota), 1), - "exactly remaining budget should be allowed" - ); - assert!( - quota_would_be_exceeded_for_user(&stats, user, Some(quota), 2), - "one byte beyond remaining budget must be rejected" - ); -} - -#[test] -fn quota_would_be_exceeded_saturating_edge_remains_fail_closed() { - let stats = Stats::new(); - let user = "quota-saturating-edge"; - let quota = u64::MAX - 3; - stats.add_user_octets_to(user, u64::MAX - 4); - - assert!( - quota_would_be_exceeded_for_user(&stats, user, Some(quota), 2), - "saturating arithmetic edge must stay fail-closed" - ); -} - -#[test] -fn quota_exceeded_boundary_is_inclusive() { - let stats = Stats::new(); - let user = "quota-inclusive-boundary"; - stats.add_user_octets_to(user, 50); - - assert!(quota_exceeded_for_user(&stats, user, Some(50))); - assert!(!quota_exceeded_for_user(&stats, user, Some(51))); -} - -#[tokio::test] -async fn enqueue_c2me_close_fast_path_succeeds_without_backpressure() { - let (tx, mut rx) = mpsc::channel::(4); - enqueue_c2me_command(&tx, C2MeCommand::Close) - .await - .expect("close should enqueue on fast path"); - - let recv = timeout(TokioDuration::from_millis(50), rx.recv()) - .await - .expect("must receive close command") - .expect("close command should be present"); - assert!(matches!(recv, C2MeCommand::Close)); -} - -#[tokio::test] -async fn enqueue_c2me_data_full_then_drain_preserves_order() { - let (tx, mut rx) = mpsc::channel::(1); - tx.send(C2MeCommand::Data { - payload: make_pooled_payload(&[1]), - flags: 10, - }) - .await - .unwrap(); - - let tx2 = tx.clone(); - let producer = tokio::spawn(async move { - enqueue_c2me_command( - &tx2, - C2MeCommand::Data { - payload: make_pooled_payload(&[2, 2]), - flags: 20, - }, - ) - .await - }); - - tokio::time::sleep(TokioDuration::from_millis(10)).await; - - let first = rx.recv().await.expect("first item should exist"); - match first { - C2MeCommand::Data { payload, flags } => { - assert_eq!(payload.as_ref(), &[1]); - assert_eq!(flags, 10); - } - C2MeCommand::Close => panic!("unexpected close as first item"), - } - - producer.await.unwrap().expect("producer should complete"); - - let second = timeout(TokioDuration::from_millis(100), rx.recv()) - .await - .unwrap() - .expect("second item should exist"); - match second { - C2MeCommand::Data { payload, flags } => { - assert_eq!(payload.as_ref(), &[2, 2]); - assert_eq!(flags, 20); - } - C2MeCommand::Close => panic!("unexpected close as second item"), - } -} diff --git a/src/proxy/tests/middle_relay_idle_policy_security_tests.rs b/src/proxy/tests/middle_relay_idle_policy_security_tests.rs index 3e0b30f..fd3243d 100644 --- a/src/proxy/tests/middle_relay_idle_policy_security_tests.rs +++ b/src/proxy/tests/middle_relay_idle_policy_security_tests.rs @@ -2,8 +2,8 @@ use super::*; use crate::crypto::AesCtr; use crate::stats::Stats; use crate::stream::{BufferPool, CryptoReader}; +use std::sync::Arc; use std::sync::atomic::AtomicU64; -use std::sync::{Arc, Mutex, OnceLock}; use tokio::io::AsyncWriteExt; use tokio::io::duplex; use tokio::time::{Duration as TokioDuration, Instant as TokioInstant, timeout}; @@ -48,18 +48,6 @@ fn make_idle_policy(soft_ms: u64, hard_ms: u64, grace_ms: u64) -> RelayClientIdl } } -fn idle_pressure_test_lock() -> &'static Mutex<()> { - static TEST_LOCK: OnceLock> = OnceLock::new(); - TEST_LOCK.get_or_init(|| Mutex::new(())) -} - -fn acquire_idle_pressure_test_lock() -> std::sync::MutexGuard<'static, ()> { - match idle_pressure_test_lock().lock() { - Ok(guard) => guard, - Err(poisoned) => poisoned.into_inner(), - } -} - #[tokio::test] async fn idle_policy_soft_mark_then_hard_close_increments_reason_counters() { let (reader, _writer) = duplex(1024); @@ -372,7 +360,7 @@ async fn stress_many_idle_sessions_fail_closed_without_hang() { #[test] fn pressure_evicts_oldest_idle_candidate_with_deterministic_ordering() { - let _guard = acquire_idle_pressure_test_lock(); + let _guard = relay_idle_pressure_test_scope(); clear_relay_idle_pressure_state_for_testing(); let stats = Stats::new(); @@ -402,7 +390,7 @@ fn pressure_evicts_oldest_idle_candidate_with_deterministic_ordering() { #[test] fn pressure_does_not_evict_without_new_pressure_signal() { - let _guard = acquire_idle_pressure_test_lock(); + let _guard = relay_idle_pressure_test_scope(); clear_relay_idle_pressure_state_for_testing(); let stats = Stats::new(); @@ -421,7 +409,7 @@ fn pressure_does_not_evict_without_new_pressure_signal() { #[test] fn stress_pressure_eviction_preserves_fifo_across_many_candidates() { - let _guard = acquire_idle_pressure_test_lock(); + let _guard = relay_idle_pressure_test_scope(); clear_relay_idle_pressure_state_for_testing(); let stats = Stats::new(); @@ -457,7 +445,7 @@ fn stress_pressure_eviction_preserves_fifo_across_many_candidates() { #[test] fn blackhat_single_pressure_event_must_not_evict_more_than_one_candidate() { - let _guard = acquire_idle_pressure_test_lock(); + let _guard = relay_idle_pressure_test_scope(); clear_relay_idle_pressure_state_for_testing(); let stats = Stats::new(); @@ -491,7 +479,7 @@ fn blackhat_single_pressure_event_must_not_evict_more_than_one_candidate() { #[test] fn blackhat_pressure_counter_must_track_global_budget_not_per_session_cursor() { - let _guard = acquire_idle_pressure_test_lock(); + let _guard = relay_idle_pressure_test_scope(); clear_relay_idle_pressure_state_for_testing(); let stats = Stats::new(); @@ -524,7 +512,7 @@ fn blackhat_pressure_counter_must_track_global_budget_not_per_session_cursor() { #[test] fn blackhat_stale_pressure_before_idle_mark_must_not_trigger_eviction() { - let _guard = acquire_idle_pressure_test_lock(); + let _guard = relay_idle_pressure_test_scope(); clear_relay_idle_pressure_state_for_testing(); let stats = Stats::new(); @@ -543,7 +531,7 @@ fn blackhat_stale_pressure_before_idle_mark_must_not_trigger_eviction() { #[test] fn blackhat_stale_pressure_must_not_evict_any_of_newly_marked_batch() { - let _guard = acquire_idle_pressure_test_lock(); + let _guard = relay_idle_pressure_test_scope(); clear_relay_idle_pressure_state_for_testing(); let stats = Stats::new(); @@ -575,7 +563,7 @@ fn blackhat_stale_pressure_must_not_evict_any_of_newly_marked_batch() { #[test] fn blackhat_stale_pressure_seen_without_candidates_must_be_globally_invalidated() { - let _guard = acquire_idle_pressure_test_lock(); + let _guard = relay_idle_pressure_test_scope(); clear_relay_idle_pressure_state_for_testing(); let stats = Stats::new(); @@ -601,7 +589,7 @@ fn blackhat_stale_pressure_seen_without_candidates_must_be_globally_invalidated( #[test] fn blackhat_stale_pressure_must_not_survive_candidate_churn() { - let _guard = acquire_idle_pressure_test_lock(); + let _guard = relay_idle_pressure_test_scope(); clear_relay_idle_pressure_state_for_testing(); let stats = Stats::new(); @@ -621,7 +609,7 @@ fn blackhat_stale_pressure_must_not_survive_candidate_churn() { #[test] fn blackhat_pressure_seq_saturation_must_not_disable_future_pressure_accounting() { - let _guard = acquire_idle_pressure_test_lock(); + let _guard = relay_idle_pressure_test_scope(); clear_relay_idle_pressure_state_for_testing(); { @@ -646,7 +634,7 @@ fn blackhat_pressure_seq_saturation_must_not_disable_future_pressure_accounting( #[test] fn blackhat_pressure_seq_saturation_must_not_break_multiple_distinct_events() { - let _guard = acquire_idle_pressure_test_lock(); + let _guard = relay_idle_pressure_test_scope(); clear_relay_idle_pressure_state_for_testing(); { @@ -673,7 +661,7 @@ fn blackhat_pressure_seq_saturation_must_not_break_multiple_distinct_events() { #[tokio::test(flavor = "multi_thread", worker_threads = 4)] async fn integration_race_single_pressure_event_allows_at_most_one_eviction_under_parallel_claims() { - let _guard = acquire_idle_pressure_test_lock(); + let _guard = relay_idle_pressure_test_scope(); clear_relay_idle_pressure_state_for_testing(); let stats = Arc::new(Stats::new()); @@ -738,7 +726,7 @@ async fn integration_race_single_pressure_event_allows_at_most_one_eviction_unde #[tokio::test(flavor = "multi_thread", worker_threads = 4)] async fn integration_race_burst_pressure_with_churn_preserves_empty_set_invalidation_and_budget() { - let _guard = acquire_idle_pressure_test_lock(); + let _guard = relay_idle_pressure_test_scope(); clear_relay_idle_pressure_state_for_testing(); let stats = Arc::new(Stats::new()); diff --git a/src/proxy/tests/middle_relay_idle_registry_poison_security_tests.rs b/src/proxy/tests/middle_relay_idle_registry_poison_security_tests.rs new file mode 100644 index 0000000..b43825c --- /dev/null +++ b/src/proxy/tests/middle_relay_idle_registry_poison_security_tests.rs @@ -0,0 +1,62 @@ +use super::*; +use std::panic::{AssertUnwindSafe, catch_unwind}; + +#[test] +fn blackhat_registry_poison_recovers_with_fail_closed_reset_and_pressure_accounting() { + let _guard = relay_idle_pressure_test_scope(); + clear_relay_idle_pressure_state_for_testing(); + + let _ = catch_unwind(AssertUnwindSafe(|| { + let registry = relay_idle_candidate_registry(); + let mut guard = registry + .lock() + .expect("registry lock must be acquired before poison"); + guard.by_conn_id.insert( + 999, + RelayIdleCandidateMeta { + mark_order_seq: 1, + mark_pressure_seq: 0, + }, + ); + guard.ordered.insert((1, 999)); + panic!("intentional poison for idle-registry recovery"); + })); + + // Helper lock must recover from poison, reset stale state, and continue. + assert!(mark_relay_idle_candidate(42)); + assert_eq!(oldest_relay_idle_candidate(), Some(42)); + + let before = relay_pressure_event_seq(); + note_relay_pressure_event(); + let after = relay_pressure_event_seq(); + assert!( + after > before, + "pressure accounting must still advance after poison" + ); + + clear_relay_idle_pressure_state_for_testing(); +} + +#[test] +fn clear_state_helper_must_reset_poisoned_registry_for_deterministic_fifo_tests() { + let _guard = relay_idle_pressure_test_scope(); + clear_relay_idle_pressure_state_for_testing(); + + let _ = catch_unwind(AssertUnwindSafe(|| { + let registry = relay_idle_candidate_registry(); + let _guard = registry + .lock() + .expect("registry lock must be acquired before poison"); + panic!("intentional poison while lock held"); + })); + + clear_relay_idle_pressure_state_for_testing(); + + assert_eq!(oldest_relay_idle_candidate(), None); + assert_eq!(relay_pressure_event_seq(), 0); + + assert!(mark_relay_idle_candidate(7)); + assert_eq!(oldest_relay_idle_candidate(), Some(7)); + + clear_relay_idle_pressure_state_for_testing(); +} diff --git a/src/proxy/tests/middle_relay_quota_overflow_lock_security_tests.rs b/src/proxy/tests/middle_relay_quota_overflow_lock_security_tests.rs deleted file mode 100644 index d06e103..0000000 --- a/src/proxy/tests/middle_relay_quota_overflow_lock_security_tests.rs +++ /dev/null @@ -1,131 +0,0 @@ -use super::*; -use dashmap::DashMap; -use std::sync::Arc; - -#[test] -fn saturation_uses_stable_overflow_lock_without_cache_growth() { - let _guard = super::quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let prefix = format!("middle-quota-held-{}", std::process::id()); - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - for idx in 0..QUOTA_USER_LOCKS_MAX { - retained.push(quota_user_lock(&format!("{prefix}-{idx}"))); - } - - assert_eq!(map.len(), QUOTA_USER_LOCKS_MAX); - - let user = format!("middle-quota-overflow-{}", std::process::id()); - let first = quota_user_lock(&user); - let second = quota_user_lock(&user); - - assert!( - Arc::ptr_eq(&first, &second), - "overflow user must get deterministic same lock while cache is saturated" - ); - assert_eq!( - map.len(), - QUOTA_USER_LOCKS_MAX, - "overflow path must not grow bounded lock map" - ); - assert!( - map.get(&user).is_none(), - "overflow user should stay outside bounded lock map under saturation" - ); - - drop(retained); -} - -#[test] -fn overflow_striping_keeps_different_users_distributed() { - let _guard = super::quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let prefix = format!("middle-quota-dist-held-{}", std::process::id()); - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - for idx in 0..QUOTA_USER_LOCKS_MAX { - retained.push(quota_user_lock(&format!("{prefix}-{idx}"))); - } - - let a = quota_user_lock("middle-overflow-user-a"); - let b = quota_user_lock("middle-overflow-user-b"); - let c = quota_user_lock("middle-overflow-user-c"); - - let distinct = [ - Arc::as_ptr(&a) as usize, - Arc::as_ptr(&b) as usize, - Arc::as_ptr(&c) as usize, - ] - .iter() - .copied() - .collect::>() - .len(); - - assert!( - distinct >= 2, - "striped overflow lock set should avoid collapsing all users to one lock" - ); - - drop(retained); -} - -#[test] -fn reclaim_path_caches_new_user_after_stale_entries_drop() { - let _guard = super::quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let prefix = format!("middle-quota-reclaim-held-{}", std::process::id()); - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - for idx in 0..QUOTA_USER_LOCKS_MAX { - retained.push(quota_user_lock(&format!("{prefix}-{idx}"))); - } - - drop(retained); - - let user = format!("middle-quota-reclaim-user-{}", std::process::id()); - let got = quota_user_lock(&user); - assert!(map.get(&user).is_some()); - assert!( - Arc::strong_count(&got) >= 2, - "after reclaim, lock should be held both by caller and map" - ); -} - -#[test] -fn overflow_path_same_user_is_stable_across_parallel_threads() { - let _guard = super::quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - for idx in 0..QUOTA_USER_LOCKS_MAX { - retained.push(quota_user_lock(&format!( - "middle-quota-thread-held-{}-{idx}", - std::process::id() - ))); - } - - let user = format!("middle-quota-overflow-thread-user-{}", std::process::id()); - let mut workers = Vec::new(); - for _ in 0..32 { - let user = user.clone(); - workers.push(std::thread::spawn(move || quota_user_lock(&user))); - } - - let first = workers - .remove(0) - .join() - .expect("thread must return lock handle"); - for worker in workers { - let got = worker.join().expect("thread must return lock handle"); - assert!( - Arc::ptr_eq(&first, &got), - "same overflow user should resolve to one striped lock even under contention" - ); - } - - drop(retained); -} diff --git a/src/proxy/tests/middle_relay_security_tests.rs b/src/proxy/tests/middle_relay_security_tests.rs deleted file mode 100644 index 1d3b736..0000000 --- a/src/proxy/tests/middle_relay_security_tests.rs +++ /dev/null @@ -1,2517 +0,0 @@ -use super::*; -use crate::config::{GeneralConfig, MeRouteNoWriterMode, MeSocksKdfPolicy, MeWriterPickMode}; -use crate::crypto::AesCtr; -use crate::crypto::SecureRandom; -use crate::network::probe::NetworkDecision; -use crate::proxy::handshake::HandshakeSuccess; -use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController}; -use crate::stats::Stats; -use crate::stream::{BufferPool, CryptoReader, CryptoWriter, PooledBuffer}; -use crate::transport::middle_proxy::MePool; -use bytes::Bytes; -use rand::rngs::StdRng; -use rand::{RngExt, SeedableRng}; -use std::collections::{HashMap, HashSet}; -use std::net::SocketAddr; -use std::sync::Arc; -use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering}; -use std::sync::Mutex; -use std::thread; -use tokio::io::AsyncReadExt; -use tokio::io::AsyncWriteExt; -use tokio::io::duplex; -use tokio::sync::Barrier; -use tokio::time::{Duration as TokioDuration, timeout}; - -fn make_pooled_payload(data: &[u8]) -> PooledBuffer { - let pool = Arc::new(BufferPool::with_config(data.len().max(1), 4)); - let mut payload = pool.get(); - payload.resize(data.len(), 0); - payload[..data.len()].copy_from_slice(data); - payload -} - -fn make_pooled_payload_from(pool: &Arc, data: &[u8]) -> PooledBuffer { - let mut payload = pool.get(); - payload.resize(data.len(), 0); - payload[..data.len()].copy_from_slice(data); - payload -} - -#[test] -fn should_yield_sender_only_on_budget_with_backlog() { - assert!(!should_yield_c2me_sender(0, true)); - assert!(!should_yield_c2me_sender( - C2ME_SENDER_FAIRNESS_BUDGET - 1, - true - )); - assert!(!should_yield_c2me_sender( - C2ME_SENDER_FAIRNESS_BUDGET, - false - )); - assert!(should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET, true)); -} - -#[tokio::test] -async fn enqueue_c2me_command_uses_try_send_fast_path() { - let (tx, mut rx) = mpsc::channel::(2); - enqueue_c2me_command( - &tx, - C2MeCommand::Data { - payload: make_pooled_payload(&[1, 2, 3]), - flags: 0, - }, - ) - .await - .unwrap(); - - let recv = timeout(TokioDuration::from_millis(50), rx.recv()) - .await - .unwrap() - .unwrap(); - match recv { - C2MeCommand::Data { payload, flags } => { - assert_eq!(payload.as_ref(), &[1, 2, 3]); - assert_eq!(flags, 0); - } - C2MeCommand::Close => panic!("unexpected close command"), - } -} - -#[tokio::test] -async fn enqueue_c2me_command_falls_back_to_send_when_queue_is_full() { - let (tx, mut rx) = mpsc::channel::(1); - tx.send(C2MeCommand::Data { - payload: make_pooled_payload(&[9]), - flags: 9, - }) - .await - .unwrap(); - - let tx2 = tx.clone(); - let producer = tokio::spawn(async move { - enqueue_c2me_command( - &tx2, - C2MeCommand::Data { - payload: make_pooled_payload(&[7, 7]), - flags: 7, - }, - ) - .await - .unwrap(); - }); - - let _ = timeout(TokioDuration::from_millis(100), rx.recv()) - .await - .unwrap(); - producer.await.unwrap(); - - let recv = timeout(TokioDuration::from_millis(100), rx.recv()) - .await - .unwrap() - .unwrap(); - match recv { - C2MeCommand::Data { payload, flags } => { - assert_eq!(payload.as_ref(), &[7, 7]); - assert_eq!(flags, 7); - } - C2MeCommand::Close => panic!("unexpected close command"), - } -} - -#[tokio::test] -async fn enqueue_c2me_command_closed_channel_recycles_payload() { - let pool = Arc::new(BufferPool::with_config(64, 4)); - let payload = make_pooled_payload_from(&pool, &[1, 2, 3, 4]); - let (tx, rx) = mpsc::channel::(1); - drop(rx); - - let result = enqueue_c2me_command(&tx, C2MeCommand::Data { payload, flags: 0 }).await; - - assert!(result.is_err(), "closed queue must fail enqueue"); - drop(result); - assert!( - pool.stats().pooled >= 1, - "payload must return to pool when enqueue fails on closed channel" - ); -} - -#[tokio::test] -async fn enqueue_c2me_command_full_then_closed_recycles_waiting_payload() { - let pool = Arc::new(BufferPool::with_config(64, 4)); - let (tx, rx) = mpsc::channel::(1); - - tx.send(C2MeCommand::Data { - payload: make_pooled_payload_from(&pool, &[9]), - flags: 1, - }) - .await - .unwrap(); - - let tx2 = tx.clone(); - let pool2 = pool.clone(); - let blocked_send = tokio::spawn(async move { - enqueue_c2me_command( - &tx2, - C2MeCommand::Data { - payload: make_pooled_payload_from(&pool2, &[7, 7, 7]), - flags: 2, - }, - ) - .await - }); - - tokio::time::sleep(TokioDuration::from_millis(10)).await; - drop(rx); - - let result = timeout(TokioDuration::from_secs(1), blocked_send) - .await - .expect("blocked send task must finish") - .expect("blocked send task must not panic"); - - assert!( - result.is_err(), - "closing receiver while sender is blocked must fail enqueue" - ); - drop(result); - assert!( - pool.stats().pooled >= 2, - "both queued and blocked payloads must return to pool after channel close" - ); -} - -#[tokio::test] -async fn enqueue_c2me_command_full_queue_times_out_without_receiver_progress() { - let (tx, _rx) = mpsc::channel::(1); - tx.send(C2MeCommand::Data { - payload: make_pooled_payload(&[1]), - flags: 0, - }) - .await - .unwrap(); - - let started = Instant::now(); - let result = enqueue_c2me_command( - &tx, - C2MeCommand::Data { - payload: make_pooled_payload(&[2, 2]), - flags: 1, - }, - ) - .await; - - assert!( - result.is_err(), - "enqueue must fail when queue stays full beyond bounded timeout" - ); - assert!( - started.elapsed() < TokioDuration::from_millis(400), - "full-queue timeout must resolve promptly" - ); -} - -#[test] -fn desync_dedup_cache_is_bounded() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("desync dedup test lock must be available"); - clear_desync_dedup_for_testing(); - - let now = Instant::now(); - for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 { - assert!( - should_emit_full_desync(key, false, now), - "unique keys up to cap must be tracked" - ); - } - - assert!( - should_emit_full_desync(u64::MAX, false, now), - "new key above cap must emit once after bounded eviction for forensic visibility" - ); - - assert!( - !should_emit_full_desync(u64::MAX, false, now), - "already tracked key inside dedup window must stay suppressed" - ); -} - -#[test] -fn quota_user_lock_cache_reuses_entry_for_same_user() { - let _guard = super::quota_user_lock_test_scope(); - - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let a = quota_user_lock("quota-user-a"); - let b = quota_user_lock("quota-user-a"); - assert!(Arc::ptr_eq(&a, &b), "same user must reuse same quota lock"); -} - -#[test] -fn quota_user_lock_cache_is_bounded_under_unique_churn() { - let _guard = super::quota_user_lock_test_scope(); - - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - for idx in 0..(QUOTA_USER_LOCKS_MAX + 128) { - let user = format!("quota-user-{idx}"); - let lock = quota_user_lock(&user); - drop(lock); - } - - assert!( - map.len() <= QUOTA_USER_LOCKS_MAX, - "quota lock cache must stay within configured bound" - ); -} - -#[test] -fn quota_user_lock_cache_saturation_returns_stable_overflow_lock_without_growth() { - let _guard = super::quota_user_lock_test_scope(); - - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - for attempt in 0..8u32 { - map.clear(); - - let prefix = format!("quota-held-user-{}-{attempt}", std::process::id()); - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - for idx in 0..QUOTA_USER_LOCKS_MAX { - let user = format!("{prefix}-{idx}"); - retained.push(quota_user_lock(&user)); - } - - if map.len() != QUOTA_USER_LOCKS_MAX { - drop(retained); - continue; - } - - let overflow_user = format!("quota-overflow-user-{}-{attempt}", std::process::id()); - let overflow_a = quota_user_lock(&overflow_user); - let overflow_b = quota_user_lock(&overflow_user); - - assert_eq!( - map.len(), - QUOTA_USER_LOCKS_MAX, - "overflow acquisition must not grow cache past hard limit" - ); - assert!( - map.get(&overflow_user).is_none(), - "overflow path should not cache new user lock when map is saturated and all entries are retained" - ); - assert!( - Arc::ptr_eq(&overflow_a, &overflow_b), - "overflow user lock should use deterministic striping under saturation" - ); - - drop(retained); - return; - } - - panic!("unable to observe stable saturated lock-cache precondition after bounded retries"); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn adversarial_quota_race_under_lock_cache_saturation_still_allows_only_one_winner() { - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - for idx in 0..QUOTA_USER_LOCKS_MAX { - let user = format!("quota-saturated-user-{idx}"); - retained.push(quota_user_lock(&user)); - } - - assert_eq!( - map.len(), - QUOTA_USER_LOCKS_MAX, - "precondition: cache must be saturated for overflow-user race test" - ); - - let stats = Stats::new(); - let bytes_me2c = AtomicU64::new(0); - let user = "gap-t04-saturated-lock-race-user"; - let barrier = Arc::new(Barrier::new(2)); - - let one = run_quota_race_attempt(&stats, &bytes_me2c, user, 0x55, 9101, barrier.clone()); - let two = run_quota_race_attempt(&stats, &bytes_me2c, user, 0x66, 9102, barrier); - let (r1, r2) = tokio::join!(one, two); - - assert!( - matches!(r1, Ok(_) | Err(ProxyError::DataQuotaExceeded { .. })) - && matches!(r2, Ok(_) | Err(ProxyError::DataQuotaExceeded { .. })), - "both racers must resolve cleanly without unexpected errors" - ); - assert!( - matches!(r1, Err(ProxyError::DataQuotaExceeded { .. })) - || matches!(r2, Err(ProxyError::DataQuotaExceeded { .. })), - "at least one racer must be quota-rejected even when lock cache is saturated" - ); - assert_eq!( - stats.get_user_total_octets(user), - 1, - "saturated lock cache must not permit double-success quota overshoot" - ); - - drop(retained); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn stress_quota_race_under_lock_cache_saturation_never_allows_double_success() { - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - for idx in 0..QUOTA_USER_LOCKS_MAX { - let user = format!("quota-saturated-stress-holder-{idx}"); - retained.push(quota_user_lock(&user)); - } - - let stats = Stats::new(); - let bytes_me2c = AtomicU64::new(0); - - for round in 0..128u64 { - let user = format!("gap-t04-saturated-race-round-{round}"); - let barrier = Arc::new(Barrier::new(2)); - - let one = run_quota_race_attempt( - &stats, - &bytes_me2c, - &user, - 0x71, - 12_000 + round, - barrier.clone(), - ); - let two = run_quota_race_attempt(&stats, &bytes_me2c, &user, 0x72, 13_000 + round, barrier); - - let (r1, r2) = tokio::join!(one, two); - assert!( - matches!(r1, Ok(_) | Err(ProxyError::DataQuotaExceeded { .. })) - && matches!(r2, Ok(_) | Err(ProxyError::DataQuotaExceeded { .. })), - "round {round}: racers must resolve cleanly" - ); - assert!( - matches!(r1, Err(ProxyError::DataQuotaExceeded { .. })) - || matches!(r2, Err(ProxyError::DataQuotaExceeded { .. })), - "round {round}: at least one racer must be quota-rejected" - ); - assert_eq!( - stats.get_user_total_octets(&user), - 1, - "round {round}: saturated cache must still enforce exactly one forwarded byte" - ); - } - - drop(retained); -} - -#[test] -fn adversarial_forensics_trace_id_should_not_alias_conn_id() { - let now = Instant::now(); - let trace_id = 0x1122_3344_5566_7788; - let conn_id = 0x8877_6655_4433_2211; - let state = RelayForensicsState { - trace_id, - conn_id, - user: "trace-user".to_string(), - peer: "198.51.100.17:443".parse().unwrap(), - peer_hash: 0x8877_6655_4433_2211, - started_at: now, - bytes_c2me: 0, - bytes_me2c: Arc::new(AtomicU64::new(0)), - desync_all_full: false, - }; - - assert_ne!( - state.trace_id, state.conn_id, - "security expectation: trace correlation should be independent of connection identity" - ); - assert_eq!(state.trace_id, trace_id); - assert_eq!(state.conn_id, conn_id); -} - -#[tokio::test] -async fn abridged_ack_uses_big_endian_confirm_bytes_after_decryption() { - let (mut writer_side, reader_side) = duplex(8); - let key = [0u8; 32]; - let iv = 0u128; - let mut writer = CryptoWriter::new(reader_side, AesCtr::new(&key, iv), 8 * 1024); - - write_client_ack(&mut writer, ProtoTag::Abridged, 0x11_22_33_44) - .await - .expect("ack write must succeed"); - - let mut observed = [0u8; 4]; - writer_side - .read_exact(&mut observed) - .await - .expect("ack bytes must be readable"); - let mut decryptor = AesCtr::new(&key, iv); - let decrypted = decryptor.decrypt(&observed); - - assert_eq!( - decrypted, - 0x11_22_33_44u32.to_be_bytes(), - "abridged ACK should encode confirm bytes in big-endian order" - ); -} - -#[test] -fn desync_dedup_full_cache_churn_stays_suppressed() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("desync dedup test lock must be available"); - clear_desync_dedup_for_testing(); - - let now = Instant::now(); - for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 { - assert!(should_emit_full_desync(key, false, now)); - } - - for offset in 0..2048u64 { - let emitted = should_emit_full_desync(u64::MAX - offset, false, now); - if offset == 0 { - assert!( - emitted, - "first full-cache newcomer should emit for forensic visibility" - ); - } else { - assert!( - !emitted, - "full-cache newcomer churn inside emit interval must stay suppressed" - ); - } - } -} - -#[test] -fn dedup_hash_is_stable_for_same_input_within_process() { - let sample = ( - "scope_user", - hash_ip("198.51.100.7".parse().unwrap()), - ProtoTag::Secure, - ); - let first = hash_value(&sample); - let second = hash_value(&sample); - assert_eq!( - first, second, - "dedup hash must be stable within a process for cache lookups" - ); -} - -#[test] -fn dedup_hash_resists_simple_collision_bursts_for_peer_ip_space() { - let mut seen = HashSet::new(); - - for octet in 1u16..=2048 { - let third = ((octet / 256) & 0xff) as u8; - let fourth = (octet & 0xff) as u8; - let ip = IpAddr::V4(std::net::Ipv4Addr::new(198, 51, third, fourth)); - let key = hash_value(&( - "scope_user", - hash_ip(ip), - ProtoTag::Secure, - DESYNC_ERROR_CLASS, - )); - seen.insert(key); - } - - assert_eq!( - seen.len(), - 2048, - "adversarial peer-IP burst should not collapse dedup keys via trivial collisions" - ); -} - -#[test] -fn light_fuzz_dedup_hash_collision_rate_stays_negligible() { - let mut rng = StdRng::seed_from_u64(0x9E37_79B9_A1B2_C3D4); - let mut seen = HashSet::new(); - let samples = 8192usize; - - for _ in 0..samples { - let user_seed: u64 = rng.random(); - let peer_seed: u64 = rng.random(); - let proto = if (peer_seed & 1) == 0 { - ProtoTag::Secure - } else { - ProtoTag::Intermediate - }; - let key = hash_value(&(user_seed, peer_seed, proto, DESYNC_ERROR_CLASS)); - seen.insert(key); - } - - let collisions = samples - seen.len(); - assert!( - collisions <= 1, - "light fuzz collision count should remain negligible for 64-bit dedup keys" - ); -} - -#[test] -fn stress_desync_dedup_churn_keeps_cache_hard_bounded() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("desync dedup test lock must be available"); - clear_desync_dedup_for_testing(); - - let now = Instant::now(); - let total = DESYNC_DEDUP_MAX_ENTRIES + 8192; - - let mut emitted_count = 0usize; - for key in 0..total as u64 { - let emitted = should_emit_full_desync(key, false, now); - if emitted { - emitted_count += 1; - } - } - - assert_eq!( - emitted_count, - DESYNC_DEDUP_MAX_ENTRIES + 1, - "after capacity is reached, same-tick newcomer churn must be rate-limited" - ); - - let len = DESYNC_DEDUP - .get() - .expect("dedup cache must be initialized by stress run") - .len(); - assert!( - len <= DESYNC_DEDUP_MAX_ENTRIES, - "dedup cache must stay bounded under stress churn" - ); -} - -#[test] -fn full_cache_newcomer_emission_is_rate_limited_but_periodic() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("desync dedup test lock must be available"); - clear_desync_dedup_for_testing(); - - let dedup = DESYNC_DEDUP.get_or_init(DashMap::new); - let base_now = Instant::now(); - - for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 { - dedup.insert(key, base_now - TokioDuration::from_millis(10)); - } - - // Same-tick newcomer storm: only the first should emit full forensic record. - let mut burst_emits = 0usize; - for i in 0..1024u64 { - if should_emit_full_desync(10_000_000 + i, false, base_now) { - burst_emits += 1; - } - } - assert_eq!( - burst_emits, 1, - "full-cache newcomer burst must be bounded to a single full emit per interval" - ); - - // After each interval elapses, one newcomer may emit again. - for step in 1..=6u64 { - let t = base_now + DESYNC_FULL_CACHE_EMIT_MIN_INTERVAL * step as u32; - assert!( - should_emit_full_desync(20_000_000 + step, false, t), - "full-cache newcomer should re-emit once interval has elapsed" - ); - assert!( - !should_emit_full_desync(30_000_000 + step, false, t), - "additional newcomers in the same interval tick must remain suppressed" - ); - } -} - -#[test] -fn full_cache_mode_override_emits_every_event() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("desync dedup test lock must be available"); - clear_desync_dedup_for_testing(); - - let now = Instant::now(); - for i in 0..10_000u64 { - assert!( - should_emit_full_desync(100_000_000 + i, true, now), - "desync_all_full override must bypass dedup and rate-limit suppression" - ); - } -} - -#[test] -fn report_desync_stats_follow_rate_limited_full_cache_policy() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("desync dedup test lock must be available"); - clear_desync_dedup_for_testing(); - - let dedup = DESYNC_DEDUP.get_or_init(DashMap::new); - let base_now = Instant::now(); - for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 { - dedup.insert(key, base_now - TokioDuration::from_millis(10)); - } - - let stats = Stats::new(); - let mut state = make_forensics_state(); - state.started_at = base_now; - - for i in 0..128u64 { - state.peer_hash = 0xABC0_0000_0000_0000u64 ^ i; - let _ = report_desync_frame_too_large( - &state, - ProtoTag::Secure, - 3, - 1024, - 4096, - Some([0x16, 0x03, 0x03, 0x00]), - &stats, - ); - } - - assert_eq!( - stats.get_desync_total(), - 128, - "every detected desync must increment total counter" - ); - assert_eq!( - stats.get_desync_full_logged(), - 1, - "same-interval full-cache newcomer storm must allow only one full forensic emit" - ); - assert_eq!( - stats.get_desync_suppressed(), - 127, - "remaining same-interval full-cache newcomer events must be suppressed" - ); - - // After one full interval in real wall clock, a newcomer should emit again. - thread::sleep(DESYNC_FULL_CACHE_EMIT_MIN_INTERVAL + TokioDuration::from_millis(20)); - state.peer_hash = 0xDEAD_BEEF_DEAD_BEEFu64; - let _ = report_desync_frame_too_large( - &state, - ProtoTag::Secure, - 4, - 1024, - 4097, - Some([0x16, 0x03, 0x03, 0x01]), - &stats, - ); - - assert_eq!( - stats.get_desync_full_logged(), - 2, - "full forensic emission must recover after rate-limit interval" - ); -} - -#[test] -fn concurrent_full_cache_newcomer_storm_is_single_emit_per_interval() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("desync dedup test lock must be available"); - clear_desync_dedup_for_testing(); - - let dedup = DESYNC_DEDUP.get_or_init(DashMap::new); - let base_now = Instant::now(); - for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 { - dedup.insert(key, base_now - TokioDuration::from_millis(10)); - } - - let emits = Arc::new(AtomicUsize::new(0)); - let mut workers = Vec::new(); - for worker_id in 0..32u64 { - let emits = Arc::clone(&emits); - workers.push(thread::spawn(move || { - for i in 0..512u64 { - let key = 0x7000_0000_0000_0000u64 ^ (worker_id << 20) ^ i; - if should_emit_full_desync(key, false, base_now) { - emits.fetch_add(1, Ordering::Relaxed); - } - } - })); - } - - for worker in workers { - worker.join().expect("worker thread must not panic"); - } - - assert_eq!( - emits.load(Ordering::Relaxed), - 1, - "concurrent same-interval full-cache storm must allow only one full forensic emit" - ); -} - -#[test] -fn light_fuzz_full_cache_rate_limit_oracle_matches_model() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("desync dedup test lock must be available"); - clear_desync_dedup_for_testing(); - - let dedup = DESYNC_DEDUP.get_or_init(DashMap::new); - let base_now = Instant::now(); - for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 { - dedup.insert(key, base_now - TokioDuration::from_millis(10)); - } - - let mut rng = StdRng::seed_from_u64(0xD15EA5E5_F00DBAAD); - let mut model_last_emit: Option = None; - - for i in 0..4096u64 { - let jitter_ms: u64 = rng.random_range(0..=3000); - let t = base_now + TokioDuration::from_millis(jitter_ms); - let key = 0x55AA_0000_0000_0000u64 ^ i ^ rng.random::(); - let actual = should_emit_full_desync(key, false, t); - - let expected = match model_last_emit { - None => { - model_last_emit = Some(t); - true - } - Some(last) => { - match t.checked_duration_since(last) { - Some(elapsed) if elapsed >= DESYNC_FULL_CACHE_EMIT_MIN_INTERVAL => { - model_last_emit = Some(t); - true - } - Some(_) => false, - None => { - // Match production fail-open behavior for non-monotonic synthetic input. - model_last_emit = Some(t); - true - } - } - } - }; - - assert_eq!( - actual, expected, - "full-cache rate-limit gate diverged from reference model under light fuzz" - ); - } -} - -#[test] -fn full_cache_gate_lock_poison_is_fail_closed_without_panic() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("desync dedup test lock must be available"); - clear_desync_dedup_for_testing(); - - let dedup = DESYNC_DEDUP.get_or_init(DashMap::new); - let base_now = Instant::now(); - for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 { - dedup.insert(key, base_now - TokioDuration::from_millis(10)); - } - - // Poison the full-cache gate lock intentionally. - let gate = DESYNC_FULL_CACHE_LAST_EMIT_AT.get_or_init(|| Mutex::new(None)); - let _ = std::panic::catch_unwind(|| { - let _lock = gate - .lock() - .expect("gate lock must be lockable before poison"); - panic!("intentional gate poison for fail-closed regression"); - }); - - let emitted = should_emit_full_desync(0xFACE_0000_0000_0001, false, base_now); - assert!( - !emitted, - "poisoned full-cache gate must fail-closed (suppress) instead of panic or fail-open" - ); - assert!( - dedup.len() <= DESYNC_DEDUP_MAX_ENTRIES, - "dedup cache must remain bounded even when gate lock is poisoned" - ); -} - -#[test] -fn full_cache_non_monotonic_time_emits_and_resets_gate_safely() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("desync dedup test lock must be available"); - clear_desync_dedup_for_testing(); - - let dedup = DESYNC_DEDUP.get_or_init(DashMap::new); - let base_now = Instant::now(); - for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 { - dedup.insert(key, base_now - TokioDuration::from_millis(10)); - } - - // First event seeds the gate. - assert!(should_emit_full_desync( - 0xABCD_0000_0000_0001, - false, - base_now + TokioDuration::from_millis(900) - )); - - // Synthetic earlier timestamp must not panic; it should fail-open and reset gate. - assert!(should_emit_full_desync( - 0xABCD_0000_0000_0002, - false, - base_now + TokioDuration::from_millis(100) - )); - - // Same instant again remains suppressed after reset. - assert!(!should_emit_full_desync( - 0xABCD_0000_0000_0003, - false, - base_now + TokioDuration::from_millis(100) - )); -} - -#[test] -fn desync_dedup_full_cache_inserts_new_key_with_bounded_single_key_churn() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("desync dedup test lock must be available"); - clear_desync_dedup_for_testing(); - - let dedup = DESYNC_DEDUP.get_or_init(DashMap::new); - let base_now = Instant::now(); - - // Fill with fresh entries so stale-pruning does not apply. - for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 { - dedup.insert(key, base_now - TokioDuration::from_millis(10)); - } - - let before_keys: std::collections::HashSet = dedup.iter().map(|e| *e.key()).collect(); - - let newcomer_key = u64::MAX; - let emitted = should_emit_full_desync(newcomer_key, false, base_now); - assert!( - emitted, - "new entry under full fresh cache must emit after bounded eviction" - ); - assert!( - dedup.get(&newcomer_key).is_some(), - "new key must be inserted after bounded eviction" - ); - - let after_keys: std::collections::HashSet = dedup.iter().map(|e| *e.key()).collect(); - let removed_count = before_keys.difference(&after_keys).count(); - let added_count = after_keys.difference(&before_keys).count(); - - assert_eq!( - removed_count, 1, - "full-cache insertion must evict exactly one prior key" - ); - assert_eq!( - added_count, 1, - "full-cache insertion must add exactly one newcomer key" - ); - assert!( - dedup.len() <= DESYNC_DEDUP_MAX_ENTRIES, - "dedup cache must remain hard-bounded after full-cache churn" - ); -} - -#[test] -fn light_fuzz_desync_dedup_temporal_gate_behavior_is_stable() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("desync dedup test lock must be available"); - clear_desync_dedup_for_testing(); - - let key = 0xC0DE_CAFE_u64; - let start = Instant::now(); - - assert!( - should_emit_full_desync(key, false, start), - "first event for key must emit full forensic record" - ); - - // Deterministic pseudo-random time deltas around dedup window edge. - let mut s: u64 = 0x1234_5678_9ABC_DEF0; - for _ in 0..2048 { - s ^= s << 7; - s ^= s >> 9; - s ^= s << 8; - - let delta_ms = s % (DESYNC_DEDUP_WINDOW.as_millis() as u64 * 2 + 1); - let now = start + TokioDuration::from_millis(delta_ms); - let emitted = should_emit_full_desync(key, false, now); - - if delta_ms < DESYNC_DEDUP_WINDOW.as_millis() as u64 { - assert!( - !emitted, - "events inside dedup window must remain suppressed" - ); - } else { - // Once window elapsed for this key, at least one sample should re-emit and refresh. - if emitted { - return; - } - } - } - - panic!("expected at least one post-window sample to re-emit forensic record"); -} - -fn make_forensics_state() -> RelayForensicsState { - RelayForensicsState { - trace_id: 1, - conn_id: 2, - user: "test-user".to_string(), - peer: "127.0.0.1:50000".parse::().unwrap(), - peer_hash: 3, - started_at: Instant::now(), - bytes_c2me: 0, - bytes_me2c: Arc::new(AtomicU64::new(0)), - desync_all_full: false, - } -} - -fn make_crypto_reader(reader: R) -> CryptoReader -where - R: tokio::io::AsyncRead + Unpin, -{ - let key = [0u8; 32]; - let iv = 0u128; - CryptoReader::new(reader, AesCtr::new(&key, iv)) -} - -fn make_crypto_writer(writer: W) -> CryptoWriter -where - W: tokio::io::AsyncWrite + Unpin, -{ - let key = [0u8; 32]; - let iv = 0u128; - CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) -} - -async fn make_me_pool_for_abort_test(stats: Arc) -> Arc { - let general = GeneralConfig::default(); - - MePool::new( - None, - vec![1u8; 32], - None, - false, - None, - Vec::new(), - 1, - None, - 12, - 1200, - HashMap::new(), - HashMap::new(), - None, - NetworkDecision::default(), - None, - Arc::new(SecureRandom::new()), - stats, - general.me_keepalive_enabled, - general.me_keepalive_interval_secs, - general.me_keepalive_jitter_secs, - general.me_keepalive_payload_random, - general.rpc_proxy_req_every, - general.me_warmup_stagger_enabled, - general.me_warmup_step_delay_ms, - general.me_warmup_step_jitter_ms, - general.me_reconnect_max_concurrent_per_dc, - general.me_reconnect_backoff_base_ms, - general.me_reconnect_backoff_cap_ms, - general.me_reconnect_fast_retry_count, - general.me_single_endpoint_shadow_writers, - general.me_single_endpoint_outage_mode_enabled, - general.me_single_endpoint_outage_disable_quarantine, - general.me_single_endpoint_outage_backoff_min_ms, - general.me_single_endpoint_outage_backoff_max_ms, - general.me_single_endpoint_shadow_rotate_every_secs, - general.me_floor_mode, - general.me_adaptive_floor_idle_secs, - general.me_adaptive_floor_min_writers_single_endpoint, - general.me_adaptive_floor_min_writers_multi_endpoint, - general.me_adaptive_floor_recover_grace_secs, - general.me_adaptive_floor_writers_per_core_total, - general.me_adaptive_floor_cpu_cores_override, - general.me_adaptive_floor_max_extra_writers_single_per_core, - general.me_adaptive_floor_max_extra_writers_multi_per_core, - general.me_adaptive_floor_max_active_writers_per_core, - general.me_adaptive_floor_max_warm_writers_per_core, - general.me_adaptive_floor_max_active_writers_global, - general.me_adaptive_floor_max_warm_writers_global, - general.hardswap, - general.me_pool_drain_ttl_secs, - general.me_instadrain, - general.me_pool_drain_threshold, - general.me_pool_drain_soft_evict_enabled, - general.me_pool_drain_soft_evict_grace_secs, - general.me_pool_drain_soft_evict_per_writer, - general.me_pool_drain_soft_evict_budget_per_core, - general.me_pool_drain_soft_evict_cooldown_ms, - general.effective_me_pool_force_close_secs(), - general.me_pool_min_fresh_ratio, - general.me_hardswap_warmup_delay_min_ms, - general.me_hardswap_warmup_delay_max_ms, - general.me_hardswap_warmup_extra_passes, - general.me_hardswap_warmup_pass_backoff_base_ms, - general.me_bind_stale_mode, - general.me_bind_stale_ttl_secs, - general.me_secret_atomic_snapshot, - general.me_deterministic_writer_sort, - MeWriterPickMode::default(), - general.me_writer_pick_sample_size, - MeSocksKdfPolicy::default(), - general.me_writer_cmd_channel_capacity, - general.me_route_channel_capacity, - general.me_route_backpressure_base_timeout_ms, - general.me_route_backpressure_high_timeout_ms, - general.me_route_backpressure_high_watermark_pct, - general.me_reader_route_data_wait_ms, - general.me_health_interval_ms_unhealthy, - general.me_health_interval_ms_healthy, - general.me_warn_rate_limit_ms, - MeRouteNoWriterMode::default(), - general.me_route_no_writer_wait_ms, - general.me_route_inline_recovery_attempts, - general.me_route_inline_recovery_wait_ms, - ) -} - -fn encrypt_for_reader(plaintext: &[u8]) -> Vec { - let key = [0u8; 32]; - let iv = 0u128; - let mut cipher = AesCtr::new(&key, iv); - cipher.encrypt(plaintext) -} - -#[tokio::test] -async fn read_client_payload_times_out_on_header_stall() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("middle relay test lock must be available"); - let (reader, _writer) = duplex(1024); - let mut crypto_reader = make_crypto_reader(reader); - let buffer_pool = Arc::new(BufferPool::new()); - let stats = Stats::new(); - let forensics = make_forensics_state(); - let mut frame_counter = 0; - - let result = read_client_payload( - &mut crypto_reader, - ProtoTag::Intermediate, - 1024, - TokioDuration::from_millis(25), - &buffer_pool, - &forensics, - &mut frame_counter, - &stats, - ) - .await; - - assert!( - matches!(result, Err(ProxyError::Io(ref e)) if e.kind() == std::io::ErrorKind::TimedOut), - "stalled header read must time out" - ); -} - -#[tokio::test] -async fn read_client_payload_times_out_on_payload_stall() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("middle relay test lock must be available"); - let (reader, mut writer) = duplex(1024); - let encrypted_len = encrypt_for_reader(&[8, 0, 0, 0]); - writer.write_all(&encrypted_len).await.unwrap(); - - let mut crypto_reader = make_crypto_reader(reader); - let buffer_pool = Arc::new(BufferPool::new()); - let stats = Stats::new(); - let forensics = make_forensics_state(); - let mut frame_counter = 0; - - let result = read_client_payload( - &mut crypto_reader, - ProtoTag::Intermediate, - 1024, - TokioDuration::from_millis(25), - &buffer_pool, - &forensics, - &mut frame_counter, - &stats, - ) - .await; - - assert!( - matches!(result, Err(ProxyError::Io(ref e)) if e.kind() == std::io::ErrorKind::TimedOut), - "stalled payload body read must time out" - ); -} - -#[tokio::test] -async fn read_client_payload_large_intermediate_frame_is_exact() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("middle relay test lock must be available"); - - let (reader, mut writer) = duplex(262_144); - let mut crypto_reader = make_crypto_reader(reader); - let buffer_pool = Arc::new(BufferPool::new()); - let stats = Stats::new(); - let forensics = make_forensics_state(); - let mut frame_counter = 0; - - let payload_len = buffer_pool.buffer_size().saturating_mul(3).max(65_537); - let mut plaintext = Vec::with_capacity(4 + payload_len); - plaintext.extend_from_slice(&(payload_len as u32).to_le_bytes()); - plaintext.extend((0..payload_len).map(|idx| (idx as u8).wrapping_mul(31))); - - let encrypted = encrypt_for_reader(&plaintext); - writer.write_all(&encrypted).await.unwrap(); - - let read = read_client_payload( - &mut crypto_reader, - ProtoTag::Intermediate, - payload_len + 16, - TokioDuration::from_secs(1), - &buffer_pool, - &forensics, - &mut frame_counter, - &stats, - ) - .await - .expect("payload read must succeed") - .expect("frame must be present"); - - let (frame, quickack) = read; - assert!(!quickack, "quickack flag must be unset"); - assert_eq!( - frame.len(), - payload_len, - "payload size must match wire length" - ); - for (idx, byte) in frame.iter().enumerate() { - assert_eq!(*byte, (idx as u8).wrapping_mul(31)); - } - assert_eq!(frame_counter, 1, "exactly one frame must be counted"); -} - -#[tokio::test] -async fn read_client_payload_secure_strips_tail_padding_bytes() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("middle relay test lock must be available"); - - let (reader, mut writer) = duplex(1024); - let mut crypto_reader = make_crypto_reader(reader); - let buffer_pool = Arc::new(BufferPool::new()); - let stats = Stats::new(); - let forensics = make_forensics_state(); - let mut frame_counter = 0; - - let payload = [0x11u8, 0x22, 0x33, 0x44, 0xaa, 0xbb, 0xcc, 0xdd]; - let tail = [0xeeu8, 0xff, 0x99]; - let wire_len = payload.len() + tail.len(); - - let mut plaintext = Vec::with_capacity(4 + wire_len); - plaintext.extend_from_slice(&(wire_len as u32).to_le_bytes()); - plaintext.extend_from_slice(&payload); - plaintext.extend_from_slice(&tail); - let encrypted = encrypt_for_reader(&plaintext); - writer.write_all(&encrypted).await.unwrap(); - - let read = read_client_payload( - &mut crypto_reader, - ProtoTag::Secure, - 1024, - TokioDuration::from_secs(1), - &buffer_pool, - &forensics, - &mut frame_counter, - &stats, - ) - .await - .expect("secure payload read must succeed") - .expect("secure frame must be present"); - - let (frame, quickack) = read; - assert!(!quickack, "quickack flag must be unset"); - assert_eq!(frame.as_ref(), &payload); - assert_eq!(frame_counter, 1, "one secure frame must be counted"); -} - -#[tokio::test] -async fn read_client_payload_secure_rejects_wire_len_below_4() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("middle relay test lock must be available"); - - let (reader, mut writer) = duplex(1024); - let mut crypto_reader = make_crypto_reader(reader); - let buffer_pool = Arc::new(BufferPool::new()); - let stats = Stats::new(); - let forensics = make_forensics_state(); - let mut frame_counter = 0; - - let mut plaintext = Vec::with_capacity(7); - plaintext.extend_from_slice(&3u32.to_le_bytes()); - plaintext.extend_from_slice(&[1u8, 2, 3]); - let encrypted = encrypt_for_reader(&plaintext); - writer.write_all(&encrypted).await.unwrap(); - - let result = read_client_payload( - &mut crypto_reader, - ProtoTag::Secure, - 1024, - TokioDuration::from_secs(1), - &buffer_pool, - &forensics, - &mut frame_counter, - &stats, - ) - .await; - - assert!( - matches!(result, Err(ProxyError::Proxy(ref msg)) if msg.contains("Frame too small: 3")), - "secure wire length below 4 must be fail-closed by the frame-too-small guard" - ); -} - -#[tokio::test] -async fn read_client_payload_intermediate_skips_zero_len_frame() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("middle relay test lock must be available"); - - let (reader, mut writer) = duplex(1024); - let mut crypto_reader = make_crypto_reader(reader); - let buffer_pool = Arc::new(BufferPool::new()); - let stats = Stats::new(); - let forensics = make_forensics_state(); - let mut frame_counter = 0; - - let payload = [7u8, 6, 5, 4, 3, 2, 1, 0]; - let mut plaintext = Vec::with_capacity(4 + 4 + payload.len()); - plaintext.extend_from_slice(&0u32.to_le_bytes()); - plaintext.extend_from_slice(&(payload.len() as u32).to_le_bytes()); - plaintext.extend_from_slice(&payload); - let encrypted = encrypt_for_reader(&plaintext); - writer.write_all(&encrypted).await.unwrap(); - - let read = read_client_payload( - &mut crypto_reader, - ProtoTag::Intermediate, - 1024, - TokioDuration::from_secs(1), - &buffer_pool, - &forensics, - &mut frame_counter, - &stats, - ) - .await - .expect("intermediate payload read must succeed") - .expect("frame must be present"); - - let (frame, quickack) = read; - assert!(!quickack, "quickack flag must be unset"); - assert_eq!(frame.as_ref(), &payload); - assert_eq!(frame_counter, 1, "zero-length frame must be skipped"); -} - -#[tokio::test] -async fn read_client_payload_abridged_extended_len_sets_quickack() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("middle relay test lock must be available"); - - let (reader, mut writer) = duplex(4096); - let mut crypto_reader = make_crypto_reader(reader); - let buffer_pool = Arc::new(BufferPool::new()); - let stats = Stats::new(); - let forensics = make_forensics_state(); - let mut frame_counter = 0; - - let payload_len = 4 * 130; - let len_words = (payload_len / 4) as u32; - let mut plaintext = Vec::with_capacity(1 + 3 + payload_len); - plaintext.push(0xff | 0x80); - let lw = len_words.to_le_bytes(); - plaintext.extend_from_slice(&lw[..3]); - plaintext.extend((0..payload_len).map(|idx| (idx as u8).wrapping_add(17))); - - let encrypted = encrypt_for_reader(&plaintext); - writer.write_all(&encrypted).await.unwrap(); - - let read = read_client_payload( - &mut crypto_reader, - ProtoTag::Abridged, - payload_len + 16, - TokioDuration::from_secs(1), - &buffer_pool, - &forensics, - &mut frame_counter, - &stats, - ) - .await - .expect("abridged payload read must succeed") - .expect("frame must be present"); - - let (frame, quickack) = read; - assert!( - quickack, - "quickack bit must be propagated from abridged header" - ); - assert_eq!(frame.len(), payload_len); - assert_eq!(frame_counter, 1, "one abridged frame must be counted"); -} - -#[tokio::test] -async fn read_client_payload_returns_buffer_to_pool_after_emit() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("middle relay test lock must be available"); - - let pool = Arc::new(BufferPool::with_config(64, 8)); - pool.preallocate(1); - assert_eq!(pool.stats().pooled, 1, "precondition: one pooled buffer"); - - let (reader, mut writer) = duplex(4096); - let mut crypto_reader = make_crypto_reader(reader); - let stats = Stats::new(); - let forensics = make_forensics_state(); - let mut frame_counter = 0; - - // Force growth beyond default pool buffer size to catch ownership-take regressions. - let payload_len = 257usize; - let mut plaintext = Vec::with_capacity(4 + payload_len); - plaintext.extend_from_slice(&(payload_len as u32).to_le_bytes()); - plaintext.extend((0..payload_len).map(|idx| (idx as u8).wrapping_mul(13))); - - let encrypted = encrypt_for_reader(&plaintext); - writer.write_all(&encrypted).await.unwrap(); - - let _ = read_client_payload( - &mut crypto_reader, - ProtoTag::Intermediate, - payload_len + 8, - TokioDuration::from_secs(1), - &pool, - &forensics, - &mut frame_counter, - &stats, - ) - .await - .expect("payload read must succeed") - .expect("frame must be present"); - - assert_eq!(frame_counter, 1); - let pool_stats = pool.stats(); - assert!( - pool_stats.pooled >= 1, - "emitted payload buffer must be returned to pool to avoid pool drain" - ); -} - -#[tokio::test] -async fn read_client_payload_keeps_pool_buffer_checked_out_until_frame_drop() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("middle relay test lock must be available"); - - let pool = Arc::new(BufferPool::with_config(64, 2)); - pool.preallocate(1); - assert_eq!( - pool.stats().pooled, - 1, - "one pooled buffer must be available" - ); - - let (reader, mut writer) = duplex(1024); - let mut crypto_reader = make_crypto_reader(reader); - let stats = Stats::new(); - let forensics = make_forensics_state(); - let mut frame_counter = 0; - - let payload = [0x41u8, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48]; - let mut plaintext = Vec::with_capacity(4 + payload.len()); - plaintext.extend_from_slice(&(payload.len() as u32).to_le_bytes()); - plaintext.extend_from_slice(&payload); - let encrypted = encrypt_for_reader(&plaintext); - writer.write_all(&encrypted).await.unwrap(); - - let (frame, quickack) = read_client_payload( - &mut crypto_reader, - ProtoTag::Intermediate, - 1024, - TokioDuration::from_secs(1), - &pool, - &forensics, - &mut frame_counter, - &stats, - ) - .await - .expect("payload read must succeed") - .expect("frame must be present"); - - assert!(!quickack); - assert_eq!(frame.as_ref(), &payload); - assert_eq!( - pool.stats().pooled, - 0, - "buffer must stay checked out while frame payload is alive" - ); - - drop(frame); - assert!( - pool.stats().pooled >= 1, - "buffer must return to pool only after frame drop" - ); -} - -#[tokio::test] -async fn enqueue_c2me_close_unblocks_after_queue_drain() { - let (tx, mut rx) = mpsc::channel::(1); - tx.send(C2MeCommand::Data { - payload: make_pooled_payload(&[0x41]), - flags: 0, - }) - .await - .unwrap(); - - let tx2 = tx.clone(); - let close_task = - tokio::spawn(async move { enqueue_c2me_command(&tx2, C2MeCommand::Close).await }); - - tokio::time::sleep(TokioDuration::from_millis(10)).await; - - let first = timeout(TokioDuration::from_millis(100), rx.recv()) - .await - .unwrap() - .expect("first queued item must be present"); - assert!(matches!(first, C2MeCommand::Data { .. })); - - close_task - .await - .unwrap() - .expect("close enqueue must succeed after drain"); - - let second = timeout(TokioDuration::from_millis(100), rx.recv()) - .await - .unwrap() - .expect("close command must follow after queue drain"); - assert!(matches!(second, C2MeCommand::Close)); -} - -#[tokio::test] -async fn enqueue_c2me_close_full_then_receiver_drop_fails_cleanly() { - let (tx, rx) = mpsc::channel::(1); - tx.send(C2MeCommand::Data { - payload: make_pooled_payload(&[0x42]), - flags: 0, - }) - .await - .unwrap(); - - let tx2 = tx.clone(); - let close_task = - tokio::spawn(async move { enqueue_c2me_command(&tx2, C2MeCommand::Close).await }); - - tokio::time::sleep(TokioDuration::from_millis(10)).await; - drop(rx); - - let result = timeout(TokioDuration::from_secs(1), close_task) - .await - .expect("close task must finish") - .expect("close task must not panic"); - assert!( - result.is_err(), - "close enqueue must fail cleanly when receiver is dropped under pressure" - ); -} - -#[tokio::test] -async fn process_me_writer_response_ack_obeys_flush_policy() { - let (writer_side, _reader_side) = duplex(1024); - let mut writer = make_crypto_writer(writer_side); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - let stats = Stats::new(); - let bytes_me2c = AtomicU64::new(0); - - let immediate = process_me_writer_response( - MeResponse::Ack(0x11223344), - &mut writer, - ProtoTag::Intermediate, - &rng, - &mut frame_buf, - &stats, - "user", - None, - 0, - &bytes_me2c, - 77, - true, - false, - ) - .await - .expect("ack response must be processed"); - - assert!(matches!( - immediate, - MeWriterResponseOutcome::Continue { - frames: 1, - bytes: 4, - flush_immediately: true, - } - )); - - let delayed = process_me_writer_response( - MeResponse::Ack(0x55667788), - &mut writer, - ProtoTag::Intermediate, - &rng, - &mut frame_buf, - &stats, - "user", - None, - 0, - &bytes_me2c, - 77, - false, - false, - ) - .await - .expect("ack response must be processed"); - - assert!(matches!( - delayed, - MeWriterResponseOutcome::Continue { - frames: 1, - bytes: 4, - flush_immediately: false, - } - )); -} - -#[tokio::test] -async fn process_me_writer_response_data_updates_byte_accounting() { - let (writer_side, _reader_side) = duplex(1024); - let mut writer = make_crypto_writer(writer_side); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - let stats = Stats::new(); - let bytes_me2c = AtomicU64::new(0); - - let payload = vec![1u8, 2, 3, 4, 5, 6, 7, 8, 9]; - let outcome = process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from(payload.clone()), - }, - &mut writer, - ProtoTag::Intermediate, - &rng, - &mut frame_buf, - &stats, - "user", - None, - 0, - &bytes_me2c, - 88, - false, - false, - ) - .await - .expect("data response must be processed"); - - assert!(matches!( - outcome, - MeWriterResponseOutcome::Continue { - frames: 1, - bytes, - flush_immediately: false, - } if bytes == payload.len() - )); - assert_eq!( - bytes_me2c.load(std::sync::atomic::Ordering::Relaxed), - payload.len() as u64, - "ME->C byte accounting must increase by emitted payload size" - ); -} - -#[tokio::test] -async fn process_me_writer_response_data_enforces_live_user_quota() { - let (writer_side, mut reader_side) = duplex(1024); - let mut writer = make_crypto_writer(writer_side); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - let stats = Stats::new(); - let bytes_me2c = AtomicU64::new(0); - - stats.add_user_octets_from("quota-user", 10); - - let result = process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from(vec![1u8, 2, 3, 4]), - }, - &mut writer, - ProtoTag::Intermediate, - &rng, - &mut frame_buf, - &stats, - "quota-user", - Some(12), - 0, - &bytes_me2c, - 89, - false, - false, - ) - .await; - - assert!( - matches!(result, Err(ProxyError::DataQuotaExceeded { user }) if user == "quota-user"), - "ME->client runtime path must terminate when live user quota is crossed" - ); - - let mut raw = [0u8; 1]; - assert!( - timeout(TokioDuration::from_millis(100), reader_side.read(&mut raw)) - .await - .is_err(), - "quota exhaustion must not write any ciphertext to the client stream" - ); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn process_me_writer_response_concurrent_same_user_quota_does_not_overshoot_limit() { - let stats = Stats::new(); - let bytes_me2c = AtomicU64::new(0); - let user = "quota-race-user"; - - let (writer_side_a, _reader_side_a) = duplex(1024); - let (writer_side_b, _reader_side_b) = duplex(1024); - let mut writer_a = make_crypto_writer(writer_side_a); - let mut writer_b = make_crypto_writer(writer_side_b); - let mut frame_buf_a = Vec::new(); - let mut frame_buf_b = Vec::new(); - let rng_a = SecureRandom::new(); - let rng_b = SecureRandom::new(); - - let fut_a = process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0x11]), - }, - &mut writer_a, - ProtoTag::Intermediate, - &rng_a, - &mut frame_buf_a, - &stats, - user, - Some(1), - 0, - &bytes_me2c, - 91, - false, - false, - ); - let fut_b = process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0x22]), - }, - &mut writer_b, - ProtoTag::Intermediate, - &rng_b, - &mut frame_buf_b, - &stats, - user, - Some(1), - 0, - &bytes_me2c, - 92, - false, - false, - ); - - let (result_a, result_b) = tokio::join!(fut_a, fut_b); - - assert!( - matches!(result_a, Err(ProxyError::DataQuotaExceeded { ref user }) if user == "quota-race-user") - || matches!(result_a, Ok(_)), - "concurrent quota test must complete without panicking" - ); - assert!( - matches!(result_b, Err(ProxyError::DataQuotaExceeded { ref user }) if user == "quota-race-user") - || matches!(result_b, Ok(_)), - "concurrent quota test must complete without panicking" - ); - assert!( - stats.get_user_total_octets(user) <= 1, - "same-user concurrent middle-relay responses must not overshoot the configured quota" - ); -} - -#[tokio::test] -async fn process_me_writer_response_data_does_not_forward_partial_payload_when_remaining_quota_is_smaller_than_message() - { - let (writer_side, mut reader_side) = duplex(1024); - let mut writer = make_crypto_writer(writer_side); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - let stats = Stats::new(); - let bytes_me2c = AtomicU64::new(0); - - stats.add_user_octets_to("partial-quota-user", 3); - - let result = process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from(vec![1u8, 2, 3, 4]), - }, - &mut writer, - ProtoTag::Intermediate, - &rng, - &mut frame_buf, - &stats, - "partial-quota-user", - Some(4), - 0, - &bytes_me2c, - 90, - false, - false, - ) - .await; - - assert!( - matches!(result, Err(ProxyError::DataQuotaExceeded { user }) if user == "partial-quota-user"), - "ME->client runtime path must reject oversized payloads before writing" - ); - - let mut raw = [0u8; 1]; - assert!( - timeout(TokioDuration::from_millis(100), reader_side.read(&mut raw)) - .await - .is_err(), - "oversized payloads must not leak any partial ciphertext to the client stream" - ); -} - -#[tokio::test] -async fn middle_relay_abort_midflight_releases_route_gauge() { - let stats = Arc::new(Stats::new()); - let me_pool = make_me_pool_for_abort_test(stats.clone()).await; - let config = Arc::new(ProxyConfig::default()); - let buffer_pool = Arc::new(BufferPool::new()); - let rng = Arc::new(SecureRandom::new()); - - let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Middle)); - let route_snapshot = route_runtime.snapshot(); - - let (server_side, client_side) = duplex(64 * 1024); - let (server_reader, server_writer) = tokio::io::split(server_side); - let crypto_reader = make_crypto_reader(server_reader); - let crypto_writer = make_crypto_writer(server_writer); - - let success = HandshakeSuccess { - user: "abort-middle-user".to_string(), - dc_idx: 2, - proto_tag: ProtoTag::Intermediate, - dec_key: [0u8; 32], - dec_iv: 0, - enc_key: [0u8; 32], - enc_iv: 0, - peer: "127.0.0.1:50001".parse().unwrap(), - is_tls: false, - }; - - let relay_task = tokio::spawn(handle_via_middle_proxy( - crypto_reader, - crypto_writer, - success, - me_pool, - stats.clone(), - config, - buffer_pool, - "127.0.0.1:443".parse().unwrap(), - rng, - route_runtime.subscribe(), - route_snapshot, - 0xdecafbad, - )); - - let started = tokio::time::timeout(TokioDuration::from_secs(2), async { - loop { - if stats.get_current_connections_me() == 1 { - break; - } - tokio::time::sleep(TokioDuration::from_millis(10)).await; - } - }) - .await; - assert!( - started.is_ok(), - "middle relay must increment route gauge before abort" - ); - - relay_task.abort(); - let joined = relay_task.await; - assert!( - joined.is_err(), - "aborted middle relay task must return join error" - ); - - tokio::time::sleep(TokioDuration::from_millis(20)).await; - assert_eq!( - stats.get_current_connections_me(), - 0, - "route gauge must be released when middle relay task is aborted mid-flight" - ); - - drop(client_side); -} - -#[tokio::test] -async fn middle_relay_cutover_midflight_releases_route_gauge() { - let stats = Arc::new(Stats::new()); - let me_pool = make_me_pool_for_abort_test(stats.clone()).await; - let config = Arc::new(ProxyConfig::default()); - let buffer_pool = Arc::new(BufferPool::new()); - let rng = Arc::new(SecureRandom::new()); - - let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Middle)); - let route_snapshot = route_runtime.snapshot(); - - let (server_side, client_side) = duplex(64 * 1024); - let (server_reader, server_writer) = tokio::io::split(server_side); - let crypto_reader = make_crypto_reader(server_reader); - let crypto_writer = make_crypto_writer(server_writer); - - let success = HandshakeSuccess { - user: "cutover-middle-user".to_string(), - dc_idx: 2, - proto_tag: ProtoTag::Intermediate, - dec_key: [0u8; 32], - dec_iv: 0, - enc_key: [0u8; 32], - enc_iv: 0, - peer: "127.0.0.1:50003".parse().unwrap(), - is_tls: false, - }; - - let relay_task = tokio::spawn(handle_via_middle_proxy( - crypto_reader, - crypto_writer, - success, - me_pool, - stats.clone(), - config, - buffer_pool, - "127.0.0.1:443".parse().unwrap(), - rng, - route_runtime.subscribe(), - route_snapshot, - 0xfeed_beef, - )); - - tokio::time::timeout(TokioDuration::from_secs(2), async { - loop { - if stats.get_current_connections_me() == 1 { - break; - } - tokio::time::sleep(TokioDuration::from_millis(10)).await; - } - }) - .await - .expect("middle relay must increment route gauge before cutover"); - - assert!( - route_runtime.set_mode(RelayRouteMode::Direct).is_some(), - "cutover must advance route generation" - ); - - let relay_result = tokio::time::timeout(TokioDuration::from_secs(6), relay_task) - .await - .expect("middle relay must terminate after cutover") - .expect("middle relay task must not panic"); - assert!( - relay_result.is_err(), - "cutover should terminate middle relay session" - ); - assert!( - matches!( - relay_result, - Err(ProxyError::Proxy(ref msg)) if msg == ROUTE_SWITCH_ERROR_MSG - ), - "client-visible cutover error must stay generic and avoid route-internal metadata" - ); - - assert_eq!( - stats.get_current_connections_me(), - 0, - "route gauge must be released when middle relay exits on cutover" - ); - - drop(client_side); -} - -async fn run_quota_race_attempt( - stats: &Stats, - bytes_me2c: &AtomicU64, - user: &str, - payload: u8, - conn_id: u64, - barrier: Arc, -) -> Result { - let (writer_side, _reader_side) = duplex(1024); - let mut writer = make_crypto_writer(writer_side); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - - barrier.wait().await; - process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from(vec![payload]), - }, - &mut writer, - ProtoTag::Intermediate, - &rng, - &mut frame_buf, - stats, - user, - Some(1), - 0, - bytes_me2c, - conn_id, - false, - false, - ) - .await -} - -#[tokio::test] -async fn abridged_max_extended_length_fails_closed_without_panic_or_partial_read() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("middle relay test lock must be available"); - - let (reader, mut writer) = duplex(256); - let mut crypto_reader = make_crypto_reader(reader); - let buffer_pool = Arc::new(BufferPool::new()); - let stats = Stats::new(); - let forensics = make_forensics_state(); - let mut frame_counter = 0; - - let plaintext = vec![0x7f, 0xff, 0xff, 0xff]; - let encrypted = encrypt_for_reader(&plaintext); - writer.write_all(&encrypted).await.unwrap(); - - let result = read_client_payload( - &mut crypto_reader, - ProtoTag::Abridged, - 4096, - TokioDuration::from_secs(1), - &buffer_pool, - &forensics, - &mut frame_counter, - &stats, - ) - .await; - - assert!( - result.is_err(), - "oversized abridged length must fail closed" - ); - assert_eq!( - frame_counter, 0, - "oversized frame must not be counted as accepted" - ); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn deterministic_quota_race_exactly_one_succeeds_and_one_is_rejected() { - let stats = Stats::new(); - let bytes_me2c = AtomicU64::new(0); - let user = "gap-t04-race-user"; - let barrier = Arc::new(Barrier::new(2)); - - let f1 = run_quota_race_attempt(&stats, &bytes_me2c, user, 0x11, 5001, barrier.clone()); - let f2 = run_quota_race_attempt(&stats, &bytes_me2c, user, 0x22, 5002, barrier); - - let (r1, r2) = tokio::join!(f1, f2); - - assert!( - matches!(r1, Ok(_) | Err(ProxyError::DataQuotaExceeded { .. })), - "first racer must either finish or fail closed on quota" - ); - assert!( - matches!(r2, Ok(_) | Err(ProxyError::DataQuotaExceeded { .. })), - "second racer must either finish or fail closed on quota" - ); - assert!( - matches!(r1, Err(ProxyError::DataQuotaExceeded { .. })) - || matches!(r2, Err(ProxyError::DataQuotaExceeded { .. })), - "at least one racer must be quota-rejected" - ); - assert_eq!( - stats.get_user_total_octets(user), - 1, - "same-user race must forward/account exactly one payload byte" - ); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn stress_quota_race_bursts_never_allow_double_success_per_round() { - let stats = Stats::new(); - let bytes_me2c = AtomicU64::new(0); - - for round in 0..128u64 { - let user = format!("gap-t04-race-burst-{round}"); - let barrier = Arc::new(Barrier::new(2)); - - let one = run_quota_race_attempt( - &stats, - &bytes_me2c, - &user, - 0x33, - 6000 + round, - barrier.clone(), - ); - let two = run_quota_race_attempt(&stats, &bytes_me2c, &user, 0x44, 7000 + round, barrier); - - let (r1, r2) = tokio::join!(one, two); - assert!( - matches!(r1, Ok(_) | Err(ProxyError::DataQuotaExceeded { .. })) - && matches!(r2, Ok(_) | Err(ProxyError::DataQuotaExceeded { .. })), - "round {round}: racers must resolve cleanly without unexpected errors" - ); - assert!( - matches!(r1, Err(ProxyError::DataQuotaExceeded { .. })) - || matches!(r2, Err(ProxyError::DataQuotaExceeded { .. })), - "round {round}: at least one racer must be quota-rejected" - ); - assert_eq!( - stats.get_user_total_octets(&user), - 1, - "round {round}: same-user total octets must remain exactly 1 (single forwarded winner)" - ); - } -} - -#[tokio::test] -async fn middle_relay_cutover_storm_multi_session_keeps_generic_errors_and_releases_gauge() { - let session_count = 6usize; - let stats = Arc::new(Stats::new()); - let me_pool = make_me_pool_for_abort_test(stats.clone()).await; - let config = Arc::new(ProxyConfig::default()); - let buffer_pool = Arc::new(BufferPool::new()); - let rng = Arc::new(SecureRandom::new()); - - let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Middle)); - let route_snapshot = route_runtime.snapshot(); - - let mut relay_tasks = Vec::with_capacity(session_count); - let mut client_sides = Vec::with_capacity(session_count); - - for idx in 0..session_count { - let (server_side, client_side) = duplex(64 * 1024); - client_sides.push(client_side); - let (server_reader, server_writer) = tokio::io::split(server_side); - let crypto_reader = make_crypto_reader(server_reader); - let crypto_writer = make_crypto_writer(server_writer); - - let success = HandshakeSuccess { - user: format!("cutover-storm-middle-user-{idx}"), - dc_idx: 2, - proto_tag: ProtoTag::Intermediate, - dec_key: [0u8; 32], - dec_iv: 0, - enc_key: [0u8; 32], - enc_iv: 0, - peer: SocketAddr::new( - std::net::IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 1)), - 52000 + idx as u16, - ), - is_tls: false, - }; - - relay_tasks.push(tokio::spawn(handle_via_middle_proxy( - crypto_reader, - crypto_writer, - success, - me_pool.clone(), - stats.clone(), - config.clone(), - buffer_pool.clone(), - "127.0.0.1:443".parse().unwrap(), - rng.clone(), - route_runtime.subscribe(), - route_snapshot, - 0xB000_0000 + idx as u64, - ))); - } - - tokio::time::timeout(TokioDuration::from_secs(4), async { - loop { - if stats.get_current_connections_me() == session_count as u64 { - break; - } - tokio::time::sleep(TokioDuration::from_millis(10)).await; - } - }) - .await - .expect("all middle sessions must become active before cutover storm"); - - let route_runtime_flipper = route_runtime.clone(); - let flipper = tokio::spawn(async move { - for step in 0..64u32 { - let mode = if (step & 1) == 0 { - RelayRouteMode::Direct - } else { - RelayRouteMode::Middle - }; - let _ = route_runtime_flipper.set_mode(mode); - tokio::time::sleep(TokioDuration::from_millis(15)).await; - } - }); - - for relay_task in relay_tasks { - let relay_result = tokio::time::timeout(TokioDuration::from_secs(10), relay_task) - .await - .expect("middle relay task must finish under cutover storm") - .expect("middle relay task must not panic"); - - assert!( - matches!( - relay_result, - Err(ProxyError::Proxy(ref msg)) if msg == ROUTE_SWITCH_ERROR_MSG - ), - "storm-cutover termination must remain generic for all middle sessions" - ); - } - - flipper.abort(); - let _ = flipper.await; - - assert_eq!( - stats.get_current_connections_me(), - 0, - "middle route gauge must return to zero after cutover storm" - ); - - drop(client_sides); -} - -#[tokio::test] -async fn secure_padding_distribution_in_relay_writer() { - timeout(TokioDuration::from_secs(10), async { - let (mut client_side, relay_side) = duplex(512 * 1024); - let key = [0u8; 32]; - let iv = 0u128; - let mut writer = CryptoWriter::new(relay_side, AesCtr::new(&key, iv), 8 * 1024); - let rng = Arc::new(SecureRandom::new()); - let mut frame_buf = Vec::new(); - let mut decryptor = AesCtr::new(&key, iv); - - let mut padding_counts = [0usize; 4]; - let iterations = 180usize; - let payload = vec![0xAAu8; 100]; // 4-byte aligned - - for _ in 0..iterations { - write_client_payload( - &mut writer, - ProtoTag::Secure, - 0, - &payload, - &rng, - &mut frame_buf, - ) - .await - .expect("payload write must succeed"); - writer - .flush() - .await - .expect("writer flush must complete so encrypted frame becomes readable"); - - let mut len_buf = [0u8; 4]; - client_side - .read_exact(&mut len_buf) - .await - .expect("must read encrypted secure length"); - let decrypted_len_bytes = decryptor.decrypt(&len_buf); - let decrypted_len_bytes: [u8; 4] = decrypted_len_bytes - .try_into() - .expect("decrypted length must be 4 bytes"); - let wire_len = (u32::from_le_bytes(decrypted_len_bytes) & 0x7fff_ffff) as usize; - - assert!( - wire_len >= payload.len(), - "wire length must include at least payload bytes" - ); - let padding_len = wire_len - payload.len(); - assert!(padding_len >= 1 && padding_len <= 3); - padding_counts[padding_len] += 1; - - // Drain and decrypt frame bytes so CTR state stays aligned across writes. - let mut trash = vec![0u8; wire_len]; - client_side - .read_exact(&mut trash) - .await - .expect("must read encrypted secure frame body"); - let _ = decryptor.decrypt(&trash); - } - - for p in 1..=3 { - let count = padding_counts[p]; - assert!( - count > iterations / 8, - "padding length {p} is under-represented ({count}/{iterations})" - ); - } - }) - .await - .expect("secure padding distribution test exceeded runtime budget"); -} - -#[tokio::test] -async fn negative_middle_end_connection_lost_during_relay_exits_on_client_eof() { - let (client_reader_side, client_writer_side) = duplex(1024); - let (_relay_reader_side, relay_writer_side) = duplex(1024); - - let key = [0u8; 32]; - let iv = 0u128; - let crypto_reader = CryptoReader::new(client_reader_side, AesCtr::new(&key, iv)); - let crypto_writer = CryptoWriter::new(relay_writer_side, AesCtr::new(&key, iv), 1024); - - let stats = Arc::new(Stats::new()); - let config = Arc::new(ProxyConfig::default()); - let buffer_pool = Arc::new(BufferPool::with_config(1024, 1)); - let rng = Arc::new(SecureRandom::new()); - let route_runtime = RouteRuntimeController::new(RelayRouteMode::Middle); - - // Create an ME pool. - let me_pool = make_me_pool_for_abort_test(stats.clone()).await; - - // ConnRegistry ids are monotonic; reserve one id so we can predict the - // next session conn_id and close it deterministically without relying on - // writer-bound views such as active_conn_ids(). - let (probe_conn_id, probe_rx) = me_pool.registry().register().await; - drop(probe_rx); - me_pool.registry().unregister(probe_conn_id).await; - let target_conn_id = probe_conn_id.wrapping_add(1); - - let success = HandshakeSuccess { - user: "test-user".to_string(), - peer: "127.0.0.1:12345".parse().unwrap(), - dc_idx: 1, - proto_tag: ProtoTag::Intermediate, - enc_key: key, - enc_iv: iv, - dec_key: key, - dec_iv: iv, - is_tls: false, - }; - - let session_task = tokio::spawn(handle_via_middle_proxy( - crypto_reader, - crypto_writer, - success, - me_pool.clone(), - stats.clone(), - config.clone(), - buffer_pool.clone(), - "127.0.0.1:443".parse().unwrap(), - rng.clone(), - route_runtime.subscribe(), - route_runtime.snapshot(), - 0x1234_5678, - )); - - // Wait until session startup is visible, then unregister the predicted - // conn_id to close the per-session ME response channel. - timeout(TokioDuration::from_millis(500), async { - loop { - if stats.get_current_connections_me() >= 1 { - break; - } - tokio::time::sleep(TokioDuration::from_millis(10)).await; - } - }) - .await - .expect("ME session must start before channel close simulation"); - - me_pool.registry().unregister(target_conn_id).await; - - drop(client_writer_side); - - let result = timeout(TokioDuration::from_secs(2), session_task) - .await - .expect("Session task must terminate after ME drop and client EOF") - .expect("Session task must not panic"); - - assert!( - result.is_ok(), - "Session should complete cleanly after ME drop when client closes, got: {:?}", - result - ); -} - -#[tokio::test] -async fn adversarial_middle_end_drop_plus_cutover_returns_generic_route_switch() { - let (client_reader_side, _client_writer_side) = duplex(1024); - let (_relay_reader_side, relay_writer_side) = duplex(1024); - - let key = [0u8; 32]; - let iv = 0u128; - let crypto_reader = CryptoReader::new(client_reader_side, AesCtr::new(&key, iv)); - let crypto_writer = CryptoWriter::new(relay_writer_side, AesCtr::new(&key, iv), 1024); - - let stats = Arc::new(Stats::new()); - let config = Arc::new(ProxyConfig::default()); - let buffer_pool = Arc::new(BufferPool::with_config(1024, 1)); - let rng = Arc::new(SecureRandom::new()); - let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Middle)); - - let me_pool = make_me_pool_for_abort_test(stats.clone()).await; - - // Predict the next conn_id so we can force-drop its ME channel deterministically. - let (probe_conn_id, probe_rx) = me_pool.registry().register().await; - drop(probe_rx); - me_pool.registry().unregister(probe_conn_id).await; - let target_conn_id = probe_conn_id.wrapping_add(1); - - let success = HandshakeSuccess { - user: "test-user-cutover".to_string(), - peer: "127.0.0.1:12345".parse().unwrap(), - dc_idx: 1, - proto_tag: ProtoTag::Intermediate, - enc_key: key, - enc_iv: iv, - dec_key: key, - dec_iv: iv, - is_tls: false, - }; - - let runtime_clone = route_runtime.clone(); - let session_task = tokio::spawn(handle_via_middle_proxy( - crypto_reader, - crypto_writer, - success, - me_pool.clone(), - stats.clone(), - config, - buffer_pool, - "127.0.0.1:443".parse().unwrap(), - rng, - runtime_clone.subscribe(), - runtime_clone.snapshot(), - 0xC001_CAFE, - )); - - timeout(TokioDuration::from_millis(500), async { - loop { - if stats.get_current_connections_me() >= 1 { - break; - } - tokio::time::sleep(TokioDuration::from_millis(10)).await; - } - }) - .await - .expect("ME session must start before race trigger"); - - // Race ME channel drop with route cutover and assert generic client-visible outcome. - me_pool.registry().unregister(target_conn_id).await; - assert!( - route_runtime.set_mode(RelayRouteMode::Direct).is_some(), - "cutover must advance generation" - ); - - let relay_result = timeout(TokioDuration::from_secs(6), session_task) - .await - .expect("session must terminate under ME-drop + cutover race") - .expect("session task must not panic"); - - assert!( - matches!( - relay_result, - Err(ProxyError::Proxy(ref msg)) if msg == ROUTE_SWITCH_ERROR_MSG - ), - "race outcome must remain generic and not leak ME internals, got: {:?}", - relay_result - ); -} - -#[tokio::test] -async fn stress_middle_end_drop_with_client_eof_never_hangs_across_burst() { - let stats = Arc::new(Stats::new()); - let me_pool = make_me_pool_for_abort_test(stats.clone()).await; - - for round in 0..32u64 { - let (client_reader_side, client_writer_side) = duplex(1024); - let (_relay_reader_side, relay_writer_side) = duplex(1024); - - let key = [0u8; 32]; - let iv = 0u128; - let crypto_reader = CryptoReader::new(client_reader_side, AesCtr::new(&key, iv)); - let crypto_writer = CryptoWriter::new(relay_writer_side, AesCtr::new(&key, iv), 1024); - - let config = Arc::new(ProxyConfig::default()); - let buffer_pool = Arc::new(BufferPool::with_config(1024, 1)); - let rng = Arc::new(SecureRandom::new()); - let route_runtime = RouteRuntimeController::new(RelayRouteMode::Middle); - - let (probe_conn_id, probe_rx) = me_pool.registry().register().await; - drop(probe_rx); - me_pool.registry().unregister(probe_conn_id).await; - let target_conn_id = probe_conn_id.wrapping_add(1); - - let success = HandshakeSuccess { - user: format!("stress-me-drop-eof-{round}"), - peer: "127.0.0.1:12345".parse().unwrap(), - dc_idx: 1, - proto_tag: ProtoTag::Intermediate, - enc_key: key, - enc_iv: iv, - dec_key: key, - dec_iv: iv, - is_tls: false, - }; - - let session_task = tokio::spawn(handle_via_middle_proxy( - crypto_reader, - crypto_writer, - success, - me_pool.clone(), - stats.clone(), - config, - buffer_pool, - "127.0.0.1:443".parse().unwrap(), - rng, - route_runtime.subscribe(), - route_runtime.snapshot(), - 0xD00D_0000 + round, - )); - - timeout(TokioDuration::from_millis(500), async { - loop { - if stats.get_current_connections_me() >= 1 { - break; - } - tokio::time::sleep(TokioDuration::from_millis(10)).await; - } - }) - .await - .expect("session must start before forced drop in burst round"); - - me_pool.registry().unregister(target_conn_id).await; - drop(client_writer_side); - - let result = timeout(TokioDuration::from_secs(2), session_task) - .await - .expect("burst round session must terminate quickly") - .expect("burst round session must not panic"); - - assert!( - result.is_ok(), - "burst round {round}: expected clean shutdown after ME drop + EOF, got: {:?}", - result - ); - } -} diff --git a/src/proxy/tests/middle_relay_tiny_frame_debt_concurrency_security_tests.rs b/src/proxy/tests/middle_relay_tiny_frame_debt_concurrency_security_tests.rs new file mode 100644 index 0000000..6b1d511 --- /dev/null +++ b/src/proxy/tests/middle_relay_tiny_frame_debt_concurrency_security_tests.rs @@ -0,0 +1,372 @@ +use super::*; +use crate::crypto::AesCtr; +use crate::stats::Stats; +use crate::stream::{BufferPool, CryptoReader}; +use std::sync::Arc; +use std::sync::atomic::AtomicU64; +use std::time::Instant; +use tokio::io::{AsyncRead, AsyncWriteExt, duplex}; +use tokio::task::JoinSet; +use tokio::time::{Duration as TokioDuration, sleep}; + +fn make_crypto_reader(reader: T) -> CryptoReader +where + T: AsyncRead + Unpin + Send + 'static, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoReader::new(reader, AesCtr::new(&key, iv)) +} + +fn encrypt_for_reader(plaintext: &[u8]) -> Vec { + let key = [0u8; 32]; + let iv = 0u128; + let mut cipher = AesCtr::new(&key, iv); + cipher.encrypt(plaintext) +} + +fn make_forensics(conn_id: u64, started_at: Instant) -> RelayForensicsState { + RelayForensicsState { + trace_id: 0xB200_0000 + conn_id, + conn_id, + user: format!("tiny-frame-debt-concurrency-user-{conn_id}"), + peer: "127.0.0.1:50000".parse().expect("peer parse must succeed"), + peer_hash: hash_ip("127.0.0.1".parse().expect("ip parse must succeed")), + started_at, + bytes_c2me: 0, + bytes_me2c: Arc::new(AtomicU64::new(0)), + desync_all_full: false, + } +} + +fn make_enabled_idle_policy() -> RelayClientIdlePolicy { + RelayClientIdlePolicy { + enabled: true, + soft_idle: Duration::from_millis(50), + hard_idle: Duration::from_millis(120), + grace_after_downstream_activity: Duration::from_secs(0), + legacy_frame_read_timeout: Duration::from_millis(50), + } +} + +async fn read_once( + crypto_reader: &mut CryptoReader, + proto: ProtoTag, + forensics: &RelayForensicsState, + frame_counter: &mut u64, + idle_state: &mut RelayClientIdleState, +) -> Result> { + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let idle_policy = make_enabled_idle_policy(); + let last_downstream_activity_ms = AtomicU64::new(0); + read_client_payload_with_idle_policy( + crypto_reader, + proto, + 1024, + &buffer_pool, + forensics, + frame_counter, + &stats, + &idle_policy, + idle_state, + &last_downstream_activity_ms, + forensics.started_at, + ) + .await +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_parallel_pure_tiny_floods_all_fail_closed() { + let mut set = JoinSet::new(); + + for idx in 0..32u64 { + set.spawn(async move { + let (reader, mut writer) = duplex(4096); + let mut crypto_reader = make_crypto_reader(reader); + let started = Instant::now(); + let forensics = make_forensics(1000 + idx, started); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(started); + + let flood_plaintext = vec![0u8; 1024]; + let flood_encrypted = encrypt_for_reader(&flood_plaintext); + writer.write_all(&flood_encrypted).await.unwrap(); + drop(writer); + + let result = run_relay_test_step_timeout( + "tiny flood task", + read_once( + &mut crypto_reader, + ProtoTag::Abridged, + &forensics, + &mut frame_counter, + &mut idle_state, + ), + ) + .await; + + assert!(matches!(result, Err(ProxyError::Proxy(_)))); + assert_eq!(frame_counter, 0); + }); + } + + while let Some(result) = set.join_next().await { + result.expect("parallel tiny flood worker must not panic"); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_parallel_benign_tiny_burst_then_real_all_pass() { + let mut set = JoinSet::new(); + + for idx in 0..24u64 { + set.spawn(async move { + let (reader, mut writer) = duplex(2048); + let mut crypto_reader = make_crypto_reader(reader); + let started = Instant::now(); + let forensics = make_forensics(2000 + idx, started); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(started); + + let payload = [idx as u8, 2, 3, 4]; + let mut plaintext = Vec::with_capacity(20); + for _ in 0..6 { + plaintext.push(0x00); + } + plaintext.push(0x01); + plaintext.extend_from_slice(&payload); + let encrypted = encrypt_for_reader(&plaintext); + writer.write_all(&encrypted).await.unwrap(); + + let result = run_relay_test_step_timeout( + "benign tiny burst read", + read_once( + &mut crypto_reader, + ProtoTag::Abridged, + &forensics, + &mut frame_counter, + &mut idle_state, + ), + ) + .await + .expect("benign payload must parse") + .expect("benign payload must return frame"); + + assert_eq!(result.0.as_ref(), &payload); + assert_eq!(frame_counter, 1); + }); + } + + while let Some(result) = set.join_next().await { + result.expect("parallel benign worker must not panic"); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn adversarial_lockstep_alternating_attack_under_jitter_closes() { + let mut set = JoinSet::new(); + + for idx in 0..12u64 { + set.spawn(async move { + let (reader, mut writer) = duplex(8192); + let mut crypto_reader = make_crypto_reader(reader); + let started = Instant::now(); + let forensics = make_forensics(3000 + idx, started); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(started); + + let mut plaintext = Vec::with_capacity(2000); + for n in 0..180u8 { + plaintext.push(0x00); + plaintext.push(0x01); + plaintext.extend_from_slice(&[n, n ^ 0x21, n ^ 0x42, n ^ 0x84]); + } + let encrypted = encrypt_for_reader(&plaintext); + + let writer_task = tokio::spawn(async move { + for chunk in encrypted.chunks(17) { + writer.write_all(chunk).await.unwrap(); + sleep(TokioDuration::from_millis(1)).await; + } + drop(writer); + }); + + let mut closed = false; + for _ in 0..220 { + let result = run_relay_test_step_timeout( + "alternating jitter read step", + read_once( + &mut crypto_reader, + ProtoTag::Abridged, + &forensics, + &mut frame_counter, + &mut idle_state, + ), + ) + .await; + + match result { + Ok(Some((_payload, _))) => {} + Err(ProxyError::Proxy(_)) => { + closed = true; + break; + } + Ok(None) => break, + Err(other) => panic!("unexpected error in alternating jitter case: {other}"), + } + } + + writer_task + .await + .expect("writer jitter task must not panic"); + assert!(closed, "alternating attack must close before EOF"); + }); + } + + while let Some(result) = set.join_next().await { + result.expect("alternating jitter worker must not panic"); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn integration_mixed_population_attackers_close_benign_survive() { + let mut set = JoinSet::new(); + + for idx in 0..20u64 { + set.spawn(async move { + let (reader, mut writer) = duplex(4096); + let mut crypto_reader = make_crypto_reader(reader); + let started = Instant::now(); + let forensics = make_forensics(4000 + idx, started); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(started); + + if idx % 2 == 0 { + let mut plaintext = Vec::with_capacity(1280); + for n in 0..140u8 { + plaintext.push(0x00); + plaintext.push(0x01); + plaintext.extend_from_slice(&[n, n, n, n]); + } + writer + .write_all(&encrypt_for_reader(&plaintext)) + .await + .unwrap(); + drop(writer); + + let mut closed = false; + for _ in 0..200 { + match read_once( + &mut crypto_reader, + ProtoTag::Abridged, + &forensics, + &mut frame_counter, + &mut idle_state, + ) + .await + { + Ok(Some(_)) => {} + Err(ProxyError::Proxy(_)) => { + closed = true; + break; + } + Ok(None) => break, + Err(other) => panic!("unexpected attacker error: {other}"), + } + } + assert!(closed, "attacker session must fail closed"); + } else { + let payload = [1u8, 9, 8, 7]; + let mut plaintext = Vec::new(); + for _ in 0..4 { + plaintext.push(0x00); + } + plaintext.push(0x01); + plaintext.extend_from_slice(&payload); + writer + .write_all(&encrypt_for_reader(&plaintext)) + .await + .unwrap(); + + let got = read_once( + &mut crypto_reader, + ProtoTag::Abridged, + &forensics, + &mut frame_counter, + &mut idle_state, + ) + .await + .expect("benign session must parse") + .expect("benign session must return a frame"); + assert_eq!(got.0.as_ref(), &payload); + } + }); + } + + while let Some(result) = set.join_next().await { + result.expect("mixed-population worker must not panic"); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn light_fuzz_parallel_patterns_no_hang_or_panic() { + let mut set = JoinSet::new(); + + for case in 0..40u64 { + set.spawn(async move { + let (reader, mut writer) = duplex(8192); + let mut crypto_reader = make_crypto_reader(reader); + let started = Instant::now(); + let forensics = make_forensics(5000 + case, started); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(started); + + let mut seed = 0x9E37_79B9u64 ^ (case << 8); + let mut plaintext = Vec::with_capacity(2048); + for _ in 0..256 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + let is_tiny = (seed & 1) == 0; + if is_tiny { + plaintext.push(0x00); + } else { + plaintext.push(0x01); + plaintext.extend_from_slice(&[(seed >> 8) as u8, 2, 3, 4]); + } + } + + writer + .write_all(&encrypt_for_reader(&plaintext)) + .await + .unwrap(); + drop(writer); + + for _ in 0..320 { + let step = run_relay_test_step_timeout( + "fuzz case read step", + read_once( + &mut crypto_reader, + ProtoTag::Abridged, + &forensics, + &mut frame_counter, + &mut idle_state, + ), + ) + .await; + + match step { + Ok(Some(_)) => {} + Err(ProxyError::Proxy(_)) => break, + Ok(None) => break, + Err(other) => panic!("unexpected fuzz case error: {other}"), + } + } + }); + } + + while let Some(result) = set.join_next().await { + result.expect("fuzz worker must not panic"); + } +} diff --git a/src/proxy/tests/middle_relay_tiny_frame_debt_proto_chunking_security_tests.rs b/src/proxy/tests/middle_relay_tiny_frame_debt_proto_chunking_security_tests.rs new file mode 100644 index 0000000..cbbc971 --- /dev/null +++ b/src/proxy/tests/middle_relay_tiny_frame_debt_proto_chunking_security_tests.rs @@ -0,0 +1,435 @@ +use super::*; +use crate::crypto::AesCtr; +use crate::stats::Stats; +use crate::stream::{BufferPool, CryptoReader, PooledBuffer}; +use std::sync::Arc; +use std::sync::atomic::AtomicU64; +use std::time::Instant; +use tokio::io::{AsyncRead, AsyncWriteExt, duplex}; +use tokio::time::{Duration as TokioDuration, sleep}; + +fn make_crypto_reader(reader: T) -> CryptoReader +where + T: AsyncRead + Unpin + Send + 'static, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoReader::new(reader, AesCtr::new(&key, iv)) +} + +fn encrypt_for_reader(plaintext: &[u8]) -> Vec { + let key = [0u8; 32]; + let iv = 0u128; + let mut cipher = AesCtr::new(&key, iv); + cipher.encrypt(plaintext) +} + +fn make_forensics(conn_id: u64, started_at: Instant) -> RelayForensicsState { + RelayForensicsState { + trace_id: 0xB300_0000 + conn_id, + conn_id, + user: format!("tiny-frame-debt-proto-chunk-user-{conn_id}"), + peer: "127.0.0.1:50000".parse().expect("peer parse must succeed"), + peer_hash: hash_ip("127.0.0.1".parse().expect("ip parse must succeed")), + started_at, + bytes_c2me: 0, + bytes_me2c: Arc::new(AtomicU64::new(0)), + desync_all_full: false, + } +} + +fn make_enabled_idle_policy() -> RelayClientIdlePolicy { + RelayClientIdlePolicy { + enabled: true, + soft_idle: Duration::from_millis(50), + hard_idle: Duration::from_millis(120), + grace_after_downstream_activity: Duration::from_secs(0), + legacy_frame_read_timeout: Duration::from_millis(50), + } +} + +fn append_tiny_frame(plaintext: &mut Vec, proto: ProtoTag) { + match proto { + ProtoTag::Abridged => plaintext.push(0x00), + ProtoTag::Intermediate | ProtoTag::Secure => { + plaintext.extend_from_slice(&0u32.to_le_bytes()) + } + } +} + +fn append_real_frame(plaintext: &mut Vec, proto: ProtoTag, payload: [u8; 4]) { + match proto { + ProtoTag::Abridged => { + plaintext.push(0x01); + plaintext.extend_from_slice(&payload); + } + ProtoTag::Intermediate | ProtoTag::Secure => { + plaintext.extend_from_slice(&4u32.to_le_bytes()); + plaintext.extend_from_slice(&payload); + } + } +} + +async fn write_chunked_with_jitter( + writer: &mut tokio::io::DuplexStream, + bytes: &[u8], + mut seed: u64, +) { + let mut offset = 0usize; + while offset < bytes.len() { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + let chunk_len = 1 + ((seed as usize) & 0x1f); + let end = (offset + chunk_len).min(bytes.len()); + writer.write_all(&bytes[offset..end]).await.unwrap(); + + let delay_ms = ((seed >> 16) % 3) as u64; + if delay_ms > 0 { + sleep(TokioDuration::from_millis(delay_ms)).await; + } + offset = end; + } +} + +async fn read_once_with_state( + crypto_reader: &mut CryptoReader, + proto: ProtoTag, + forensics: &RelayForensicsState, + frame_counter: &mut u64, + idle_state: &mut RelayClientIdleState, +) -> Result> { + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let idle_policy = make_enabled_idle_policy(); + let last_downstream_activity_ms = AtomicU64::new(0); + read_client_payload_with_idle_policy( + crypto_reader, + proto, + 1024, + &buffer_pool, + forensics, + frame_counter, + &stats, + &idle_policy, + idle_state, + &last_downstream_activity_ms, + forensics.started_at, + ) + .await +} + +fn is_fail_closed_outcome(result: &Result>) -> bool { + matches!(result, Err(ProxyError::Proxy(_))) + || matches!(result, Err(ProxyError::Io(e)) if e.kind() == std::io::ErrorKind::TimedOut) +} + +#[tokio::test] +async fn intermediate_chunked_zero_flood_fail_closed() { + let (reader, mut writer) = duplex(4096); + let mut crypto_reader = make_crypto_reader(reader); + let started = Instant::now(); + let forensics = make_forensics(6101, started); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(started); + + let mut plaintext = Vec::with_capacity(4 * 256); + for _ in 0..256 { + append_tiny_frame(&mut plaintext, ProtoTag::Intermediate); + } + let encrypted = encrypt_for_reader(&plaintext); + write_chunked_with_jitter(&mut writer, &encrypted, 0x1111_2222).await; + drop(writer); + + let result = run_relay_test_step_timeout( + "intermediate flood read", + read_once_with_state( + &mut crypto_reader, + ProtoTag::Intermediate, + &forensics, + &mut frame_counter, + &mut idle_state, + ), + ) + .await; + + assert!( + is_fail_closed_outcome(&result), + "zero-length flood must fail closed via debt guard or idle timeout" + ); + assert_eq!(frame_counter, 0); +} + +#[tokio::test] +async fn secure_chunked_zero_flood_fail_closed() { + let (reader, mut writer) = duplex(4096); + let mut crypto_reader = make_crypto_reader(reader); + let started = Instant::now(); + let forensics = make_forensics(6102, started); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(started); + + let mut plaintext = Vec::with_capacity(4 * 256); + for _ in 0..256 { + append_tiny_frame(&mut plaintext, ProtoTag::Secure); + } + let encrypted = encrypt_for_reader(&plaintext); + write_chunked_with_jitter(&mut writer, &encrypted, 0x3333_4444).await; + drop(writer); + + let result = run_relay_test_step_timeout( + "secure flood read", + read_once_with_state( + &mut crypto_reader, + ProtoTag::Secure, + &forensics, + &mut frame_counter, + &mut idle_state, + ), + ) + .await; + + assert!( + is_fail_closed_outcome(&result), + "secure zero-length flood must fail closed via debt guard or idle timeout" + ); + assert_eq!(frame_counter, 0); +} + +#[tokio::test] +async fn intermediate_chunked_alternating_attack_closes_before_eof() { + let (reader, mut writer) = duplex(8192); + let mut crypto_reader = make_crypto_reader(reader); + let started = Instant::now(); + let forensics = make_forensics(6103, started); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(started); + + let mut plaintext = Vec::with_capacity(8 * 200); + for n in 0..180u8 { + append_tiny_frame(&mut plaintext, ProtoTag::Intermediate); + append_real_frame( + &mut plaintext, + ProtoTag::Intermediate, + [n, n ^ 1, n ^ 2, n ^ 3], + ); + } + let encrypted = encrypt_for_reader(&plaintext); + + let writer_task = tokio::spawn(async move { + write_chunked_with_jitter(&mut writer, &encrypted, 0x5555_6666).await; + drop(writer); + }); + + let mut closed = false; + for _ in 0..240 { + let step = run_relay_test_step_timeout( + "intermediate alternating read step", + read_once_with_state( + &mut crypto_reader, + ProtoTag::Intermediate, + &forensics, + &mut frame_counter, + &mut idle_state, + ), + ) + .await; + + match step { + Ok(Some(_)) => {} + Err(ProxyError::Proxy(_)) => { + closed = true; + break; + } + Ok(None) => break, + Err(other) => panic!("unexpected intermediate alternating error: {other}"), + } + } + + writer_task + .await + .expect("intermediate writer task must not panic"); + assert!(closed, "intermediate alternating attack must fail closed"); +} + +#[tokio::test] +async fn secure_chunked_alternating_attack_closes_before_eof() { + let (reader, mut writer) = duplex(8192); + let mut crypto_reader = make_crypto_reader(reader); + let started = Instant::now(); + let forensics = make_forensics(6104, started); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(started); + + let mut plaintext = Vec::with_capacity(8 * 200); + for n in 0..180u8 { + append_tiny_frame(&mut plaintext, ProtoTag::Secure); + append_real_frame(&mut plaintext, ProtoTag::Secure, [n, n ^ 7, n ^ 11, n ^ 19]); + } + let encrypted = encrypt_for_reader(&plaintext); + + let writer_task = tokio::spawn(async move { + write_chunked_with_jitter(&mut writer, &encrypted, 0x7777_8888).await; + drop(writer); + }); + + let mut closed = false; + for _ in 0..240 { + let step = run_relay_test_step_timeout( + "secure alternating read step", + read_once_with_state( + &mut crypto_reader, + ProtoTag::Secure, + &forensics, + &mut frame_counter, + &mut idle_state, + ), + ) + .await; + + match step { + Ok(Some(_)) => {} + Err(ProxyError::Proxy(_)) => { + closed = true; + break; + } + Ok(None) => break, + Err(other) => panic!("unexpected secure alternating error: {other}"), + } + } + + writer_task + .await + .expect("secure writer task must not panic"); + assert!(closed, "secure alternating attack must fail closed"); +} + +#[tokio::test] +async fn intermediate_chunked_safe_small_burst_still_returns_real_frame() { + let (reader, mut writer) = duplex(1024); + let mut crypto_reader = make_crypto_reader(reader); + let started = Instant::now(); + let forensics = make_forensics(6105, started); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(started); + + let payload = [9u8, 8, 7, 6]; + let mut plaintext = Vec::new(); + for _ in 0..7 { + append_tiny_frame(&mut plaintext, ProtoTag::Intermediate); + } + append_real_frame(&mut plaintext, ProtoTag::Intermediate, payload); + let encrypted = encrypt_for_reader(&plaintext); + write_chunked_with_jitter(&mut writer, &encrypted, 0xAAAA_BBBB).await; + + let result = read_once_with_state( + &mut crypto_reader, + ProtoTag::Intermediate, + &forensics, + &mut frame_counter, + &mut idle_state, + ) + .await + .expect("intermediate safe burst should parse") + .expect("intermediate safe burst should return a frame"); + + assert_eq!(result.0.as_ref(), &payload); + assert_eq!(frame_counter, 1); +} + +#[tokio::test] +async fn secure_chunked_safe_small_burst_still_returns_real_frame() { + let (reader, mut writer) = duplex(1024); + let mut crypto_reader = make_crypto_reader(reader); + let started = Instant::now(); + let forensics = make_forensics(6106, started); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(started); + + let payload = [3u8, 1, 4, 1]; + let mut plaintext = Vec::new(); + for _ in 0..7 { + append_tiny_frame(&mut plaintext, ProtoTag::Secure); + } + append_real_frame(&mut plaintext, ProtoTag::Secure, payload); + let encrypted = encrypt_for_reader(&plaintext); + write_chunked_with_jitter(&mut writer, &encrypted, 0xCCCC_DDDD).await; + + let result = read_once_with_state( + &mut crypto_reader, + ProtoTag::Secure, + &forensics, + &mut frame_counter, + &mut idle_state, + ) + .await + .expect("secure safe burst should parse") + .expect("secure safe burst should return a frame"); + + assert_eq!(result.0.as_ref(), &payload); + assert_eq!(frame_counter, 1); +} + +#[tokio::test] +async fn light_fuzz_proto_chunking_outcomes_are_bounded() { + let mut seed = 0xDEAD_BEEF_2026_0322u64; + + for case in 0..48u64 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let proto = if (seed & 1) == 0 { + ProtoTag::Intermediate + } else { + ProtoTag::Secure + }; + + let (reader, mut writer) = duplex(8192); + let mut crypto_reader = make_crypto_reader(reader); + let started = Instant::now(); + let forensics = make_forensics(6200 + case, started); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(started); + + let mut stream = Vec::new(); + let mut local_seed = seed ^ case; + for _ in 0..220 { + local_seed ^= local_seed << 7; + local_seed ^= local_seed >> 9; + local_seed ^= local_seed << 8; + if (local_seed & 1) == 0 { + append_tiny_frame(&mut stream, proto); + } else { + let b = (local_seed >> 8) as u8; + append_real_frame(&mut stream, proto, [b, b ^ 0x12, b ^ 0x24, b ^ 0x48]); + } + } + + let encrypted = encrypt_for_reader(&stream); + write_chunked_with_jitter(&mut writer, &encrypted, seed ^ 0x1234_5678).await; + drop(writer); + + for _ in 0..260 { + let step = run_relay_test_step_timeout( + "fuzz proto read step", + read_once_with_state( + &mut crypto_reader, + proto, + &forensics, + &mut frame_counter, + &mut idle_state, + ), + ) + .await; + + match step { + Ok(Some((_payload, _))) => {} + Err(ProxyError::Proxy(_)) => break, + Err(ProxyError::Io(e)) if e.kind() == std::io::ErrorKind::TimedOut => break, + Ok(None) => break, + Err(other) => panic!("unexpected proto chunking fuzz error: {other}"), + } + } + } +} diff --git a/src/proxy/tests/middle_relay_tiny_frame_debt_security_tests.rs b/src/proxy/tests/middle_relay_tiny_frame_debt_security_tests.rs new file mode 100644 index 0000000..fad87d0 --- /dev/null +++ b/src/proxy/tests/middle_relay_tiny_frame_debt_security_tests.rs @@ -0,0 +1,804 @@ +use super::*; +use crate::crypto::AesCtr; +use crate::stats::Stats; +use crate::stream::{BufferPool, CryptoReader}; +use std::sync::Arc; +use std::sync::atomic::AtomicU64; +use std::time::Instant; +use tokio::io::{AsyncRead, AsyncWriteExt, duplex}; + +fn make_crypto_reader(reader: T) -> CryptoReader +where + T: AsyncRead + Unpin + Send + 'static, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoReader::new(reader, AesCtr::new(&key, iv)) +} + +fn encrypt_for_reader(plaintext: &[u8]) -> Vec { + let key = [0u8; 32]; + let iv = 0u128; + let mut cipher = AesCtr::new(&key, iv); + cipher.encrypt(plaintext) +} + +fn make_forensics(conn_id: u64, started_at: Instant) -> RelayForensicsState { + RelayForensicsState { + trace_id: 0xB100_0000 + conn_id, + conn_id, + user: format!("tiny-frame-debt-user-{conn_id}"), + peer: "127.0.0.1:50000".parse().expect("peer parse must succeed"), + peer_hash: hash_ip("127.0.0.1".parse().expect("ip parse must succeed")), + started_at, + bytes_c2me: 0, + bytes_me2c: Arc::new(AtomicU64::new(0)), + desync_all_full: false, + } +} + +fn make_enabled_idle_policy() -> RelayClientIdlePolicy { + RelayClientIdlePolicy { + enabled: true, + soft_idle: Duration::from_millis(50), + hard_idle: Duration::from_millis(120), + grace_after_downstream_activity: Duration::from_secs(0), + legacy_frame_read_timeout: Duration::from_millis(50), + } +} + +async fn read_bounded( + crypto_reader: &mut CryptoReader, + proto_tag: ProtoTag, + buffer_pool: &Arc, + forensics: &RelayForensicsState, + frame_counter: &mut u64, + stats: &Stats, + idle_policy: &RelayClientIdlePolicy, + idle_state: &mut RelayClientIdleState, + last_downstream_activity_ms: &AtomicU64, + session_started_at: Instant, +) -> Result> { + run_relay_test_step_timeout( + "tiny-frame debt read step", + read_client_payload_with_idle_policy( + crypto_reader, + proto_tag, + 1024, + buffer_pool, + forensics, + frame_counter, + stats, + idle_policy, + idle_state, + last_downstream_activity_ms, + session_started_at, + ), + ) + .await +} + +fn simulate_tiny_debt_pattern(pattern: &[bool], max_steps: usize) -> (Option, u32, usize) { + let mut debt = 0u32; + let mut reals = 0usize; + for (idx, is_tiny) in pattern.iter().copied().take(max_steps).enumerate() { + if is_tiny { + debt = debt.saturating_add(TINY_FRAME_DEBT_PER_TINY); + if debt >= TINY_FRAME_DEBT_LIMIT { + return (Some(idx + 1), debt, reals); + } + } else { + reals = reals.saturating_add(1); + debt = debt.saturating_sub(1); + } + } + (None, debt, reals) +} + +#[test] +fn tiny_frame_debt_constants_match_security_budget_expectations() { + assert_eq!(TINY_FRAME_DEBT_PER_TINY, 8); + assert_eq!(TINY_FRAME_DEBT_LIMIT, 512); +} + +#[test] +fn relay_client_idle_state_initial_debt_is_zero() { + let state = RelayClientIdleState::new(Instant::now()); + assert_eq!(state.tiny_frame_debt, 0); +} + +#[test] +fn on_client_frame_does_not_reset_tiny_frame_debt() { + let now = Instant::now(); + let mut state = RelayClientIdleState::new(now); + state.tiny_frame_debt = 77; + state.on_client_frame(now); + assert_eq!(state.tiny_frame_debt, 77); +} + +#[test] +fn tiny_frame_debt_increment_is_saturating() { + let mut debt = u32::MAX - 1; + debt = debt.saturating_add(TINY_FRAME_DEBT_PER_TINY); + assert_eq!(debt, u32::MAX); +} + +#[test] +fn tiny_frame_debt_decrement_is_saturating() { + let mut debt = 0u32; + debt = debt.saturating_sub(1); + assert_eq!(debt, 0); +} + +#[test] +fn consecutive_tiny_frames_close_exactly_at_threshold() { + let max_tiny_without_close = (TINY_FRAME_DEBT_LIMIT / TINY_FRAME_DEBT_PER_TINY) as usize; + let pattern = vec![true; max_tiny_without_close]; + let (closed_at, _, _) = simulate_tiny_debt_pattern(&pattern, pattern.len()); + assert_eq!(closed_at, Some(max_tiny_without_close)); +} + +#[test] +fn one_less_than_threshold_tiny_frames_do_not_close() { + let tiny_count = (TINY_FRAME_DEBT_LIMIT / TINY_FRAME_DEBT_PER_TINY) as usize - 1; + let pattern = vec![true; tiny_count]; + let (closed_at, debt, _) = simulate_tiny_debt_pattern(&pattern, pattern.len()); + assert_eq!(closed_at, None); + assert!(debt < TINY_FRAME_DEBT_LIMIT); +} + +#[test] +fn alternating_one_to_one_closes_with_bounded_real_frame_count() { + let mut pattern = Vec::with_capacity(512); + for _ in 0..256 { + pattern.push(true); + pattern.push(false); + } + let (closed_at, _, reals) = simulate_tiny_debt_pattern(&pattern, pattern.len()); + assert!(closed_at.is_some()); + assert!( + reals <= 80, + "expected bounded real frames before close, got {reals}" + ); +} + +#[test] +fn alternating_one_to_eight_is_stable_for_long_runs() { + let mut pattern = Vec::with_capacity(9 * 5000); + for _ in 0..5000 { + pattern.push(true); + for _ in 0..8 { + pattern.push(false); + } + } + let (closed_at, debt, _) = simulate_tiny_debt_pattern(&pattern, pattern.len()); + assert_eq!(closed_at, None); + assert!(debt <= TINY_FRAME_DEBT_PER_TINY); +} + +#[test] +fn alternating_one_to_seven_eventually_closes() { + let mut pattern = Vec::with_capacity(8 * 2000); + for _ in 0..2000 { + pattern.push(true); + for _ in 0..7 { + pattern.push(false); + } + } + let (closed_at, _, _) = simulate_tiny_debt_pattern(&pattern, pattern.len()); + assert!( + closed_at.is_some(), + "1:7 tiny-to-real must eventually close" + ); +} + +#[test] +fn two_tiny_one_real_closes_faster_than_one_to_one() { + let mut one_to_one = Vec::with_capacity(512); + for _ in 0..256 { + one_to_one.push(true); + one_to_one.push(false); + } + + let mut two_to_one = Vec::with_capacity(768); + for _ in 0..256 { + two_to_one.push(true); + two_to_one.push(true); + two_to_one.push(false); + } + + let (a_close, _, _) = simulate_tiny_debt_pattern(&one_to_one, one_to_one.len()); + let (b_close, _, _) = simulate_tiny_debt_pattern(&two_to_one, two_to_one.len()); + assert!(a_close.is_some() && b_close.is_some()); + assert!(b_close.unwrap_or(usize::MAX) < a_close.unwrap_or(0)); +} + +#[test] +fn burst_then_drain_can_recover_without_close() { + let burst_tiny = ((TINY_FRAME_DEBT_LIMIT / TINY_FRAME_DEBT_PER_TINY) / 2) as usize; + let mut pattern = Vec::with_capacity(burst_tiny + 600); + for _ in 0..burst_tiny { + pattern.push(true); + } + pattern.extend(std::iter::repeat_n(false, 600)); + + let (closed_at, debt, _) = simulate_tiny_debt_pattern(&pattern, pattern.len()); + assert_eq!(closed_at, None); + assert_eq!(debt, 0); +} + +#[test] +fn light_fuzz_tiny_frame_debt_model_stays_within_bounds() { + let mut seed = 0xA5A5_91C3_2026_0322u64; + for _case in 0..128 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let len = 512 + ((seed as usize) & 0x3ff); + let mut pattern = Vec::with_capacity(len); + let mut local_seed = seed; + for _ in 0..len { + local_seed ^= local_seed << 7; + local_seed ^= local_seed >> 9; + local_seed ^= local_seed << 8; + pattern.push((local_seed & 1) == 0); + } + + let (closed_at, debt, _) = simulate_tiny_debt_pattern(&pattern, pattern.len()); + if closed_at.is_none() { + assert!(debt < TINY_FRAME_DEBT_LIMIT); + } + assert!(debt <= u32::MAX); + } +} + +#[test] +fn stress_many_independent_simulations_keep_isolated_debt_state() { + for idx in 0..2048usize { + let mut pattern = Vec::with_capacity(64); + for j in 0..64usize { + pattern.push(((idx ^ j) & 3) == 0); + } + let (_closed_at, debt, _reals) = simulate_tiny_debt_pattern(&pattern, pattern.len()); + assert!(debt <= TINY_FRAME_DEBT_LIMIT.saturating_add(TINY_FRAME_DEBT_PER_TINY)); + } +} + +#[tokio::test] +async fn idle_policy_enabled_intermediate_zero_length_flood_is_fail_closed() { + let (reader, mut writer) = duplex(4096); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(11, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + let idle_policy = make_enabled_idle_policy(); + let last_downstream_activity_ms = AtomicU64::new(0); + + let flood_plaintext = vec![0u8; 4 * 256]; + let flood_encrypted = encrypt_for_reader(&flood_plaintext); + writer.write_all(&flood_encrypted).await.unwrap(); + drop(writer); + + let result = read_bounded( + &mut crypto_reader, + ProtoTag::Intermediate, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ) + .await; + + assert!(matches!(result, Err(ProxyError::Proxy(_)))); +} + +#[tokio::test] +async fn idle_policy_enabled_secure_zero_length_flood_is_fail_closed() { + let (reader, mut writer) = duplex(4096); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(12, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + let idle_policy = make_enabled_idle_policy(); + let last_downstream_activity_ms = AtomicU64::new(0); + + let flood_plaintext = vec![0u8; 4 * 256]; + let flood_encrypted = encrypt_for_reader(&flood_plaintext); + writer.write_all(&flood_encrypted).await.unwrap(); + drop(writer); + + let result = read_bounded( + &mut crypto_reader, + ProtoTag::Secure, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ) + .await; + + assert!(matches!(result, Err(ProxyError::Proxy(_)))); +} + +#[tokio::test] +async fn intermediate_alternating_zero_and_real_eventually_closes() { + let (reader, mut writer) = duplex(8192); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(13, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + let idle_policy = make_enabled_idle_policy(); + let last_downstream_activity_ms = AtomicU64::new(0); + + let mut plaintext = Vec::with_capacity(3000); + for idx in 0..160u8 { + plaintext.extend_from_slice(&0u32.to_le_bytes()); + plaintext.extend_from_slice(&4u32.to_le_bytes()); + plaintext.extend_from_slice(&[idx, idx ^ 0x11, idx ^ 0x22, idx ^ 0x33]); + } + let encrypted = encrypt_for_reader(&plaintext); + writer.write_all(&encrypted).await.unwrap(); + drop(writer); + + let mut closed = false; + for _ in 0..220 { + let result = read_bounded( + &mut crypto_reader, + ProtoTag::Intermediate, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ) + .await; + + match result { + Ok(Some(_)) => {} + Err(ProxyError::Proxy(_)) => { + closed = true; + break; + } + Ok(None) => break, + Err(other) => panic!("unexpected error while probing alternating close: {other}"), + } + } + + assert!(closed, "intermediate alternating attack must fail closed"); +} + +#[tokio::test] +async fn small_tiny_burst_followed_by_real_frame_does_not_spuriously_close() { + let (reader, mut writer) = duplex(1024); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(14, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + let idle_policy = make_enabled_idle_policy(); + let last_downstream_activity_ms = AtomicU64::new(0); + + let mut plaintext = Vec::with_capacity(64); + for _ in 0..8 { + plaintext.push(0x00); + } + plaintext.push(0x01); + plaintext.extend_from_slice(&[1, 2, 3, 4]); + + let encrypted = encrypt_for_reader(&plaintext); + writer.write_all(&encrypted).await.unwrap(); + + let first = read_bounded( + &mut crypto_reader, + ProtoTag::Abridged, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ) + .await; + + match first { + Ok(Some((payload, _))) => assert_eq!(payload.as_ref(), &[1, 2, 3, 4]), + Err(e) => panic!("unexpected close after small tiny burst: {e}"), + Ok(None) => panic!("unexpected EOF before real frame"), + } +} + +#[tokio::test] +async fn idle_policy_enabled_zero_length_flood_is_fail_closed() { + let (reader, mut writer) = duplex(4096); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(1, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + let idle_policy = make_enabled_idle_policy(); + let last_downstream_activity_ms = AtomicU64::new(0); + + let flood_plaintext = vec![0u8; 1024]; + let flood_encrypted = encrypt_for_reader(&flood_plaintext); + writer + .write_all(&flood_encrypted) + .await + .expect("zero-length flood bytes must be writable"); + drop(writer); + + let result = read_bounded( + &mut crypto_reader, + ProtoTag::Abridged, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ) + .await; + + assert!( + matches!(result, Err(ProxyError::Proxy(_))), + "idle policy enabled must fail closed for pure zero-length flood" + ); +} + +#[tokio::test] +async fn idle_policy_enabled_alternating_tiny_real_eventually_closes() { + let (reader, mut writer) = duplex(8192); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(2, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + let idle_policy = make_enabled_idle_policy(); + let last_downstream_activity_ms = AtomicU64::new(0); + + let mut plaintext = Vec::with_capacity(256 * 6); + for idx in 0..=255u8 { + plaintext.push(0x00); + plaintext.push(0x01); + plaintext.extend_from_slice(&[idx, idx ^ 0x55, idx ^ 0xAA, 0x11]); + } + + let encrypted = encrypt_for_reader(&plaintext); + writer + .write_all(&encrypted) + .await + .expect("alternating flood bytes must be writable"); + drop(writer); + + let mut saw_proxy_close = false; + for _ in 0..300 { + let result = read_bounded( + &mut crypto_reader, + ProtoTag::Abridged, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ) + .await; + + match result { + Ok(Some((_payload, _quickack))) => {} + Err(ProxyError::Proxy(_)) => { + saw_proxy_close = true; + break; + } + Err(ProxyError::Io(e)) => panic!("unexpected IO error before close: {e}"), + Ok(None) => panic!("unexpected EOF before debt-based closure"), + Err(other) => panic!("unexpected error before close: {other}"), + } + } + + assert!( + saw_proxy_close, + "alternating tiny/real sequence must eventually fail closed" + ); +} + +#[tokio::test] +async fn enabled_idle_policy_valid_nonzero_frame_still_passes() { + let (reader, mut writer) = duplex(1024); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(3, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + let idle_policy = make_enabled_idle_policy(); + let last_downstream_activity_ms = AtomicU64::new(0); + + let payload = [7u8, 8, 9, 10]; + let mut plaintext = Vec::with_capacity(1 + payload.len()); + plaintext.push(0x01); + plaintext.extend_from_slice(&payload); + + let encrypted = encrypt_for_reader(&plaintext); + writer + .write_all(&encrypted) + .await + .expect("nonzero frame must be writable"); + + let result = read_bounded( + &mut crypto_reader, + ProtoTag::Abridged, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ) + .await + .expect("valid frame should decode") + .expect("valid frame should return payload"); + + assert_eq!(result.0.as_ref(), &payload); + assert!(!result.1); + assert_eq!(frame_counter, 1); +} + +#[tokio::test] +async fn abridged_quickack_tiny_flood_is_fail_closed() { + let (reader, mut writer) = duplex(4096); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(21, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + let idle_policy = make_enabled_idle_policy(); + let last_downstream_activity_ms = AtomicU64::new(0); + + let flood_plaintext = vec![0x80u8; 256]; + let flood_encrypted = encrypt_for_reader(&flood_plaintext); + writer.write_all(&flood_encrypted).await.unwrap(); + drop(writer); + + let result = read_bounded( + &mut crypto_reader, + ProtoTag::Abridged, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ) + .await; + + assert!( + matches!(result, Err(ProxyError::Proxy(_))), + "quickack-marked zero-length flood must fail closed" + ); +} + +#[tokio::test] +async fn abridged_extended_zero_len_flood_is_fail_closed() { + let (reader, mut writer) = duplex(4096); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(22, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + let idle_policy = make_enabled_idle_policy(); + let last_downstream_activity_ms = AtomicU64::new(0); + + let mut flood_plaintext = Vec::with_capacity(4 * 256); + for _ in 0..256 { + flood_plaintext.extend_from_slice(&[0x7f, 0x00, 0x00, 0x00]); + } + let flood_encrypted = encrypt_for_reader(&flood_plaintext); + writer.write_all(&flood_encrypted).await.unwrap(); + drop(writer); + + let result = read_bounded( + &mut crypto_reader, + ProtoTag::Abridged, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ) + .await; + + assert!( + matches!(result, Err(ProxyError::Proxy(_))), + "extended zero-length abridged flood must fail closed" + ); +} + +#[tokio::test] +async fn one_to_eight_abridged_wire_pattern_survives_without_false_positive_close() { + let mut plaintext = Vec::with_capacity(9 * 300); + for idx in 0..300usize { + plaintext.push(0x00); + for _ in 0..8 { + let b = idx as u8; + plaintext.push(0x01); + plaintext.extend_from_slice(&[b, b ^ 0x11, b ^ 0x22, b ^ 0x33]); + } + } + + // Keep the test single-task and deterministic: make duplex capacity larger than the + // generated ciphertext so write_all cannot block waiting for a concurrent reader. + let duplex_capacity = plaintext.len().saturating_add(1024); + let (reader, mut writer) = duplex(duplex_capacity); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(23, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + let idle_policy = make_enabled_idle_policy(); + let last_downstream_activity_ms = AtomicU64::new(0); + + let encrypted = encrypt_for_reader(&plaintext); + writer.write_all(&encrypted).await.unwrap(); + drop(writer); + + let mut closed = false; + for _ in 0..3000 { + match read_bounded( + &mut crypto_reader, + ProtoTag::Abridged, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ) + .await + { + Ok(Some(_)) => {} + Ok(None) => break, + Err(ProxyError::Proxy(_)) => { + closed = true; + break; + } + Err(other) => panic!("unexpected error in 1:8 wire test: {other}"), + } + } + + assert!( + !closed, + "wire-level 1:8 tiny-to-real pattern should not trigger debt close" + ); +} + +#[tokio::test] +async fn deterministic_light_fuzz_abridged_wire_behavior_matches_model() { + let mut seed = 0xD1CE_BAAD_2026_0322u64; + + for case_idx in 0..32u64 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let events = 300 + ((seed as usize) & 0xff); + let mut pattern = Vec::with_capacity(events); + let mut local = seed; + for _ in 0..events { + local ^= local << 7; + local ^= local >> 9; + local ^= local << 8; + pattern.push((local & 0x03) == 0); + } + + let mut plaintext = Vec::with_capacity(events * 6); + for (idx, tiny) in pattern.iter().copied().enumerate() { + if tiny { + plaintext.push(0x00); + } else { + let b = (idx as u8) ^ (case_idx as u8); + plaintext.push(0x01); + plaintext.extend_from_slice(&[b, b ^ 0x1F, b ^ 0x7A, b ^ 0xC3]); + } + } + + let (reader, mut writer) = duplex(16 * 1024); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(500 + case_idx, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + let idle_policy = make_enabled_idle_policy(); + let last_downstream_activity_ms = AtomicU64::new(0); + + writer + .write_all(&encrypt_for_reader(&plaintext)) + .await + .unwrap(); + drop(writer); + + let (expected_close, _, _) = simulate_tiny_debt_pattern(&pattern, pattern.len()); + let mut observed_close = false; + + for _ in 0..(events + 8) { + match read_bounded( + &mut crypto_reader, + ProtoTag::Abridged, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ) + .await + { + Ok(Some(_)) => {} + Ok(None) => break, + Err(ProxyError::Proxy(_)) => { + observed_close = true; + break; + } + Err(other) => panic!("unexpected fuzz error: {other}"), + } + } + + assert_eq!( + observed_close, + expected_close.is_some(), + "wire parser behavior must match debt model for case {case_idx}" + ); + } +} diff --git a/src/proxy/tests/middle_relay_zero_length_frame_security_tests.rs b/src/proxy/tests/middle_relay_zero_length_frame_security_tests.rs new file mode 100644 index 0000000..dbf6c4c --- /dev/null +++ b/src/proxy/tests/middle_relay_zero_length_frame_security_tests.rs @@ -0,0 +1,121 @@ +use super::*; +use crate::crypto::AesCtr; +use crate::stats::Stats; +use crate::stream::{BufferPool, CryptoReader}; +use std::sync::Arc; +use std::sync::atomic::AtomicU64; +use std::time::Instant; +use tokio::io::{AsyncRead, AsyncWriteExt, duplex}; + +fn make_crypto_reader(reader: T) -> CryptoReader +where + T: AsyncRead + Unpin + Send + 'static, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoReader::new(reader, AesCtr::new(&key, iv)) +} + +fn encrypt_for_reader(plaintext: &[u8]) -> Vec { + let key = [0u8; 32]; + let iv = 0u128; + let mut cipher = AesCtr::new(&key, iv); + cipher.encrypt(plaintext) +} + +fn make_forensics(conn_id: u64, started_at: Instant) -> RelayForensicsState { + RelayForensicsState { + trace_id: 0xB000_0000 + conn_id, + conn_id, + user: format!("zero-len-test-user-{conn_id}"), + peer: "127.0.0.1:50000".parse().expect("peer parse must succeed"), + peer_hash: hash_ip("127.0.0.1".parse().expect("ip parse must succeed")), + started_at, + bytes_c2me: 0, + bytes_me2c: Arc::new(AtomicU64::new(0)), + desync_all_full: false, + } +} + +#[tokio::test] +async fn adversarial_legacy_zero_length_flood_is_fail_closed() { + let (reader, mut writer) = duplex(1024); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(1, session_started_at); + let mut frame_counter = 0u64; + + let flood_plaintext = vec![0u8; 128]; + let flood_encrypted = encrypt_for_reader(&flood_plaintext); + writer + .write_all(&flood_encrypted) + .await + .expect("zero-length flood bytes must be writable"); + drop(writer); + + let result = read_client_payload_legacy( + &mut crypto_reader, + ProtoTag::Abridged, + 1024, + Duration::from_millis(30), + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + ) + .await; + + match result { + Err(ProxyError::Proxy(msg)) => { + assert!( + msg.contains("Excessive zero-length"), + "legacy mode must close flood with explicit zero-length reason, got: {msg}" + ); + } + Ok(None) => panic!("legacy zero-length flood must not be accepted as EOF"), + Ok(Some(_)) => panic!("legacy zero-length flood must not produce a data frame"), + Err(err) => panic!("legacy zero-length flood must be a Proxy error, got: {err}"), + } +} + +#[tokio::test] +async fn business_abridged_nonzero_frame_still_passes() { + let (reader, mut writer) = duplex(1024); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(2, session_started_at); + let mut frame_counter = 0u64; + + let payload = [1u8, 2, 3, 4]; + let mut plaintext = Vec::with_capacity(1 + payload.len()); + plaintext.push(0x01); + plaintext.extend_from_slice(&payload); + + let encrypted = encrypt_for_reader(&plaintext); + writer + .write_all(&encrypted) + .await + .expect("nonzero abridged frame must be writable"); + + let result = read_client_payload_legacy( + &mut crypto_reader, + ProtoTag::Abridged, + 1024, + Duration::from_millis(30), + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + ) + .await + .expect("valid abridged frame should decode") + .expect("valid abridged frame should return payload"); + + assert_eq!(result.0.as_ref(), &payload); + assert!(!result.1, "quickack flag must remain false"); + assert_eq!(frame_counter, 1); +} diff --git a/src/proxy/tests/relay_adversarial_tests.rs b/src/proxy/tests/relay_adversarial_tests.rs index 14754cd..38e6fc7 100644 --- a/src/proxy/tests/relay_adversarial_tests.rs +++ b/src/proxy/tests/relay_adversarial_tests.rs @@ -78,7 +78,8 @@ async fn relay_hol_blocking_prevention_regression() { async fn relay_quota_mid_session_cutoff() { let stats = Arc::new(Stats::new()); let user = "quota-mid-user"; - let quota = 5000; + let quota = 5000u64; + let c2s_buf_size = 1024usize; let (client_peer, relay_client) = duplex(8192); let (relay_server, server_peer) = duplex(8192); @@ -93,7 +94,7 @@ async fn relay_quota_mid_session_cutoff() { client_writer, server_reader, server_writer, - 1024, + c2s_buf_size, 1024, user, Arc::clone(&stats), @@ -120,9 +121,25 @@ async fn relay_quota_mid_session_cutoff() { other => panic!("Expected DataQuotaExceeded error, got: {:?}", other), } - let mut small_buf = [0u8; 1]; - let n = sp_reader.read(&mut small_buf).await.unwrap(); - assert_eq!(n, 0, "Server must see EOF after quota reached"); + let mut overshoot_bytes = 0usize; + let mut buf = [0u8; 256]; + loop { + match timeout(Duration::from_millis(20), sp_reader.read(&mut buf)).await { + Ok(Ok(0)) => break, + Ok(Ok(n)) => overshoot_bytes = overshoot_bytes.saturating_add(n), + Ok(Err(e)) => panic!("server read must not fail after relay cutoff: {e}"), + Err(_) => break, + } + } + + assert!( + overshoot_bytes <= c2s_buf_size, + "post-write cutoff may leak at most one C->S chunk after boundary, got {overshoot_bytes}" + ); + assert!( + stats.get_user_quota_used(user) <= quota.saturating_add(c2s_buf_size as u64), + "accounted quota must remain bounded by one in-flight chunk overshoot" + ); } #[tokio::test] diff --git a/src/proxy/tests/relay_atomic_quota_invariant_tests.rs b/src/proxy/tests/relay_atomic_quota_invariant_tests.rs new file mode 100644 index 0000000..1bb00a6 --- /dev/null +++ b/src/proxy/tests/relay_atomic_quota_invariant_tests.rs @@ -0,0 +1,243 @@ +use super::*; +use std::collections::VecDeque; +use std::io; +use std::pin::Pin; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::sync::{Arc, Mutex}; +use std::task::{Context, Poll}; +use tokio::io::{AsyncWrite, AsyncWriteExt}; +use tokio::time::Instant; + +struct ScriptedWriter { + scripted_writes: Arc>>, + write_calls: Arc, +} + +impl ScriptedWriter { + fn new(script: &[usize], write_calls: Arc) -> Self { + Self { + scripted_writes: Arc::new(Mutex::new(script.iter().copied().collect())), + write_calls, + } + } +} + +impl AsyncWrite for ScriptedWriter { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let this = self.get_mut(); + this.write_calls.fetch_add(1, Ordering::Relaxed); + let planned = this + .scripted_writes + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) + .pop_front() + .unwrap_or(buf.len()); + Poll::Ready(Ok(planned.min(buf.len()))) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} + +fn make_stats_io_with_script( + user: &str, + quota_limit: u64, + precharged_quota: u64, + script: &[usize], +) -> ( + StatsIo, + Arc, + Arc, + Arc, +) { + let stats = Arc::new(Stats::new()); + if precharged_quota > 0 { + let user_stats = stats.get_or_create_user_stats_handle(user); + stats.quota_charge_post_write(user_stats.as_ref(), precharged_quota); + } + + let write_calls = Arc::new(AtomicUsize::new(0)); + let quota_exceeded = Arc::new(AtomicBool::new(false)); + let io = StatsIo::new( + ScriptedWriter::new(script, write_calls.clone()), + Arc::new(SharedCounters::new()), + stats.clone(), + user.to_string(), + Some(quota_limit), + quota_exceeded.clone(), + Instant::now(), + ); + + (io, stats, write_calls, quota_exceeded) +} + +#[tokio::test] +async fn direct_partial_write_charges_only_committed_bytes_without_double_charge() { + let user = "direct-partial-charge-user"; + let (mut io, stats, write_calls, quota_exceeded) = + make_stats_io_with_script(user, 1_048_576, 0, &[8 * 1024, 8 * 1024, 48 * 1024]); + let payload = vec![0xAB; 64 * 1024]; + + let n1 = io + .write(&payload) + .await + .expect("first partial write must succeed"); + let n2 = io + .write(&payload) + .await + .expect("second partial write must succeed"); + let n3 = io.write(&payload).await.expect("tail write must succeed"); + + assert_eq!(n1, 8 * 1024); + assert_eq!(n2, 8 * 1024); + assert_eq!(n3, 48 * 1024); + assert_eq!(write_calls.load(Ordering::Relaxed), 3); + assert_eq!( + stats.get_user_quota_used(user), + (n1 + n2 + n3) as u64, + "quota accounting must follow committed bytes only" + ); + assert_eq!( + stats.get_user_total_octets(user), + (n1 + n2 + n3) as u64, + "telemetry octets should match committed bytes on successful writes" + ); + assert!( + !quota_exceeded.load(Ordering::Acquire), + "quota flag should stay false under large remaining budget" + ); +} + +#[tokio::test] +async fn direct_hybrid_branch_selection_matches_contract() { + let near_limit = 256 * 1024u64; + let near_remaining = 32 * 1024u64; + let (mut near_io, _stats, _calls, _flag) = make_stats_io_with_script( + "direct-near-limit-hard-check-user", + near_limit, + near_limit - near_remaining, + &[4 * 1024], + ); + let near_payload = vec![0x11; 4 * 1024]; + let near_written = near_io + .write(&near_payload) + .await + .expect("near-limit write must succeed"); + assert_eq!(near_written, 4 * 1024); + assert_eq!( + near_io.quota_bytes_since_check, 0, + "near-limit branch must go through immediate hard check" + ); + + let (mut far_small_io, _stats, _calls, _flag) = + make_stats_io_with_script("direct-far-small-amortized-user", 1_048_576, 0, &[4 * 1024]); + let far_small_payload = vec![0x22; 4 * 1024]; + let far_small_written = far_small_io + .write(&far_small_payload) + .await + .expect("small far-from-limit write must succeed"); + assert_eq!(far_small_written, 4 * 1024); + assert_eq!( + far_small_io.quota_bytes_since_check, + 4 * 1024, + "small far-from-limit write must go through amortized path" + ); + + let (mut far_large_io, _stats, _calls, _flag) = make_stats_io_with_script( + "direct-far-large-hard-check-user", + 1_048_576, + 0, + &[32 * 1024], + ); + let far_large_payload = vec![0x33; 32 * 1024]; + let far_large_written = far_large_io + .write(&far_large_payload) + .await + .expect("large write must succeed"); + assert_eq!(far_large_written, 32 * 1024); + assert_eq!( + far_large_io.quota_bytes_since_check, 0, + "large write must force immediate hard check even far from limit" + ); +} + +#[tokio::test] +async fn remaining_before_zero_rejects_without_calling_inner_writer() { + let user = "direct-zero-remaining-user"; + let limit = 8u64; + let (mut io, stats, write_calls, quota_exceeded) = + make_stats_io_with_script(user, limit, limit, &[1]); + + let err = io + .write(&[0x44]) + .await + .expect_err("write must fail when remaining quota is zero"); + + assert!( + is_quota_io_error(&err), + "zero-remaining gate must return typed quota I/O error" + ); + assert_eq!( + write_calls.load(Ordering::Relaxed), + 0, + "inner poll_write must not be called when remaining quota is zero" + ); + assert!( + quota_exceeded.load(Ordering::Acquire), + "zero-remaining gate must set exceeded flag" + ); + assert_eq!(stats.get_user_quota_used(user), limit); +} + +#[tokio::test] +async fn exceeded_flag_blocks_following_poll_before_inner_write() { + let user = "direct-exceeded-visibility-user"; + let (mut io, stats, write_calls, quota_exceeded) = + make_stats_io_with_script(user, 1, 0, &[1, 1]); + + let first = io + .write(&[0x55]) + .await + .expect("first byte should consume remaining quota"); + assert_eq!(first, 1); + assert!( + quota_exceeded.load(Ordering::Acquire), + "hard check should store quota_exceeded after boundary hit" + ); + + let second = io + .write(&[0x66]) + .await + .expect_err("next write must be rejected by early exceeded gate"); + assert!( + is_quota_io_error(&second), + "following write must fail with typed quota error" + ); + assert_eq!( + write_calls.load(Ordering::Relaxed), + 1, + "second write must be cut before touching inner writer" + ); + assert_eq!(stats.get_user_quota_used(user), 1); +} + +#[test] +fn adaptive_interval_clamp_matches_contract() { + assert_eq!(quota_adaptive_interval_bytes(0), 4 * 1024); + assert_eq!(quota_adaptive_interval_bytes(2 * 1024), 4 * 1024); + assert_eq!(quota_adaptive_interval_bytes(32 * 1024), 16 * 1024); + assert_eq!(quota_adaptive_interval_bytes(256 * 1024), 64 * 1024); + + assert!(should_immediate_quota_check(32 * 1024, 4 * 1024)); + assert!(should_immediate_quota_check(1_048_576, 32 * 1024)); + assert!(!should_immediate_quota_check(1_048_576, 4 * 1024)); +} diff --git a/src/proxy/tests/relay_quota_boundary_blackhat_tests.rs b/src/proxy/tests/relay_quota_boundary_blackhat_tests.rs index 080240a..9a32b26 100644 --- a/src/proxy/tests/relay_quota_boundary_blackhat_tests.rs +++ b/src/proxy/tests/relay_quota_boundary_blackhat_tests.rs @@ -29,6 +29,11 @@ async fn read_available(reader: &mut R, budget: Duration) total } +fn preload_user_quota(stats: &Stats, user: &str, bytes: u64) { + let user_stats = stats.get_or_create_user_stats_handle(user); + stats.quota_charge_post_write(user_stats.as_ref(), bytes); +} + #[tokio::test] async fn integration_full_duplex_exact_budget_then_hard_cutoff() { let stats = Arc::new(Stats::new()); @@ -102,14 +107,14 @@ async fn integration_full_duplex_exact_budget_then_hard_cutoff() { relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == "quota-full-duplex-boundary-user" )); - assert!(stats.get_user_total_octets(user) <= 10); + assert!(stats.get_user_quota_used(user) <= 10); } #[tokio::test] async fn negative_preloaded_quota_blocks_both_directions_immediately() { let stats = Arc::new(Stats::new()); let user = "quota-preloaded-cutoff-user"; - stats.add_user_octets_from(user, 5); + preload_user_quota(stats.as_ref(), user, 5); let (mut client_peer, relay_client) = duplex(2048); let (relay_server, mut server_peer) = duplex(2048); @@ -154,7 +159,7 @@ async fn negative_preloaded_quota_blocks_both_directions_immediately() { relay_result, Err(ProxyError::DataQuotaExceeded { .. }) )); - assert!(stats.get_user_total_octets(user) <= 5); + assert!(stats.get_user_quota_used(user) <= 5); } #[tokio::test] @@ -212,7 +217,7 @@ async fn edge_quota_one_bidirectional_race_allows_at_most_one_forwarded_octet() relay_result, Err(ProxyError::DataQuotaExceeded { .. }) )); - assert!(stats.get_user_total_octets(user) <= 1); + assert!(stats.get_user_quota_used(user) <= 1); } #[tokio::test] @@ -277,7 +282,7 @@ async fn adversarial_blackhat_alternating_fragmented_jitter_never_overshoots_glo delivered_to_server + delivered_to_client <= quota as usize, "combined forwarded bytes must never exceed configured quota" ); - assert!(stats.get_user_total_octets(user) <= quota); + assert!(stats.get_user_quota_used(user) <= quota); } #[tokio::test] @@ -356,7 +361,7 @@ async fn light_fuzz_randomized_schedule_preserves_quota_and_forwarded_byte_invar "fuzz case {case}: forwarded bytes must not exceed quota" ); assert!( - stats.get_user_total_octets(&user) <= quota, + stats.get_user_quota_used(&user) <= quota, "fuzz case {case}: accounted bytes must not exceed quota" ); } @@ -451,7 +456,7 @@ async fn stress_multi_relay_same_user_mixed_direction_jitter_respects_global_quo } assert!( - stats.get_user_total_octets(user) <= quota, + stats.get_user_quota_used(user) <= quota, "global per-user quota must hold under concurrent mixed-direction relay stress" ); assert!( diff --git a/src/proxy/tests/relay_quota_extended_attack_surface_security_tests.rs b/src/proxy/tests/relay_quota_extended_attack_surface_security_tests.rs new file mode 100644 index 0000000..8ce1c26 --- /dev/null +++ b/src/proxy/tests/relay_quota_extended_attack_surface_security_tests.rs @@ -0,0 +1,399 @@ +use super::relay_bidirectional; +use crate::error::ProxyError; +use crate::stats::Stats; +use crate::stream::BufferPool; +use rand::rngs::StdRng; +use rand::{RngExt, SeedableRng}; +use std::sync::Arc; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::time::{Duration, timeout}; + +async fn read_available( + reader: &mut R, + budget: Duration, +) -> usize { + let start = tokio::time::Instant::now(); + let mut total = 0usize; + let mut buf = [0u8; 128]; + + loop { + let elapsed = start.elapsed(); + if elapsed >= budget { + break; + } + let remaining = budget.saturating_sub(elapsed); + match timeout(remaining, reader.read(&mut buf)).await { + Ok(Ok(0)) => break, + Ok(Ok(n)) => total = total.saturating_add(n), + Ok(Err(_)) | Err(_) => break, + } + } + + total +} + +fn preload_user_quota(stats: &Stats, user: &str, bytes: u64) { + let user_stats = stats.get_or_create_user_stats_handle(user); + stats.quota_charge_post_write(user_stats.as_ref(), bytes); +} + +#[tokio::test] +async fn positive_quota_path_forwards_both_directions_within_limit() { + let stats = Arc::new(Stats::new()); + let user = "quota-extended-positive-user"; + + let (mut client_peer, relay_client) = duplex(2048); + let (relay_server, mut server_peer) = duplex(2048); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 256, + 256, + user, + Arc::clone(&stats), + Some(16), + Arc::new(BufferPool::new()), + )); + + client_peer + .write_all(&[0xAA, 0xBB, 0xCC, 0xDD]) + .await + .unwrap(); + server_peer.read_exact(&mut [0u8; 4]).await.unwrap(); + + server_peer + .write_all(&[0x11, 0x22, 0x33, 0x44]) + .await + .unwrap(); + client_peer.read_exact(&mut [0u8; 4]).await.unwrap(); + + drop(client_peer); + drop(server_peer); + + let relay_result = timeout(Duration::from_secs(2), relay) + .await + .unwrap() + .unwrap(); + assert!(relay_result.is_ok()); + assert!(stats.get_user_quota_used(user) <= 16); +} + +#[tokio::test] +async fn negative_preloaded_quota_forbids_any_forwarding() { + let stats = Arc::new(Stats::new()); + let user = "quota-extended-negative-user"; + preload_user_quota(stats.as_ref(), user, 8); + + let (mut client_peer, relay_client) = duplex(1024); + let (relay_server, mut server_peer) = duplex(1024); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 128, + 128, + user, + Arc::clone(&stats), + Some(8), + Arc::new(BufferPool::new()), + )); + + client_peer.write_all(&[0xAA]).await.unwrap(); + server_peer.write_all(&[0xBB]).await.unwrap(); + + assert_eq!( + read_available(&mut server_peer, Duration::from_millis(120)).await, + 0 + ); + assert_eq!( + read_available(&mut client_peer, Duration::from_millis(120)).await, + 0 + ); + + let relay_result = timeout(Duration::from_secs(2), relay) + .await + .unwrap() + .unwrap(); + assert!(matches!( + relay_result, + Err(ProxyError::DataQuotaExceeded { .. }) + )); + assert!(stats.get_user_quota_used(user) <= 8); +} + +#[tokio::test] +async fn edge_quota_one_ensures_at_most_one_byte_across_directions() { + let stats = Arc::new(Stats::new()); + let user = "quota-extended-edge-user"; + + let (mut client_peer, relay_client) = duplex(1024); + let (relay_server, mut server_peer) = duplex(1024); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 128, + 128, + user, + Arc::clone(&stats), + Some(1), + Arc::new(BufferPool::new()), + )); + + let _ = tokio::join!( + client_peer.write_all(&[0xFE]), + server_peer.write_all(&[0xEF]), + ); + + let mut buf = [0u8; 1]; + let delivered_s2c = timeout(Duration::from_millis(120), client_peer.read(&mut buf)) + .await + .unwrap() + .unwrap_or(0); + let delivered_c2s = timeout(Duration::from_millis(120), server_peer.read(&mut buf)) + .await + .unwrap() + .unwrap_or(0); + + assert!(delivered_s2c + delivered_c2s <= 1); + + let relay_result = timeout(Duration::from_secs(2), relay) + .await + .unwrap() + .unwrap(); + assert!(matches!( + relay_result, + Err(ProxyError::DataQuotaExceeded { .. }) + )); +} + +#[tokio::test] +async fn adversarial_blackhat_alternating_jitter_does_not_overshoot_quota() { + let stats = Arc::new(Stats::new()); + let user = "quota-extended-blackhat-user"; + let quota = 24u64; + + let (mut client_peer, relay_client) = duplex(4096); + let (relay_server, mut server_peer) = duplex(4096); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 256, + 256, + user, + Arc::clone(&stats), + Some(quota), + Arc::new(BufferPool::new()), + )); + + let mut total_forwarded = 0usize; + + for i in 0..256usize { + if relay.is_finished() { + break; + } + if (i & 1) == 0 { + let _ = client_peer.write_all(&[(i as u8) ^ 0x57]).await; + let mut one = [0u8; 1]; + if let Ok(Ok(n)) = timeout(Duration::from_millis(6), server_peer.read(&mut one)).await { + total_forwarded += n; + } + } else { + let _ = server_peer.write_all(&[(i as u8) ^ 0xA8]).await; + let mut one = [0u8; 1]; + if let Ok(Ok(n)) = timeout(Duration::from_millis(6), client_peer.read(&mut one)).await { + total_forwarded += n; + } + } + + tokio::time::sleep(Duration::from_millis(((i % 3) + 1) as u64)).await; + } + + let relay_result = timeout(Duration::from_secs(3), relay) + .await + .unwrap() + .unwrap(); + assert!(matches!( + relay_result, + Err(ProxyError::DataQuotaExceeded { .. }) + )); + assert!(total_forwarded <= quota as usize); + assert!(stats.get_user_quota_used(user) <= quota); +} + +#[tokio::test] +async fn light_fuzz_random_quota_schedule_preserves_quota_invariants() { + let mut rng = StdRng::seed_from_u64(0xBEEF_C0DE); + + for case in 0..32u64 { + let stats = Arc::new(Stats::new()); + let user = format!("quota-extended-fuzz-{case}"); + let quota = rng.random_range(1u64..=35u64); + + let (mut client_peer, relay_client) = duplex(4096); + let (relay_server, mut server_peer) = duplex(4096); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_user = user.clone(); + let relay_stats = Arc::clone(&stats); + let relay = tokio::spawn(async move { + relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 256, + 256, + &relay_user, + Arc::clone(&relay_stats), + Some(quota), + Arc::new(BufferPool::new()), + ) + .await + }); + + let mut total_forwarded = 0usize; + + for _ in 0..96usize { + if relay.is_finished() { + break; + } + + if rng.random::() { + let _ = client_peer.write_all(&[rng.random::()]).await; + let mut one = [0u8; 1]; + if let Ok(Ok(n)) = + timeout(Duration::from_millis(4), server_peer.read(&mut one)).await + { + total_forwarded += n; + } + } else { + let _ = server_peer.write_all(&[rng.random::()]).await; + let mut one = [0u8; 1]; + if let Ok(Ok(n)) = + timeout(Duration::from_millis(4), client_peer.read(&mut one)).await + { + total_forwarded += n; + } + } + } + + drop(client_peer); + drop(server_peer); + + let relay_result = timeout(Duration::from_secs(2), relay) + .await + .unwrap() + .unwrap(); + assert!( + relay_result.is_ok() + || matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. })) + ); + assert!(total_forwarded <= quota as usize); + assert!(stats.get_user_quota_used(&user) <= quota); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_parallel_relays_for_one_user_obey_global_quota() { + let stats = Arc::new(Stats::new()); + let user = "quota-extended-stress-user".to_string(); + let quota = 64u64; + + let mut tasks = Vec::new(); + + for worker in 0..4u8 { + let stats = Arc::clone(&stats); + let user = user.clone(); + + tasks.push(tokio::spawn(async move { + let (mut client_peer, relay_client) = duplex(2048); + let (relay_server, mut server_peer) = duplex(2048); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_user = user.clone(); + let relay_stats = Arc::clone(&stats); + let relay = tokio::spawn(async move { + relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 128, + 128, + &relay_user, + Arc::clone(&relay_stats), + Some(quota), + Arc::new(BufferPool::new()), + ) + .await + }); + + let mut total = 0usize; + for step in 0..64u8 { + if relay.is_finished() { + break; + } + if (step as usize + worker as usize) % 2 == 0 { + let _ = client_peer.write_all(&[(step ^ 0x5A)]).await; + let mut one = [0u8; 1]; + if let Ok(Ok(n)) = + timeout(Duration::from_millis(6), server_peer.read(&mut one)).await + { + total += n; + } + } else { + let _ = server_peer.write_all(&[(step ^ 0xA5)]).await; + let mut one = [0u8; 1]; + if let Ok(Ok(n)) = + timeout(Duration::from_millis(6), client_peer.read(&mut one)).await + { + total += n; + } + } + tokio::time::sleep(Duration::from_millis(1)).await; + } + + drop(client_peer); + drop(server_peer); + + let relay_result = timeout(Duration::from_secs(2), relay) + .await + .unwrap() + .unwrap(); + assert!( + relay_result.is_ok() + || matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. })) + ); + total + })); + } + + let mut delivered = 0usize; + for task in tasks { + delivered += task.await.unwrap(); + } + + assert!(stats.get_user_quota_used(&user) <= quota); + assert!(delivered <= quota as usize); +} diff --git a/src/proxy/tests/relay_quota_lock_pressure_adversarial_tests.rs b/src/proxy/tests/relay_quota_lock_pressure_adversarial_tests.rs deleted file mode 100644 index e29e86e..0000000 --- a/src/proxy/tests/relay_quota_lock_pressure_adversarial_tests.rs +++ /dev/null @@ -1,438 +0,0 @@ -use super::*; -use crate::error::ProxyError; -use crate::stats::Stats; -use crate::stream::BufferPool; -use dashmap::DashMap; -use std::sync::Arc; -use std::sync::atomic::AtomicBool; -use std::time::Duration; -use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; -use tokio::sync::Barrier; -use tokio::time::Instant; - -#[test] -fn quota_lock_same_user_returns_same_arc_instance() { - let _guard = super::quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let a = quota_user_lock("quota-lock-same-user"); - let b = quota_user_lock("quota-lock-same-user"); - assert!(Arc::ptr_eq(&a, &b)); -} - -#[test] -fn quota_lock_parallel_same_user_reuses_single_lock() { - let _guard = super::quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let user = "quota-lock-parallel-same"; - let mut handles = Vec::new(); - - for _ in 0..64 { - handles.push(std::thread::spawn(move || quota_user_lock(user))); - } - - let first = handles - .remove(0) - .join() - .expect("thread must return lock handle"); - - for handle in handles { - let got = handle.join().expect("thread must return lock handle"); - assert!(Arc::ptr_eq(&first, &got)); - } -} - -#[test] -fn quota_lock_unique_users_materialize_distinct_entries() { - let _guard = super::quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - - map.clear(); - - let base = format!("quota-lock-distinct-{}", std::process::id()); - let users: Vec = (0..(QUOTA_USER_LOCKS_MAX / 2)) - .map(|idx| format!("{base}-{idx}")) - .collect(); - - for user in &users { - let _ = quota_user_lock(user); - } - - for user in &users { - assert!( - map.get(user).is_some(), - "lock cache must contain entry for {user}" - ); - } -} - -#[test] -fn quota_lock_unique_churn_stress_keeps_all_inserted_keys_addressable() { - let _guard = super::quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - - map.clear(); - - let base = format!("quota-lock-churn-{}", std::process::id()); - for idx in 0..(QUOTA_USER_LOCKS_MAX + 256) { - let _ = quota_user_lock(&format!("{base}-{idx}")); - } - - assert!( - map.len() <= QUOTA_USER_LOCKS_MAX, - "quota lock cache must stay bounded under unique-user churn" - ); -} - -#[test] -fn quota_lock_saturation_returns_stable_overflow_lock_without_cache_growth() { - let _guard = super::quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let prefix = format!("quota-held-{}", std::process::id()); - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - for idx in 0..QUOTA_USER_LOCKS_MAX { - retained.push(quota_user_lock(&format!("{prefix}-{idx}"))); - } - - assert_eq!( - map.len(), - QUOTA_USER_LOCKS_MAX, - "cache must be saturated for overflow check" - ); - - let overflow_user = format!("quota-overflow-{}", std::process::id()); - let overflow_a = quota_user_lock(&overflow_user); - let overflow_b = quota_user_lock(&overflow_user); - - assert_eq!( - map.len(), - QUOTA_USER_LOCKS_MAX, - "overflow path must not grow lock cache" - ); - assert!( - map.get(&overflow_user).is_none(), - "overflow user lock must stay outside bounded cache under saturation" - ); - assert!( - Arc::ptr_eq(&overflow_a, &overflow_b), - "overflow user must receive stable striped overflow lock while saturated" - ); - - drop(retained); -} - -#[test] -fn quota_lock_reclaims_unreferenced_entries_before_ephemeral_fallback() { - let _guard = super::quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - // Saturate with retained strong references first so parallel tests cannot - // reclaim our fixture entries before we validate the reclaim path. - let prefix = format!("quota-reclaim-drop-{}", std::process::id()); - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - for idx in 0..QUOTA_USER_LOCKS_MAX { - retained.push(quota_user_lock(&format!("{prefix}-{idx}"))); - } - - drop(retained); - - let overflow_user = format!("quota-reclaim-overflow-{}", std::process::id()); - let overflow = quota_user_lock(&overflow_user); - - assert!( - map.get(&overflow_user).is_some(), - "after reclaiming stale entries, overflow user should become cacheable" - ); - assert!( - Arc::strong_count(&overflow) >= 2, - "cacheable overflow lock should be held by both map and caller" - ); -} - -#[test] -fn quota_lock_saturated_same_user_must_not_return_distinct_locks() { - let _guard = super::quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - for idx in 0..QUOTA_USER_LOCKS_MAX { - retained.push(quota_user_lock(&format!( - "quota-saturated-held-{}-{idx}", - std::process::id() - ))); - } - - let overflow_user = format!("quota-saturated-same-user-{}", std::process::id()); - let a = quota_user_lock(&overflow_user); - let b = quota_user_lock(&overflow_user); - - assert!( - Arc::ptr_eq(&a, &b), - "same user must not receive distinct locks under saturation because that enables quota race bypass" - ); - - drop(retained); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn quota_lock_saturation_concurrent_same_user_never_overshoots_quota() { - let _guard = super::quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - for idx in 0..QUOTA_USER_LOCKS_MAX { - retained.push(quota_user_lock(&format!( - "quota-saturated-race-held-{}-{idx}", - std::process::id() - ))); - } - - let stats = Arc::new(Stats::new()); - let user = format!("quota-saturated-race-user-{}", std::process::id()); - let gate = Arc::new(Barrier::new(2)); - - let worker = |label: u8, stats: Arc, user: String, gate: Arc| { - tokio::spawn(async move { - let counters = Arc::new(SharedCounters::new()); - let quota_exceeded = Arc::new(AtomicBool::new(false)); - let mut io = StatsIo::new( - tokio::io::sink(), - counters, - Arc::clone(&stats), - user, - Some(1), - quota_exceeded, - Instant::now(), - ); - gate.wait().await; - io.write_all(&[label]).await - }) - }; - - let one = worker(0x11, Arc::clone(&stats), user.clone(), Arc::clone(&gate)); - let two = worker(0x22, Arc::clone(&stats), user.clone(), Arc::clone(&gate)); - - let _ = tokio::time::timeout(Duration::from_secs(2), async { - let _ = one.await.expect("task one must not panic"); - let _ = two.await.expect("task two must not panic"); - }) - .await - .expect("quota race workers must complete"); - - assert!( - stats.get_user_total_octets(&user) <= 1, - "saturated lock path must never overshoot quota for same user" - ); - - drop(retained); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn quota_lock_saturation_stress_same_user_never_overshoots_quota() { - let _guard = super::quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - for idx in 0..QUOTA_USER_LOCKS_MAX { - retained.push(quota_user_lock(&format!( - "quota-saturated-stress-held-{}-{idx}", - std::process::id() - ))); - } - - for round in 0..128u32 { - let stats = Arc::new(Stats::new()); - let user = format!("quota-saturated-stress-user-{}-{round}", std::process::id()); - let gate = Arc::new(Barrier::new(2)); - - let one = { - let stats = Arc::clone(&stats); - let user = user.clone(); - let gate = Arc::clone(&gate); - tokio::spawn(async move { - let counters = Arc::new(SharedCounters::new()); - let quota_exceeded = Arc::new(AtomicBool::new(false)); - let mut io = StatsIo::new( - tokio::io::sink(), - counters, - Arc::clone(&stats), - user, - Some(1), - quota_exceeded, - Instant::now(), - ); - gate.wait().await; - io.write_all(&[0x31]).await - }) - }; - - let two = { - let stats = Arc::clone(&stats); - let user = user.clone(); - let gate = Arc::clone(&gate); - tokio::spawn(async move { - let counters = Arc::new(SharedCounters::new()); - let quota_exceeded = Arc::new(AtomicBool::new(false)); - let mut io = StatsIo::new( - tokio::io::sink(), - counters, - Arc::clone(&stats), - user, - Some(1), - quota_exceeded, - Instant::now(), - ); - gate.wait().await; - io.write_all(&[0x32]).await - }) - }; - - let _ = one.await.expect("stress task one must not panic"); - let _ = two.await.expect("stress task two must not panic"); - - assert!( - stats.get_user_total_octets(&user) <= 1, - "round {round}: saturated path must not overshoot quota" - ); - } - - drop(retained); -} - -#[test] -fn quota_error_classifier_accepts_internal_quota_sentinel_only() { - let err = quota_io_error(); - assert!(is_quota_io_error(&err)); -} - -#[test] -fn quota_error_classifier_rejects_plain_permission_denied() { - let err = std::io::Error::new(std::io::ErrorKind::PermissionDenied, "permission denied"); - assert!(!is_quota_io_error(&err)); -} - -#[test] -fn quota_lock_test_scope_recovers_after_guard_poison() { - let poison_result = std::thread::spawn(|| { - let _guard = super::quota_user_lock_test_scope(); - panic!("intentional test-only guard poison"); - }) - .join(); - assert!(poison_result.is_err(), "poison setup thread must panic"); - - let _guard = super::quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let a = quota_user_lock("quota-lock-poison-recovery-user"); - let b = quota_user_lock("quota-lock-poison-recovery-user"); - assert!(Arc::ptr_eq(&a, &b)); -} - -#[tokio::test] -async fn quota_lock_integration_zero_quota_cuts_off_without_forwarding() { - let stats = Arc::new(Stats::new()); - let user = "quota-zero-user"; - - let (mut client_peer, relay_client) = duplex(2048); - let (relay_server, mut server_peer) = duplex(2048); - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let relay = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 512, - 512, - user, - Arc::clone(&stats), - Some(0), - Arc::new(BufferPool::new()), - )); - - client_peer - .write_all(b"x") - .await - .expect("client write must succeed"); - - let mut probe = [0u8; 1]; - let forwarded = - tokio::time::timeout(Duration::from_millis(80), server_peer.read(&mut probe)).await; - if let Ok(Ok(n)) = forwarded { - assert_eq!(n, 0, "zero quota path must not forward payload bytes"); - } - - let result = tokio::time::timeout(Duration::from_secs(2), relay) - .await - .expect("relay must terminate under zero quota") - .expect("relay task must not panic"); - assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. }))); -} - -#[tokio::test] -async fn quota_lock_integration_no_quota_relays_both_directions_under_burst() { - let stats = Arc::new(Stats::new()); - - let (mut client_peer, relay_client) = duplex(8192); - let (relay_server, mut server_peer) = duplex(8192); - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let relay = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 1024, - 1024, - "quota-none-burst-user", - Arc::clone(&stats), - None, - Arc::new(BufferPool::new()), - )); - - let c2s = vec![0xA5; 2048]; - let s2c = vec![0x5A; 1536]; - - client_peer - .write_all(&c2s) - .await - .expect("client burst write must succeed"); - let mut got_c2s = vec![0u8; c2s.len()]; - server_peer - .read_exact(&mut got_c2s) - .await - .expect("server must receive c2s burst"); - assert_eq!(got_c2s, c2s); - - server_peer - .write_all(&s2c) - .await - .expect("server burst write must succeed"); - let mut got_s2c = vec![0u8; s2c.len()]; - client_peer - .read_exact(&mut got_s2c) - .await - .expect("client must receive s2c burst"); - assert_eq!(got_s2c, s2c); - - drop(client_peer); - drop(server_peer); - - let done = tokio::time::timeout(Duration::from_secs(2), relay) - .await - .expect("relay must terminate after peers close") - .expect("relay task must not panic"); - assert!(done.is_ok()); -} diff --git a/src/proxy/tests/relay_quota_model_adversarial_tests.rs b/src/proxy/tests/relay_quota_model_adversarial_tests.rs index 5714f48..04a7020 100644 --- a/src/proxy/tests/relay_quota_model_adversarial_tests.rs +++ b/src/proxy/tests/relay_quota_model_adversarial_tests.rs @@ -32,6 +32,7 @@ async fn drain_available(reader: &mut R, out: &mut Vec #[tokio::test] async fn model_fuzz_bidirectional_schedule_preserves_prefixes_and_quota_budget() { let mut rng = StdRng::seed_from_u64(0xC0DE_CAFE_D15C_F00D); + const MAX_INPUT_CHUNK: usize = 12; for case in 0..64u64 { let stats = Arc::new(Stats::new()); @@ -92,12 +93,12 @@ async fn model_fuzz_bidirectional_schedule_preserves_prefixes_and_quota_budget() assert_is_prefix(&recv_at_server, &sent_c2s, "C->S"); assert_is_prefix(&recv_at_client, &sent_s2c, "S->C"); assert!( - recv_at_server.len() + recv_at_client.len() <= quota as usize, - "fuzz case {case}: delivered bytes exceed quota" + recv_at_server.len() + recv_at_client.len() <= quota as usize + MAX_INPUT_CHUNK, + "fuzz case {case}: delivered bytes exceed bounded post-check overshoot" ); assert!( - stats.get_user_total_octets(&user) <= quota, - "fuzz case {case}: accounted bytes exceed quota" + stats.get_user_quota_used(&user) <= quota + MAX_INPUT_CHUNK as u64, + "fuzz case {case}: accounted bytes exceed bounded post-check overshoot" ); } @@ -117,8 +118,8 @@ async fn model_fuzz_bidirectional_schedule_preserves_prefixes_and_quota_budget() assert_is_prefix(&recv_at_server, &sent_c2s, "C->S final"); assert_is_prefix(&recv_at_client, &sent_s2c, "S->C final"); - assert!(recv_at_server.len() + recv_at_client.len() <= quota as usize); - assert!(stats.get_user_total_octets(&user) <= quota); + assert!(recv_at_server.len() + recv_at_client.len() <= quota as usize + MAX_INPUT_CHUNK); + assert!(stats.get_user_quota_used(&user) <= quota + MAX_INPUT_CHUNK as u64); } } @@ -209,7 +210,7 @@ async fn adversarial_dual_direction_cutoff_race_allows_at_most_one_forwarded_byt relay_result, Err(ProxyError::DataQuotaExceeded { .. }) )); - assert!(stats.get_user_total_octets(user) <= 1); + assert!(stats.get_user_quota_used(user) <= 1); } #[tokio::test(flavor = "multi_thread", worker_threads = 4)] @@ -217,9 +218,12 @@ async fn stress_shared_user_multi_relay_global_quota_never_overshoots_under_mode let stats = Arc::new(Stats::new()); let user = "quota-model-stress-user"; let quota = 96u64; + const WORKERS: usize = 6; + const MAX_WORKER_CHUNK: u64 = 10; + let max_parallel_post_write_overshoot = WORKERS as u64 * MAX_WORKER_CHUNK; let mut workers = Vec::new(); - for worker_id in 0..6u64 { + for worker_id in 0..WORKERS as u64 { let stats = Arc::clone(&stats); let user = user.to_string(); @@ -305,11 +309,11 @@ async fn stress_shared_user_multi_relay_global_quota_never_overshoots_under_mode } assert!( - stats.get_user_total_octets(user) <= quota, - "global per-user quota must never overshoot under concurrent multi-relay model load" + stats.get_user_quota_used(user) <= quota + max_parallel_post_write_overshoot, + "global per-user accounted bytes must stay within bounded post-write overshoot" ); assert!( - delivered_sum <= quota as usize, - "aggregate delivered bytes across relays must remain within global quota" + delivered_sum as u64 <= quota + max_parallel_post_write_overshoot, + "aggregate delivered bytes must stay within bounded post-write overshoot" ); } diff --git a/src/proxy/tests/relay_quota_overflow_regression_tests.rs b/src/proxy/tests/relay_quota_overflow_regression_tests.rs index dfbab85..f1e6c34 100644 --- a/src/proxy/tests/relay_quota_overflow_regression_tests.rs +++ b/src/proxy/tests/relay_quota_overflow_regression_tests.rs @@ -19,13 +19,22 @@ async fn read_available(reader: &mut R, budget_ms: u64) -> total } +fn preload_user_quota(stats: &Stats, user: &str, bytes: u64) { + let user_stats = stats.get_or_create_user_stats_handle(user); + stats.quota_charge_post_write(user_stats.as_ref(), bytes); +} + #[tokio::test] async fn regression_client_chunk_larger_than_remaining_quota_does_not_overshoot_accounting() { let stats = Arc::new(Stats::new()); let user = "quota-overflow-regression-client-chunk"; + let quota = 10u64; + let preloaded = 9u64; + let attempted_chunk = [0x11, 0x22, 0x33, 0x44]; + let max_post_write_overshoot = attempted_chunk.len() as u64; // Leave only 1 byte remaining under quota. - stats.add_user_octets_from(user, 9); + preload_user_quota(stats.as_ref(), user, preloaded); let (mut client_peer, relay_client) = duplex(2048); let (relay_server, mut server_peer) = duplex(2048); @@ -41,15 +50,12 @@ async fn regression_client_chunk_larger_than_remaining_quota_does_not_overshoot_ 512, user, Arc::clone(&stats), - Some(10), + Some(quota), Arc::new(BufferPool::new()), )); // Single chunk attempts to cross remaining budget (4 > 1). - client_peer - .write_all(&[0x11, 0x22, 0x33, 0x44]) - .await - .unwrap(); + client_peer.write_all(&attempted_chunk).await.unwrap(); client_peer.shutdown().await.unwrap(); let forwarded = read_available(&mut server_peer, 60).await; @@ -59,17 +65,17 @@ async fn regression_client_chunk_larger_than_remaining_quota_does_not_overshoot_ .expect("relay must terminate after quota overflow attempt") .expect("relay task must not panic"); - assert_eq!( - forwarded, 0, - "overflowing C->S chunk must not be forwarded when it exceeds remaining quota" + assert!( + forwarded <= attempted_chunk.len(), + "forwarded bytes must stay within one charged post-write chunk" ); assert!(matches!( relay_result, Err(ProxyError::DataQuotaExceeded { .. }) )); assert!( - stats.get_user_total_octets(user) <= 10, - "accounted bytes must never exceed quota after overflowing chunk" + stats.get_user_quota_used(user) <= quota + max_post_write_overshoot, + "accounted bytes must stay within bounded post-write overshoot" ); } @@ -79,7 +85,7 @@ async fn regression_client_exact_remaining_quota_forwards_once_then_hard_cuts_of let user = "quota-overflow-regression-boundary"; // Leave exactly 4 bytes remaining. - stats.add_user_octets_from(user, 6); + preload_user_quota(stats.as_ref(), user, 6); let (mut client_peer, relay_client) = duplex(2048); let (relay_server, mut server_peer) = duplex(2048); @@ -131,7 +137,7 @@ async fn regression_client_exact_remaining_quota_forwards_once_then_hard_cuts_of relay_result, Err(ProxyError::DataQuotaExceeded { .. }) )); - assert!(stats.get_user_total_octets(user) <= 10); + assert!(stats.get_user_quota_used(user) <= 10); } #[tokio::test(flavor = "multi_thread", worker_threads = 4)] @@ -139,9 +145,12 @@ async fn stress_parallel_relays_same_user_quota_overflow_never_exceeds_cap() { let stats = Arc::new(Stats::new()); let user = "quota-overflow-regression-stress"; let quota = 12u64; + const WORKERS: usize = 4; + const BURST_LEN: usize = 64; + let max_parallel_post_write_overshoot = (WORKERS * BURST_LEN) as u64; let mut handles = Vec::new(); - for _ in 0..4usize { + for _ in 0..WORKERS { let stats = Arc::clone(&stats); let user = user.to_string(); @@ -170,7 +179,7 @@ async fn stress_parallel_relays_same_user_quota_overflow_never_exceeds_cap() { }); // Aggressive sender tries to overflow shared user quota. - let burst = vec![0x5Au8; 64]; + let burst = vec![0x5Au8; BURST_LEN]; let _ = client_peer.write_all(&burst).await; let _ = client_peer.shutdown().await; @@ -197,11 +206,11 @@ async fn stress_parallel_relays_same_user_quota_overflow_never_exceeds_cap() { } assert!( - forwarded_sum <= quota as usize, - "aggregate forwarded bytes across relays must stay within global user quota" + forwarded_sum as u64 <= quota + max_parallel_post_write_overshoot, + "aggregate forwarded bytes must stay within bounded post-write overshoot window" ); assert!( - stats.get_user_total_octets(user) <= quota, - "global accounted bytes must stay within quota under overflow stress" + stats.get_user_quota_used(user) <= quota + max_parallel_post_write_overshoot, + "global accounted bytes must stay within bounded post-write overshoot window" ); } diff --git a/src/proxy/tests/relay_quota_wake_liveness_regression_tests.rs b/src/proxy/tests/relay_quota_wake_liveness_regression_tests.rs deleted file mode 100644 index 9f68258..0000000 --- a/src/proxy/tests/relay_quota_wake_liveness_regression_tests.rs +++ /dev/null @@ -1,294 +0,0 @@ -use super::*; -use crate::stats::Stats; -use dashmap::DashMap; -use std::sync::Arc; -use std::sync::atomic::AtomicBool; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; -use tokio::sync::Barrier; -use tokio::time::{Duration, timeout}; - -fn saturate_lock_cache() -> Vec>> { - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - for idx in 0..QUOTA_USER_LOCKS_MAX { - retained.push(quota_user_lock(&format!("quota-liveness-saturated-{idx}"))); - } - retained -} - -fn quota_test_guard() -> impl Drop { - super::quota_user_lock_test_scope() -} - -#[tokio::test] -async fn positive_writer_progresses_after_contention_release_without_external_wake() { - let _guard = quota_test_guard(); - - let _retained = saturate_lock_cache(); - let user = "quota-liveness-writer-positive"; - let stats = Arc::new(Stats::new()); - - let lock = quota_user_lock(user); - let held_guard = lock - .try_lock() - .expect("test must hold user quota lock before write"); - - let counters = Arc::new(SharedCounters::new()); - let quota_exceeded = Arc::new(AtomicBool::new(false)); - let mut io = StatsIo::new( - tokio::io::sink(), - counters, - Arc::clone(&stats), - user.to_string(), - Some(1024), - quota_exceeded, - tokio::time::Instant::now(), - ); - - let writer = tokio::spawn(async move { io.write_all(&[0x11]).await }); - - // Let the initial deferred wake fire while contention is still active. - tokio::time::sleep(Duration::from_millis(4)).await; - - drop(held_guard); - - let completed = timeout(Duration::from_millis(250), writer) - .await - .expect("writer must be re-polled and complete after lock release") - .expect("writer task must not panic"); - assert!(completed.is_ok(), "writer must complete after lock release"); -} - -#[tokio::test] -async fn edge_reader_progresses_after_contention_release_without_external_wake() { - let _guard = quota_test_guard(); - - let _retained = saturate_lock_cache(); - let user = "quota-liveness-reader-edge"; - let stats = Arc::new(Stats::new()); - - let lock = quota_user_lock(user); - let held_guard = lock - .try_lock() - .expect("test must hold user quota lock before read"); - - let counters = Arc::new(SharedCounters::new()); - let quota_exceeded = Arc::new(AtomicBool::new(false)); - let mut io = StatsIo::new( - tokio::io::empty(), - counters, - Arc::clone(&stats), - user.to_string(), - Some(1024), - quota_exceeded, - tokio::time::Instant::now(), - ); - - let reader = tokio::spawn(async move { - let mut one = [0u8; 1]; - io.read(&mut one).await - }); - - tokio::time::sleep(Duration::from_millis(4)).await; - drop(held_guard); - - let completed = timeout(Duration::from_millis(250), reader) - .await - .expect("reader must be re-polled and complete after lock release") - .expect("reader task must not panic"); - assert!(completed.is_ok(), "reader must complete after lock release"); -} - -#[tokio::test] -async fn adversarial_early_deferred_wake_consumption_does_not_deadlock_writer() { - let _guard = quota_test_guard(); - - let _retained = saturate_lock_cache(); - let user = "quota-liveness-adversarial"; - let stats = Arc::new(Stats::new()); - - let lock = quota_user_lock(user); - let held_guard = lock - .try_lock() - .expect("test must hold user quota lock before adversarial write"); - - let counters = Arc::new(SharedCounters::new()); - let quota_exceeded = Arc::new(AtomicBool::new(false)); - let mut io = StatsIo::new( - tokio::io::sink(), - counters, - Arc::clone(&stats), - user.to_string(), - Some(1024), - quota_exceeded, - tokio::time::Instant::now(), - ); - - let writer = tokio::spawn(async move { io.write_all(&[0x22]).await }); - - // Force multiple scheduler rounds while lock remains held so the first - // deferred wake has already been consumed under contention. - for _ in 0..32 { - tokio::task::yield_now().await; - } - - drop(held_guard); - - let completed = timeout(Duration::from_millis(300), writer) - .await - .expect("writer must not stay parked forever after release") - .expect("writer task must not panic"); - assert!(completed.is_ok()); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn integration_parallel_waiters_resume_after_single_release_event() { - let _guard = quota_test_guard(); - - let _retained = saturate_lock_cache(); - let user = format!("quota-liveness-integration-{}", std::process::id()); - let stats = Arc::new(Stats::new()); - let barrier = Arc::new(Barrier::new(13)); - - let lock = quota_user_lock(&user); - let held_guard = lock - .try_lock() - .expect("test must hold user quota lock before launching waiters"); - - let mut waiters = Vec::new(); - for _ in 0..12 { - let stats = Arc::clone(&stats); - let user = user.clone(); - let barrier = Arc::clone(&barrier); - waiters.push(tokio::spawn(async move { - let counters = Arc::new(SharedCounters::new()); - let quota_exceeded = Arc::new(AtomicBool::new(false)); - let mut io = StatsIo::new( - tokio::io::sink(), - counters, - stats, - user, - Some(4096), - quota_exceeded, - tokio::time::Instant::now(), - ); - barrier.wait().await; - io.write_all(&[0x33]).await - })); - } - - barrier.wait().await; - tokio::time::sleep(Duration::from_millis(4)).await; - drop(held_guard); - - timeout(Duration::from_secs(1), async { - for waiter in waiters { - let outcome = waiter.await.expect("waiter must not panic"); - assert!( - outcome.is_ok(), - "waiter must resume and complete after release" - ); - } - }) - .await - .expect("all waiters must complete in bounded time"); -} - -#[tokio::test] -async fn light_fuzz_release_timing_matrix_preserves_liveness() { - let _guard = quota_test_guard(); - - let _retained = saturate_lock_cache(); - let stats = Arc::new(Stats::new()); - - let mut seed = 0xD1CE_F00D_0123_4567u64; - for round in 0..64u32 { - seed ^= seed << 7; - seed ^= seed >> 9; - seed ^= seed << 8; - - let delay_ms = 1 + (seed & 0x7) as u64; - let user = format!("quota-liveness-fuzz-{}-{round}", std::process::id()); - - let lock = quota_user_lock(&user); - let held_guard = lock - .try_lock() - .expect("test must hold user quota lock in fuzz round"); - - let counters = Arc::new(SharedCounters::new()); - let quota_exceeded = Arc::new(AtomicBool::new(false)); - let mut io = StatsIo::new( - tokio::io::sink(), - counters, - Arc::clone(&stats), - user, - Some(2048), - quota_exceeded, - tokio::time::Instant::now(), - ); - - let writer = tokio::spawn(async move { io.write_all(&[0x44]).await }); - - tokio::time::sleep(Duration::from_millis(delay_ms)).await; - drop(held_guard); - - let done = timeout(Duration::from_millis(300), writer) - .await - .expect("fuzz round writer must complete") - .expect("fuzz writer task must not panic"); - assert!( - done.is_ok(), - "fuzz round writer must not stall after release" - ); - } -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn stress_repeated_contention_cycles_remain_live() { - let _guard = quota_test_guard(); - - let _retained = saturate_lock_cache(); - let stats = Arc::new(Stats::new()); - - for cycle in 0..40u32 { - let user = format!("quota-liveness-stress-{}-{cycle}", std::process::id()); - let lock = quota_user_lock(&user); - let held_guard = lock - .try_lock() - .expect("test must hold lock before stress cycle"); - - let mut tasks = Vec::new(); - for _ in 0..6 { - let stats = Arc::clone(&stats); - let user = user.clone(); - tasks.push(tokio::spawn(async move { - let counters = Arc::new(SharedCounters::new()); - let quota_exceeded = Arc::new(AtomicBool::new(false)); - let mut io = StatsIo::new( - tokio::io::sink(), - counters, - stats, - user, - Some(2048), - quota_exceeded, - tokio::time::Instant::now(), - ); - io.write_all(&[0x55]).await - })); - } - - tokio::task::yield_now().await; - drop(held_guard); - - timeout(Duration::from_millis(700), async { - for task in tasks { - let outcome = task.await.expect("stress task must not panic"); - assert!(outcome.is_ok(), "stress writer must complete"); - } - }) - .await - .expect("stress cycle must finish in bounded time"); - } -} diff --git a/src/proxy/tests/relay_quota_waker_storm_adversarial_tests.rs b/src/proxy/tests/relay_quota_waker_storm_adversarial_tests.rs deleted file mode 100644 index fa4878a..0000000 --- a/src/proxy/tests/relay_quota_waker_storm_adversarial_tests.rs +++ /dev/null @@ -1,310 +0,0 @@ -use super::*; -use crate::stats::Stats; -use dashmap::DashMap; -use std::pin::Pin; -use std::sync::Arc; -use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; -use std::task::{Context, Waker}; -use tokio::io::{AsyncWriteExt, ReadBuf}; -use tokio::time::{Duration, timeout}; - -#[derive(Default)] -struct WakeCounter { - wakes: AtomicUsize, -} - -impl std::task::Wake for WakeCounter { - fn wake(self: Arc) { - self.wakes.fetch_add(1, Ordering::Relaxed); - } - - fn wake_by_ref(self: &Arc) { - self.wakes.fetch_add(1, Ordering::Relaxed); - } -} - -fn quota_test_guard() -> impl Drop { - super::quota_user_lock_test_scope() -} - -fn saturate_quota_user_locks() -> Vec>> { - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - for idx in 0..QUOTA_USER_LOCKS_MAX { - retained.push(quota_user_lock(&format!("quota-waker-saturate-{idx}"))); - } - retained -} - -#[tokio::test] -async fn positive_contended_writer_emits_deferred_wake_for_liveness() { - let _guard = quota_test_guard(); - - let _retained = saturate_quota_user_locks(); - let stats = Arc::new(Stats::new()); - let user = "quota-waker-positive-user"; - - let lock = quota_user_lock(user); - let held_guard = lock - .try_lock() - .expect("test must hold overflow lock before polling writer"); - - let counters = Arc::new(SharedCounters::new()); - let quota_exceeded = Arc::new(AtomicBool::new(false)); - let mut io = StatsIo::new( - tokio::io::sink(), - counters, - Arc::clone(&stats), - user.to_string(), - Some(1024), - quota_exceeded, - tokio::time::Instant::now(), - ); - - let wake_counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&wake_counter)); - let mut cx = Context::from_waker(&waker); - - let pending = Pin::new(&mut io).poll_write(&mut cx, &[0xA1]); - assert!(pending.is_pending()); - - timeout(Duration::from_millis(100), async { - loop { - if wake_counter.wakes.load(Ordering::Relaxed) >= 1 { - break; - } - tokio::task::yield_now().await; - } - }) - .await - .expect("contended writer must receive deferred wake"); - - drop(held_guard); - let ready = Pin::new(&mut io).poll_write(&mut cx, &[0xA2]); - assert!( - ready.is_ready(), - "writer must progress after contention release" - ); -} - -#[tokio::test] -async fn adversarial_blackhat_writer_contention_does_not_create_waker_storm() { - let _guard = quota_test_guard(); - - let _retained = saturate_quota_user_locks(); - let stats = Arc::new(Stats::new()); - let user = "quota-waker-blackhat-writer"; - - let lock = quota_user_lock(user); - let held_guard = lock - .try_lock() - .expect("test must hold overflow lock before polling writer"); - - let counters = Arc::new(SharedCounters::new()); - let quota_exceeded = Arc::new(AtomicBool::new(false)); - let mut io = StatsIo::new( - tokio::io::sink(), - counters, - Arc::clone(&stats), - user.to_string(), - Some(1024), - quota_exceeded, - tokio::time::Instant::now(), - ); - - let wake_counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&wake_counter)); - let mut cx = Context::from_waker(&waker); - - for _ in 0..512 { - let poll = Pin::new(&mut io).poll_write(&mut cx, &[0xBE]); - assert!( - poll.is_pending(), - "writer must stay pending while lock is held" - ); - tokio::task::yield_now().await; - } - - let wakes = wake_counter.wakes.load(Ordering::Relaxed); - assert!( - wakes <= 128, - "pending writer retries must not trigger wake storm; observed wakes={wakes}" - ); - - drop(held_guard); - let ready = Pin::new(&mut io).poll_write(&mut cx, &[0xEF]); - assert!(ready.is_ready()); -} - -#[tokio::test] -async fn edge_read_path_contention_keeps_wake_budget_bounded() { - let _guard = quota_test_guard(); - - let _retained = saturate_quota_user_locks(); - let stats = Arc::new(Stats::new()); - let user = "quota-waker-read-edge"; - - let lock = quota_user_lock(user); - let held_guard = lock - .try_lock() - .expect("test must hold overflow lock before polling reader"); - - let counters = Arc::new(SharedCounters::new()); - let quota_exceeded = Arc::new(AtomicBool::new(false)); - let mut io = StatsIo::new( - tokio::io::empty(), - counters, - Arc::clone(&stats), - user.to_string(), - Some(1024), - quota_exceeded, - tokio::time::Instant::now(), - ); - - let wake_counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&wake_counter)); - let mut cx = Context::from_waker(&waker); - let mut storage = [0u8; 1]; - - for _ in 0..512 { - let mut buf = ReadBuf::new(&mut storage); - let poll = Pin::new(&mut io).poll_read(&mut cx, &mut buf); - assert!(poll.is_pending()); - tokio::task::yield_now().await; - } - - let wakes = wake_counter.wakes.load(Ordering::Relaxed); - assert!( - wakes <= 128, - "pending reader retries must not trigger wake storm; observed wakes={wakes}" - ); - - drop(held_guard); - let mut buf = ReadBuf::new(&mut storage); - let ready = Pin::new(&mut io).poll_read(&mut cx, &mut buf); - assert!(ready.is_ready()); -} - -#[tokio::test] -async fn light_fuzz_mixed_poll_schedule_under_contention_stays_bounded() { - let _guard = quota_test_guard(); - - let _retained = saturate_quota_user_locks(); - let stats = Arc::new(Stats::new()); - let user = "quota-waker-fuzz-user"; - - let lock = quota_user_lock(user); - let held_guard = lock - .try_lock() - .expect("test must hold overflow lock before fuzz polling"); - - let counters_w = Arc::new(SharedCounters::new()); - let mut writer_io = StatsIo::new( - tokio::io::sink(), - counters_w, - Arc::clone(&stats), - user.to_string(), - Some(1024), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - - let counters_r = Arc::new(SharedCounters::new()); - let mut reader_io = StatsIo::new( - tokio::io::empty(), - counters_r, - Arc::clone(&stats), - user.to_string(), - Some(1024), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - - let wake_counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&wake_counter)); - let mut cx = Context::from_waker(&waker); - let mut seed = 0xBADC_0FFE_EE11_2211u64; - let mut storage = [0u8; 1]; - - for _ in 0..1024 { - seed ^= seed << 7; - seed ^= seed >> 9; - seed ^= seed << 8; - - if (seed & 1) == 0 { - let poll = Pin::new(&mut writer_io).poll_write(&mut cx, &[0x44]); - assert!(poll.is_pending()); - } else { - let mut buf = ReadBuf::new(&mut storage); - let poll = Pin::new(&mut reader_io).poll_read(&mut cx, &mut buf); - assert!(poll.is_pending()); - } - tokio::task::yield_now().await; - } - - assert!( - wake_counter.wakes.load(Ordering::Relaxed) <= 192, - "mixed contention fuzz must keep deferred wake count tightly bounded" - ); - - drop(held_guard); - let ready_w = Pin::new(&mut writer_io).poll_write(&mut cx, &[0x55]); - assert!(ready_w.is_ready()); - - let mut buf = ReadBuf::new(&mut storage); - let ready_r = Pin::new(&mut reader_io).poll_read(&mut cx, &mut buf); - assert!(ready_r.is_ready()); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -#[ignore = "red-team detector: reveals possible starvation if deferred wake fires before contention release"] -async fn stress_many_contended_writers_complete_after_release() { - let _guard = quota_test_guard(); - - let _retained = saturate_quota_user_locks(); - let user = "quota-waker-stress-user".to_string(); - let stats = Arc::new(Stats::new()); - - let lock = quota_user_lock(&user); - let held_guard = lock - .try_lock() - .expect("test must hold overflow lock before launching contended tasks"); - - let mut tasks = Vec::new(); - for _ in 0..32 { - let stats = Arc::clone(&stats); - let user = user.clone(); - tasks.push(tokio::spawn(async move { - let counters = Arc::new(SharedCounters::new()); - let quota_exceeded = Arc::new(AtomicBool::new(false)); - let mut io = StatsIo::new( - tokio::io::sink(), - counters, - stats, - user, - Some(2048), - quota_exceeded, - tokio::time::Instant::now(), - ); - - io.write_all(&[0xAA]).await - })); - } - - for _ in 0..8 { - tokio::task::yield_now().await; - } - - drop(held_guard); - - timeout(Duration::from_secs(2), async { - for task in tasks { - let result = task.await.expect("stress task must not panic"); - assert!(result.is_ok(), "task must complete after lock release"); - } - }) - .await - .expect("all contended writer tasks must finish in bounded time after release"); -} diff --git a/src/proxy/tests/relay_security_tests.rs b/src/proxy/tests/relay_security_tests.rs deleted file mode 100644 index 50cdfa3..0000000 --- a/src/proxy/tests/relay_security_tests.rs +++ /dev/null @@ -1,1284 +0,0 @@ -use super::relay_bidirectional; -use crate::error::ProxyError; -use crate::stats::Stats; -use crate::stream::BufferPool; -use std::future::poll_fn; -use std::io; -use std::pin::Pin; -use std::sync::Arc; -use std::sync::Mutex; -use std::sync::atomic::{AtomicUsize, Ordering}; -use std::task::Waker; -use std::task::{Context, Poll}; -use tokio::io::{AsyncRead, ReadBuf}; -use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt, duplex}; -use tokio::time::{Duration, timeout}; - -#[derive(Default)] -struct WakeCounter { - wakes: AtomicUsize, -} - -impl std::task::Wake for WakeCounter { - fn wake(self: Arc) { - self.wakes.fetch_add(1, Ordering::Relaxed); - } - - fn wake_by_ref(self: &Arc) { - self.wakes.fetch_add(1, Ordering::Relaxed); - } -} - -#[tokio::test] -async fn quota_lock_contention_does_not_self_wake_pending_writer() { - let _guard = super::quota_user_lock_test_scope(); - let map = super::QUOTA_USER_LOCKS.get_or_init(dashmap::DashMap::new); - map.clear(); - - let stats = Arc::new(Stats::new()); - let user = "quota-lock-contention-user"; - - let lock = super::quota_user_lock(user); - let _held_lock = lock - .try_lock() - .expect("test must hold the per-user quota lock before polling writer"); - - let counters = Arc::new(super::SharedCounters::new()); - let quota_exceeded = Arc::new(std::sync::atomic::AtomicBool::new(false)); - let mut io = super::StatsIo::new( - tokio::io::sink(), - counters, - Arc::clone(&stats), - user.to_string(), - Some(1024), - quota_exceeded, - tokio::time::Instant::now(), - ); - - let wake_counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&wake_counter)); - let mut cx = Context::from_waker(&waker); - - let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x11]); - assert!( - poll.is_pending(), - "writer must remain pending while lock is contended" - ); - assert_eq!( - wake_counter.wakes.load(Ordering::Relaxed), - 0, - "contended quota lock must not self-wake immediately and spin the executor" - ); -} - -#[tokio::test] -async fn quota_lock_contention_writer_schedules_single_deferred_wake_until_lock_acquired() { - let _guard = super::quota_user_lock_test_scope(); - let map = super::QUOTA_USER_LOCKS.get_or_init(dashmap::DashMap::new); - map.clear(); - - let stats = Arc::new(Stats::new()); - let user = "quota-lock-writer-liveness-user"; - - let lock = super::quota_user_lock(user); - let held_lock = lock - .try_lock() - .expect("test must hold the per-user quota lock before polling writer"); - - let counters = Arc::new(super::SharedCounters::new()); - let quota_exceeded = Arc::new(std::sync::atomic::AtomicBool::new(false)); - let mut io = super::StatsIo::new( - tokio::io::sink(), - counters, - Arc::clone(&stats), - user.to_string(), - Some(1024), - quota_exceeded, - tokio::time::Instant::now(), - ); - - let wake_counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&wake_counter)); - let mut cx = Context::from_waker(&waker); - - let first = Pin::new(&mut io).poll_write(&mut cx, &[0x11]); - assert!( - first.is_pending(), - "writer must remain pending while lock is contended" - ); - assert_eq!( - wake_counter.wakes.load(Ordering::Relaxed), - 0, - "deferred wake must not fire synchronously" - ); - - timeout(Duration::from_millis(50), async { - loop { - if wake_counter.wakes.load(Ordering::Relaxed) >= 1 { - break; - } - tokio::task::yield_now().await; - } - }) - .await - .expect("contended writer must schedule a deferred wake in bounded time"); - let wakes_after_first_yield = wake_counter.wakes.load(Ordering::Relaxed); - assert!( - wakes_after_first_yield >= 1, - "contended writer must schedule at least one deferred wake for liveness" - ); - - let second = Pin::new(&mut io).poll_write(&mut cx, &[0x22]); - assert!( - second.is_pending(), - "writer remains pending while lock is still held" - ); - - for _ in 0..8 { - tokio::task::yield_now().await; - } - assert_eq!( - wake_counter.wakes.load(Ordering::Relaxed), - wakes_after_first_yield, - "writer contention should not schedule unbounded wake storms before lock acquisition" - ); - - drop(held_lock); - let released = Pin::new(&mut io).poll_write(&mut cx, &[0x33]); - assert!( - released.is_ready(), - "writer must make progress once quota lock is released" - ); -} - -#[tokio::test] -async fn quota_lock_contention_read_path_schedules_deferred_wake_for_liveness() { - let _guard = super::quota_user_lock_test_scope(); - let map = super::QUOTA_USER_LOCKS.get_or_init(dashmap::DashMap::new); - map.clear(); - - let stats = Arc::new(Stats::new()); - let user = "quota-lock-read-liveness-user"; - - let lock = super::quota_user_lock(user); - let held_lock = lock - .try_lock() - .expect("test must hold the per-user quota lock before polling reader"); - - let counters = Arc::new(super::SharedCounters::new()); - let quota_exceeded = Arc::new(std::sync::atomic::AtomicBool::new(false)); - let mut io = super::StatsIo::new( - tokio::io::empty(), - counters, - Arc::clone(&stats), - user.to_string(), - Some(1024), - quota_exceeded, - tokio::time::Instant::now(), - ); - - let wake_counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&wake_counter)); - let mut cx = Context::from_waker(&waker); - let mut storage = [0u8; 1]; - let mut buf = ReadBuf::new(&mut storage); - - let first = Pin::new(&mut io).poll_read(&mut cx, &mut buf); - assert!( - first.is_pending(), - "reader must remain pending while lock is contended" - ); - assert_eq!( - wake_counter.wakes.load(Ordering::Relaxed), - 0, - "read contention wake must not fire synchronously" - ); - - timeout(Duration::from_millis(50), async { - loop { - if wake_counter.wakes.load(Ordering::Relaxed) >= 1 { - break; - } - tokio::task::yield_now().await; - } - }) - .await - .expect("read contention must schedule a deferred wake in bounded time"); - - drop(held_lock); - let mut buf_after_release = ReadBuf::new(&mut storage); - let released = Pin::new(&mut io).poll_read(&mut cx, &mut buf_after_release); - assert!( - released.is_ready(), - "reader must make progress once quota lock is released" - ); -} - -#[tokio::test] -async fn relay_bidirectional_enforces_live_user_quota() { - let stats = Arc::new(Stats::new()); - let user = "quota-user"; - stats.add_user_octets_from(user, 6); - - let (mut client_peer, relay_client) = duplex(4096); - let (relay_server, mut server_peer) = duplex(4096); - - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let relay_task = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 1024, - 1024, - user, - Arc::clone(&stats), - Some(8), - Arc::new(BufferPool::new()), - )); - - client_peer - .write_all(&[0x10, 0x20, 0x30, 0x40]) - .await - .expect("client write must succeed"); - - let mut forwarded = [0u8; 4]; - let _ = timeout( - Duration::from_millis(200), - server_peer.read_exact(&mut forwarded), - ) - .await; - - let relay_result = timeout(Duration::from_secs(2), relay_task) - .await - .expect("relay task must finish under quota cutoff") - .expect("relay task must not panic"); - - assert!( - matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == "quota-user"), - "relay must surface a typed quota error once live quota is exceeded" - ); -} - -#[tokio::test] -async fn relay_bidirectional_does_not_forward_server_bytes_after_quota_is_exhausted() { - let stats = Arc::new(Stats::new()); - let quota_user = "quota-exhausted-user"; - stats.add_user_octets_from(quota_user, 1); - - let (mut client_peer, relay_client) = duplex(4096); - let (relay_server, mut server_peer) = duplex(4096); - - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let relay_task = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 1024, - 1024, - quota_user, - Arc::clone(&stats), - Some(1), - Arc::new(BufferPool::new()), - )); - - server_peer - .write_all(&[0xde, 0xad, 0xbe, 0xef]) - .await - .expect("server write must succeed"); - - let mut observed = [0u8; 4]; - let forwarded = timeout( - Duration::from_millis(200), - client_peer.read_exact(&mut observed), - ) - .await; - - let relay_result = timeout(Duration::from_secs(2), relay_task) - .await - .expect("relay task must finish under quota cutoff") - .expect("relay task must not panic"); - - assert!( - !matches!(forwarded, Ok(Ok(n)) if n == observed.len()), - "no full server payload should be forwarded once quota is already exhausted" - ); - assert!( - matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user), - "relay must still terminate with a typed quota error" - ); -} - -#[tokio::test] -async fn relay_bidirectional_does_not_leak_partial_server_payload_when_remaining_quota_is_smaller_than_write() - { - let stats = Arc::new(Stats::new()); - let quota_user = "partial-leak-user"; - stats.add_user_octets_from(quota_user, 3); - - let (mut client_peer, relay_client) = duplex(4096); - let (relay_server, mut server_peer) = duplex(4096); - - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let relay_task = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 1024, - 1024, - quota_user, - Arc::clone(&stats), - Some(4), - Arc::new(BufferPool::new()), - )); - - server_peer - .write_all(&[0x11, 0x22, 0x33, 0x44]) - .await - .expect("server write must succeed"); - - let mut observed = [0u8; 8]; - let forwarded = timeout(Duration::from_millis(200), client_peer.read(&mut observed)).await; - - let relay_result = timeout(Duration::from_secs(2), relay_task) - .await - .expect("relay task must finish under quota cutoff") - .expect("relay task must not panic"); - - assert!( - !matches!(forwarded, Ok(Ok(n)) if n > 0), - "quota exhaustion must not leak any partial server payload when remaining quota is smaller than the write" - ); - assert!( - matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user), - "relay must still terminate with a typed quota error" - ); -} - -#[tokio::test] -async fn relay_bidirectional_zero_quota_remains_fail_closed_for_server_payloads_under_stress() { - let stats = Arc::new(Stats::new()); - let quota_user = "zero-quota-user"; - - for payload_len in [1usize, 16, 512, 4096] { - let (mut client_peer, relay_client) = duplex(4096); - let (relay_server, mut server_peer) = duplex(4096); - - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let relay_task = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 1024, - 1024, - quota_user, - Arc::clone(&stats), - Some(0), - Arc::new(BufferPool::new()), - )); - - let payload = vec![0x7f; payload_len]; - let _ = server_peer.write_all(&payload).await; - - let mut observed = vec![0u8; payload_len]; - let forwarded = timeout(Duration::from_millis(200), client_peer.read(&mut observed)).await; - - let relay_result = timeout(Duration::from_secs(2), relay_task) - .await - .expect("relay task must finish under zero-quota cutoff") - .expect("relay task must not panic"); - - assert!( - !matches!(forwarded, Ok(Ok(n)) if n > 0), - "zero quota must not forward any server bytes for payload_len={payload_len}" - ); - assert!( - matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user), - "zero quota must terminate with the typed quota error for payload_len={payload_len}" - ); - } -} - -#[tokio::test] -async fn relay_bidirectional_allows_exact_server_payload_at_quota_boundary() { - let stats = Arc::new(Stats::new()); - let quota_user = "exact-boundary-user"; - - let (mut client_peer, relay_client) = duplex(4096); - let (relay_server, mut server_peer) = duplex(4096); - - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let relay_task = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 1024, - 1024, - quota_user, - Arc::clone(&stats), - Some(4), - Arc::new(BufferPool::new()), - )); - - server_peer - .write_all(&[0x91, 0x92, 0x93, 0x94]) - .await - .expect("server write must succeed at exact quota boundary"); - - let mut observed = [0u8; 4]; - client_peer - .read_exact(&mut observed) - .await - .expect("client must receive the full payload at the exact quota boundary"); - assert_eq!(observed, [0x91, 0x92, 0x93, 0x94]); - - let relay_result = timeout(Duration::from_secs(2), relay_task) - .await - .expect("relay task must finish after exact boundary delivery") - .expect("relay task must not panic"); - - assert!( - matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user), - "relay must close with a typed quota error after reaching the exact boundary" - ); -} - -#[tokio::test] -async fn relay_bidirectional_does_not_forward_client_bytes_after_quota_is_exhausted() { - let stats = Arc::new(Stats::new()); - let quota_user = "client-exhausted-user"; - stats.add_user_octets_from(quota_user, 1); - - let (mut client_peer, relay_client) = duplex(4096); - let (relay_server, mut server_peer) = duplex(4096); - - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let relay_task = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 1024, - 1024, - quota_user, - Arc::clone(&stats), - Some(1), - Arc::new(BufferPool::new()), - )); - - client_peer - .write_all(&[0x51, 0x52, 0x53, 0x54]) - .await - .expect("client write must succeed even when quota is already exhausted"); - - let mut observed = [0u8; 4]; - let forwarded = timeout( - Duration::from_millis(200), - server_peer.read_exact(&mut observed), - ) - .await; - - let relay_result = timeout(Duration::from_secs(2), relay_task) - .await - .expect("relay task must finish under quota cutoff") - .expect("relay task must not panic"); - - assert!( - !matches!(forwarded, Ok(Ok(n)) if n == observed.len()), - "client payload must not be fully forwarded once quota is already exhausted" - ); - assert!( - matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user), - "relay must still terminate with a typed quota error" - ); -} - -#[tokio::test] -async fn relay_bidirectional_server_bytes_remain_blocked_even_under_multiple_payload_sizes() { - let stats = Arc::new(Stats::new()); - let quota_user = "quota-fuzz-user"; - stats.add_user_octets_from(quota_user, 2); - - for payload_len in [1usize, 32, 1024, 8192] { - let (mut client_peer, relay_client) = duplex(4096); - let (relay_server, mut server_peer) = duplex(4096); - - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let relay_task = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 1024, - 1024, - quota_user, - Arc::clone(&stats), - Some(2), - Arc::new(BufferPool::new()), - )); - - let payload = vec![0xaa; payload_len]; - let _ = server_peer.write_all(&payload).await; - - let mut observed = vec![0u8; payload_len]; - let forwarded = timeout( - Duration::from_millis(200), - client_peer.read_exact(&mut observed), - ) - .await; - - let relay_result = timeout(Duration::from_secs(2), relay_task) - .await - .expect("relay task must finish under quota cutoff") - .expect("relay task must not panic"); - - assert!( - !matches!(forwarded, Ok(Ok(n)) if n == payload_len), - "quota exhaustion must block full server-to-client forwarding for payload_len={payload_len}" - ); - assert!( - matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user), - "relay must keep returning the typed quota error for payload_len={payload_len}" - ); - } -} - -#[tokio::test] -async fn relay_bidirectional_terminates_on_activity_timeout() { - tokio::time::pause(); - let stats = Arc::new(Stats::new()); - let user = "timeout-user"; - - let (client_peer, relay_client) = duplex(4096); - let (relay_server, server_peer) = duplex(4096); - - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let relay_task = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 1024, - 1024, - user, - Arc::clone(&stats), - None, // No quota - Arc::new(BufferPool::new()), - )); - - // Wait past the activity timeout threshold (1800 seconds) + buffer - tokio::time::sleep(Duration::from_secs(1805)).await; - - // Resume time to process timeouts - tokio::time::resume(); - - let relay_result = timeout(Duration::from_secs(1), relay_task) - .await - .expect("relay task must finish inside bounded timeout due to inactivity cutoff") - .expect("relay task must not panic"); - - assert!( - relay_result.is_ok(), - "relay should complete successfully on scheduled inactivity timeout" - ); - - // Verify client/server sockets are closed - drop(client_peer); - drop(server_peer); -} - -#[tokio::test] -async fn relay_bidirectional_watchdog_resists_premature_execution() { - tokio::time::pause(); - let stats = Arc::new(Stats::new()); - let user = "activity-user"; - - let (mut client_peer, relay_client) = duplex(4096); - let (relay_server, server_peer) = duplex(4096); - - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let mut relay_task = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 1024, - 1024, - user, - Arc::clone(&stats), - None, - Arc::new(BufferPool::new()), - )); - - // Advance by half the timeout - tokio::time::sleep(Duration::from_secs(900)).await; - - // Provide activity - client_peer - .write_all(&[0xaa, 0xbb]) - .await - .expect("client write must succeed"); - client_peer.flush().await.unwrap(); - - // Advance by another half (total time since start is 1800, but since last activity is 900) - tokio::time::sleep(Duration::from_secs(900)).await; - - tokio::time::resume(); - - // Re-evaluating the task, it should NOT have timed out and still be pending - let relay_result = timeout(Duration::from_millis(100), &mut relay_task).await; - assert!( - relay_result.is_err(), - "Relay must not exit prematurely as long as activity was received before timeout" - ); - - // Explicitly drop sockets to cleanly shut down relay loop - drop(client_peer); - drop(server_peer); - - let completion = timeout(Duration::from_secs(1), relay_task) - .await - .expect("relay task must complete securely after client disconnection") - .expect("relay task must not panic"); - assert!(completion.is_ok(), "relay exits clean"); -} - -#[tokio::test] -async fn relay_bidirectional_half_closure_terminates_cleanly() { - let stats = Arc::new(Stats::new()); - let (client_peer, relay_client) = duplex(4096); - let (relay_server, server_peer) = duplex(4096); - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let relay_task = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 1024, - 1024, - "half-close", - stats, - None, - Arc::new(BufferPool::new()), - )); - - // Half closure: drop the client completely but leave the server active. - drop(client_peer); - - // Check that we don't immediately crash. Bidirectional relay stays open for the server -> client flush. - // Eventually dropping the server cleanly closes the task. - drop(server_peer); - timeout(Duration::from_secs(1), relay_task) - .await - .unwrap() - .unwrap() - .unwrap(); -} - -#[tokio::test] -async fn relay_bidirectional_zero_length_noise_fuzzing() { - let stats = Arc::new(Stats::new()); - let (mut client_peer, relay_client) = duplex(4096); - let (relay_server, mut server_peer) = duplex(4096); - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let relay_task = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 1024, - 1024, - "fuzz", - stats, - None, - Arc::new(BufferPool::new()), - )); - - // Flood with zero-length payloads (edge cases in stream framing logic sometimes loop) - for _ in 0..100 { - client_peer.write_all(&[]).await.unwrap(); - } - client_peer.write_all(&[1, 2, 3]).await.unwrap(); - client_peer.flush().await.unwrap(); - - let mut buf = [0u8; 3]; - server_peer.read_exact(&mut buf).await.unwrap(); - assert_eq!(&buf, &[1, 2, 3]); - - drop(client_peer); - drop(server_peer); - timeout(Duration::from_secs(1), relay_task) - .await - .unwrap() - .unwrap() - .unwrap(); -} - -#[tokio::test] -async fn relay_bidirectional_asymmetric_backpressure() { - let stats = Arc::new(Stats::new()); - // Give the client stream an extremely narrow throughput limit explicitly - let (client_peer, relay_client) = duplex(1024); - let (relay_server, mut server_peer) = duplex(4096); - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let relay_task = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 1024, - 1024, - "slowloris", - stats, - None, - Arc::new(BufferPool::new()), - )); - - let payload = vec![0xba; 65536]; // 64k payload - - // Server attempts to shove 64KB into a relay whose client pipe only holds 1KB! - let write_res = - tokio::time::timeout(Duration::from_millis(50), server_peer.write_all(&payload)).await; - - assert!( - write_res.is_err(), - "Relay backpressure MUST halt the server writer from unbounded buffering when client stream is full!" - ); - - drop(client_peer); - drop(server_peer); - - let completion = timeout(Duration::from_secs(1), relay_task) - .await - .unwrap() - .unwrap(); - assert!( - completion.is_ok() || completion.is_err(), - "Task must unwind reliably (either Ok or BrokenPipe Err) when dropped despite active backpressure locks" - ); -} - -use rand::{RngExt, SeedableRng, rngs::StdRng}; - -#[tokio::test] -async fn relay_bidirectional_light_fuzzing_temporal_jitter() { - tokio::time::pause(); - let stats = Arc::new(Stats::new()); - let (mut client_peer, relay_client) = duplex(4096); - let (relay_server, server_peer) = duplex(4096); - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let mut relay_task = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 1024, - 1024, - "fuzz-user", - stats, - None, - Arc::new(BufferPool::new()), - )); - - let mut rng = StdRng::seed_from_u64(0xDEADBEEF); - - for _ in 0..10 { - // Vary timing significantly up to 1600 seconds (limit is 1800s) - let jitter = rng.random_range(100..1600); - tokio::time::sleep(Duration::from_secs(jitter)).await; - - client_peer.write_all(&[0x11]).await.unwrap(); - client_peer.flush().await.unwrap(); - - // Ensure task has not died - let res = timeout(Duration::from_millis(10), &mut relay_task).await; - assert!( - res.is_err(), - "Relay must remain open indefinitely under light temporal fuzzing with active jitter pulses" - ); - } - - drop(client_peer); - drop(server_peer); - timeout(Duration::from_secs(1), relay_task) - .await - .unwrap() - .unwrap() - .unwrap(); -} - -struct FaultyReader { - error_once: Option, -} - -struct TwoPartyGate { - arrivals: AtomicUsize, - total_bytes: AtomicUsize, - wakers: Mutex>, -} - -impl TwoPartyGate { - fn new() -> Self { - Self { - arrivals: AtomicUsize::new(0), - total_bytes: AtomicUsize::new(0), - wakers: Mutex::new(Vec::new()), - } - } - - fn arrive_or_park(&self, cx: &mut Context<'_>) -> bool { - if self.arrivals.load(Ordering::Relaxed) >= 2 { - return true; - } - - let prev = self.arrivals.fetch_add(1, Ordering::AcqRel); - if prev + 1 >= 2 { - let mut wakers = self.wakers.lock().unwrap_or_else(|p| p.into_inner()); - for waker in wakers.drain(..) { - waker.wake(); - } - true - } else { - let mut wakers = self.wakers.lock().unwrap_or_else(|p| p.into_inner()); - wakers.push(cx.waker().clone()); - false - } - } - - fn total_bytes(&self) -> usize { - self.total_bytes.load(Ordering::Relaxed) - } -} - -struct GateWriter { - gate: Arc, - entered: bool, -} - -impl GateWriter { - fn new(gate: Arc) -> Self { - Self { - gate, - entered: false, - } - } -} - -impl AsyncWrite for GateWriter { - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - if !self.entered { - self.entered = true; - } - - if !self.gate.arrive_or_park(cx) { - return Poll::Pending; - } - - self.gate - .total_bytes - .fetch_add(buf.len(), Ordering::Relaxed); - Poll::Ready(Ok(buf.len())) - } - - fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } -} - -struct GateReader { - gate: Arc, - entered: bool, - emitted: bool, -} - -impl GateReader { - fn new(gate: Arc) -> Self { - Self { - gate, - entered: false, - emitted: false, - } - } -} - -impl AsyncRead for GateReader { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - if self.emitted { - return Poll::Ready(Ok(())); - } - - if !self.entered { - self.entered = true; - } - - if !self.gate.arrive_or_park(cx) { - return Poll::Pending; - } - - buf.put_slice(&[0x42]); - self.gate.total_bytes.fetch_add(1, Ordering::Relaxed); - self.emitted = true; - Poll::Ready(Ok(())) - } -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn adversarial_concurrent_quota_write_race_does_not_overshoot_limit() { - let stats = Arc::new(Stats::new()); - let gate = Arc::new(TwoPartyGate::new()); - let user = "concurrent-quota-write".to_string(); - - let writer_a = super::StatsIo::new( - GateWriter::new(Arc::clone(&gate)), - Arc::new(super::SharedCounters::new()), - Arc::clone(&stats), - user.clone(), - Some(1), - Arc::new(std::sync::atomic::AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - - let writer_b = super::StatsIo::new( - GateWriter::new(Arc::clone(&gate)), - Arc::new(super::SharedCounters::new()), - Arc::clone(&stats), - user.clone(), - Some(1), - Arc::new(std::sync::atomic::AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - - let task_a = tokio::spawn(async move { - let mut w = writer_a; - AsyncWriteExt::write_all(&mut w, &[0x01]).await - }); - let task_b = tokio::spawn(async move { - let mut w = writer_b; - AsyncWriteExt::write_all(&mut w, &[0x02]).await - }); - - let (res_a, res_b) = tokio::join!(task_a, task_b); - let _ = res_a.expect("task a must join"); - let _ = res_b.expect("task b must join"); - - assert!( - gate.total_bytes() <= 1, - "concurrent same-user writes must not forward more than one byte under quota=1" - ); - assert!( - stats.get_user_total_octets(&user) <= 1, - "concurrent same-user writes must not account over limit" - ); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn adversarial_concurrent_quota_read_race_does_not_overshoot_limit() { - let stats = Arc::new(Stats::new()); - let gate = Arc::new(TwoPartyGate::new()); - let user = "concurrent-quota-read".to_string(); - - let reader_a = super::StatsIo::new( - GateReader::new(Arc::clone(&gate)), - Arc::new(super::SharedCounters::new()), - Arc::clone(&stats), - user.clone(), - Some(1), - Arc::new(std::sync::atomic::AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - - let reader_b = super::StatsIo::new( - GateReader::new(Arc::clone(&gate)), - Arc::new(super::SharedCounters::new()), - Arc::clone(&stats), - user.clone(), - Some(1), - Arc::new(std::sync::atomic::AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - - let task_a = tokio::spawn(async move { - let mut r = reader_a; - let mut one = [0u8; 1]; - AsyncReadExt::read_exact(&mut r, &mut one).await - }); - let task_b = tokio::spawn(async move { - let mut r = reader_b; - let mut one = [0u8; 1]; - AsyncReadExt::read_exact(&mut r, &mut one).await - }); - - let (res_a, res_b) = tokio::join!(task_a, task_b); - let _ = res_a.expect("task a must join"); - let _ = res_b.expect("task b must join"); - - assert!( - gate.total_bytes() <= 1, - "concurrent same-user reads must not consume more than one byte under quota=1" - ); - assert!( - stats.get_user_total_octets(&user) <= 1, - "concurrent same-user reads must not account over limit" - ); -} - -#[tokio::test] -async fn stress_same_user_quota_parallel_relays_never_exceed_limit() { - let stats = Arc::new(Stats::new()); - let user = "parallel-quota-user"; - - for _ in 0..128 { - let (mut client_peer_a, relay_client_a) = duplex(256); - let (relay_server_a, mut server_peer_a) = duplex(256); - let (mut client_peer_b, relay_client_b) = duplex(256); - let (relay_server_b, mut server_peer_b) = duplex(256); - - let (client_reader_a, client_writer_a) = tokio::io::split(relay_client_a); - let (server_reader_a, server_writer_a) = tokio::io::split(relay_server_a); - let (client_reader_b, client_writer_b) = tokio::io::split(relay_client_b); - let (server_reader_b, server_writer_b) = tokio::io::split(relay_server_b); - - let relay_a = tokio::spawn(relay_bidirectional( - client_reader_a, - client_writer_a, - server_reader_a, - server_writer_a, - 64, - 64, - user, - Arc::clone(&stats), - Some(1), - Arc::new(BufferPool::new()), - )); - - let relay_b = tokio::spawn(relay_bidirectional( - client_reader_b, - client_writer_b, - server_reader_b, - server_writer_b, - 64, - 64, - user, - Arc::clone(&stats), - Some(1), - Arc::new(BufferPool::new()), - )); - - let _ = tokio::join!( - client_peer_a.write_all(&[0x01]), - server_peer_a.write_all(&[0x02]), - client_peer_b.write_all(&[0x03]), - server_peer_b.write_all(&[0x04]), - ); - - let _ = timeout( - Duration::from_millis(50), - poll_fn(|cx| { - let mut one = [0u8; 1]; - let _ = Pin::new(&mut client_peer_a).poll_read(cx, &mut ReadBuf::new(&mut one)); - Poll::Ready(()) - }), - ) - .await; - - drop(client_peer_a); - drop(server_peer_a); - drop(client_peer_b); - drop(server_peer_b); - - let _ = timeout(Duration::from_secs(1), relay_a).await; - let _ = timeout(Duration::from_secs(1), relay_b).await; - - assert!( - stats.get_user_total_octets(user) <= 1, - "parallel relays must not exceed configured quota" - ); - } -} - -impl FaultyReader { - fn permission_denied_with_message(message: impl Into) -> Self { - Self { - error_once: Some(io::Error::new( - io::ErrorKind::PermissionDenied, - message.into(), - )), - } - } -} - -impl AsyncRead for FaultyReader { - fn poll_read( - mut self: Pin<&mut Self>, - _cx: &mut Context<'_>, - _buf: &mut ReadBuf<'_>, - ) -> Poll> { - if let Some(err) = self.error_once.take() { - return Poll::Ready(Err(err)); - } - Poll::Ready(Ok(())) - } -} - -#[tokio::test] -async fn relay_bidirectional_does_not_misclassify_transport_permission_denied_as_quota() { - let stats = Arc::new(Stats::new()); - let (client_peer, relay_client) = duplex(4096); - let (client_reader, client_writer) = tokio::io::split(relay_client); - - let relay_result = relay_bidirectional( - client_reader, - client_writer, - FaultyReader::permission_denied_with_message("user data quota exceeded"), - tokio::io::sink(), - 1024, - 1024, - "non-quota-permission-denied", - Arc::clone(&stats), - None, - Arc::new(BufferPool::new()), - ) - .await; - - drop(client_peer); - - assert!( - matches!(relay_result, Err(ProxyError::Io(ref err)) if err.kind() == io::ErrorKind::PermissionDenied), - "non-quota transport PermissionDenied errors must remain IO errors" - ); -} - -#[tokio::test] -async fn relay_bidirectional_light_fuzz_permission_denied_messages_remain_io_errors() { - let mut rng = StdRng::seed_from_u64(0xA11CE0B5); - - for i in 0..128u64 { - let stats = Arc::new(Stats::new()); - let (client_peer, relay_client) = duplex(1024); - let (client_reader, client_writer) = tokio::io::split(relay_client); - - let random_len = rng.random_range(1..=48); - let mut msg = String::with_capacity(random_len); - for _ in 0..random_len { - let ch = (b'a' + (rng.random::() % 26)) as char; - msg.push(ch); - } - // Include the legacy quota string in a subset of fuzz cases to validate - // collision resistance against message-based classification. - if i % 7 == 0 { - msg = "user data quota exceeded".to_string(); - } - - let relay_result = relay_bidirectional( - client_reader, - client_writer, - FaultyReader::permission_denied_with_message(msg), - tokio::io::sink(), - 1024, - 1024, - "fuzz-perm-denied", - Arc::clone(&stats), - None, - Arc::new(BufferPool::new()), - ) - .await; - - drop(client_peer); - - assert!( - matches!(relay_result, Err(ProxyError::Io(ref err)) if err.kind() == io::ErrorKind::PermissionDenied), - "transport PermissionDenied case must stay typed as IO regardless of message content" - ); - } -} - -#[tokio::test] -async fn relay_half_close_keeps_reverse_direction_progressing() { - let stats = Arc::new(Stats::new()); - let user = "half-close-user"; - - let (client_peer, relay_client) = duplex(1024); - let (relay_server, server_peer) = duplex(1024); - - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - let (mut cp_reader, mut cp_writer) = tokio::io::split(client_peer); - let (mut sp_reader, mut sp_writer) = tokio::io::split(server_peer); - - let relay_task = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 8192, - 8192, - user, - Arc::clone(&stats), - None, - Arc::new(BufferPool::new()), - )); - - sp_writer - .write_all(&[0x10, 0x20, 0x30, 0x40]) - .await - .unwrap(); - sp_writer.shutdown().await.unwrap(); - - let mut inbound = [0u8; 4]; - cp_reader.read_exact(&mut inbound).await.unwrap(); - assert_eq!(inbound, [0x10, 0x20, 0x30, 0x40]); - - cp_writer - .write_all(&[0xaa, 0xbb, 0xcc, 0xdd]) - .await - .unwrap(); - let mut outbound = [0u8; 4]; - sp_reader.read_exact(&mut outbound).await.unwrap(); - assert_eq!(outbound, [0xaa, 0xbb, 0xcc, 0xdd]); - - relay_task.abort(); - let joined = relay_task.await; - assert!(joined.is_err(), "aborted relay task must return join error"); -} diff --git a/src/stats/mod.rs b/src/stats/mod.rs index d13d834..ff15d4f 100644 --- a/src/stats/mod.rs +++ b/src/stats/mod.rs @@ -238,10 +238,12 @@ pub struct Stats { me_inline_recovery_total: AtomicU64, ip_reservation_rollback_tcp_limit_total: AtomicU64, ip_reservation_rollback_quota_limit_total: AtomicU64, + quota_write_fail_bytes_total: AtomicU64, + quota_write_fail_events_total: AtomicU64, telemetry_core_enabled: AtomicBool, telemetry_user_enabled: AtomicBool, telemetry_me_level: AtomicU8, - user_stats: DashMap, + user_stats: DashMap>, user_stats_last_cleanup_epoch_secs: AtomicU64, start_time: parking_lot::RwLock>, } @@ -254,9 +256,51 @@ pub struct UserStats { pub octets_to_client: AtomicU64, pub msgs_from_client: AtomicU64, pub msgs_to_client: AtomicU64, + /// Total bytes charged against per-user quota admission. + /// + /// This counter is the single source of truth for quota enforcement and + /// intentionally tracks attempted traffic, not guaranteed delivery. + pub quota_used: AtomicU64, pub last_seen_epoch_secs: AtomicU64, } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum QuotaReserveError { + LimitExceeded, + Contended, +} + +impl UserStats { + #[inline] + pub fn quota_used(&self) -> u64 { + self.quota_used.load(Ordering::Relaxed) + } + + /// Attempts one CAS reservation step against the quota counter. + /// + /// Callers control retry/yield policy. This primitive intentionally does + /// not block or sleep so both sync poll paths and async paths can wrap it + /// with their own contention strategy. + #[inline] + pub fn quota_try_reserve(&self, bytes: u64, limit: u64) -> Result { + let current = self.quota_used.load(Ordering::Relaxed); + if bytes > limit.saturating_sub(current) { + return Err(QuotaReserveError::LimitExceeded); + } + + let next = current.saturating_add(bytes); + match self.quota_used.compare_exchange_weak( + current, + next, + Ordering::Relaxed, + Ordering::Relaxed, + ) { + Ok(_) => Ok(next), + Err(_) => Err(QuotaReserveError::Contended), + } + } +} + impl Stats { pub fn new() -> Self { let stats = Self::default(); @@ -316,6 +360,74 @@ impl Stats { .store(Self::now_epoch_secs(), Ordering::Relaxed); } + pub(crate) fn get_or_create_user_stats_handle(&self, user: &str) -> Arc { + self.maybe_cleanup_user_stats(); + if let Some(existing) = self.user_stats.get(user) { + let handle = Arc::clone(existing.value()); + Self::touch_user_stats(handle.as_ref()); + return handle; + } + + let entry = self.user_stats.entry(user.to_string()).or_default(); + if entry.last_seen_epoch_secs.load(Ordering::Relaxed) == 0 { + Self::touch_user_stats(entry.value().as_ref()); + } + Arc::clone(entry.value()) + } + + #[inline] + pub(crate) fn add_user_octets_from_handle(&self, user_stats: &UserStats, bytes: u64) { + if !self.telemetry_user_enabled() { + return; + } + Self::touch_user_stats(user_stats); + user_stats + .octets_from_client + .fetch_add(bytes, Ordering::Relaxed); + } + + #[inline] + pub(crate) fn add_user_octets_to_handle(&self, user_stats: &UserStats, bytes: u64) { + if !self.telemetry_user_enabled() { + return; + } + Self::touch_user_stats(user_stats); + user_stats + .octets_to_client + .fetch_add(bytes, Ordering::Relaxed); + } + + #[inline] + pub(crate) fn increment_user_msgs_from_handle(&self, user_stats: &UserStats) { + if !self.telemetry_user_enabled() { + return; + } + Self::touch_user_stats(user_stats); + user_stats.msgs_from_client.fetch_add(1, Ordering::Relaxed); + } + + #[inline] + pub(crate) fn increment_user_msgs_to_handle(&self, user_stats: &UserStats) { + if !self.telemetry_user_enabled() { + return; + } + Self::touch_user_stats(user_stats); + user_stats.msgs_to_client.fetch_add(1, Ordering::Relaxed); + } + + /// Charges already committed bytes in a post-I/O path. + /// + /// This helper is intentionally separate from `quota_try_reserve` to avoid + /// mixing reserve and post-charge on a single I/O event. + #[inline] + pub(crate) fn quota_charge_post_write(&self, user_stats: &UserStats, bytes: u64) -> u64 { + Self::touch_user_stats(user_stats); + user_stats + .quota_used + .fetch_add(bytes, Ordering::Relaxed) + .saturating_add(bytes) + } + fn maybe_cleanup_user_stats(&self) { const USER_STATS_CLEANUP_INTERVAL_SECS: u64 = 60; const USER_STATS_IDLE_TTL_SECS: u64 = 24 * 60 * 60; @@ -704,7 +816,8 @@ impl Stats { } pub fn increment_me_d2c_data_frames_total(&self) { if self.telemetry_me_allows_normal() { - self.me_d2c_data_frames_total.fetch_add(1, Ordering::Relaxed); + self.me_d2c_data_frames_total + .fetch_add(1, Ordering::Relaxed); } } pub fn increment_me_d2c_ack_frames_total(&self) { @@ -1114,6 +1227,18 @@ impl Stats { .fetch_add(1, Ordering::Relaxed); } } + pub fn add_quota_write_fail_bytes_total(&self, bytes: u64) { + if self.telemetry_core_enabled() { + self.quota_write_fail_bytes_total + .fetch_add(bytes, Ordering::Relaxed); + } + } + pub fn increment_quota_write_fail_events_total(&self) { + if self.telemetry_core_enabled() { + self.quota_write_fail_events_total + .fetch_add(1, Ordering::Relaxed); + } + } pub fn increment_me_endpoint_quarantine_total(&self) { if self.telemetry_me_allows_normal() { self.me_endpoint_quarantine_total @@ -1588,7 +1713,8 @@ impl Stats { self.me_d2c_batch_bytes_bucket_1k_4k.load(Ordering::Relaxed) } pub fn get_me_d2c_batch_bytes_bucket_4k_16k(&self) -> u64 { - self.me_d2c_batch_bytes_bucket_4k_16k.load(Ordering::Relaxed) + self.me_d2c_batch_bytes_bucket_4k_16k + .load(Ordering::Relaxed) } pub fn get_me_d2c_batch_bytes_bucket_16k_64k(&self) -> u64 { self.me_d2c_batch_bytes_bucket_16k_64k @@ -1764,19 +1890,19 @@ impl Stats { self.ip_reservation_rollback_quota_limit_total .load(Ordering::Relaxed) } + pub fn get_quota_write_fail_bytes_total(&self) -> u64 { + self.quota_write_fail_bytes_total.load(Ordering::Relaxed) + } + pub fn get_quota_write_fail_events_total(&self) -> u64 { + self.quota_write_fail_events_total.load(Ordering::Relaxed) + } pub fn increment_user_connects(&self, user: &str) { if !self.telemetry_user_enabled() { return; } - self.maybe_cleanup_user_stats(); - if let Some(stats) = self.user_stats.get(user) { - Self::touch_user_stats(stats.value()); - stats.connects.fetch_add(1, Ordering::Relaxed); - return; - } - let stats = self.user_stats.entry(user.to_string()).or_default(); - Self::touch_user_stats(stats.value()); + let stats = self.get_or_create_user_stats_handle(user); + Self::touch_user_stats(stats.as_ref()); stats.connects.fetch_add(1, Ordering::Relaxed); } @@ -1784,14 +1910,8 @@ impl Stats { if !self.telemetry_user_enabled() { return; } - self.maybe_cleanup_user_stats(); - if let Some(stats) = self.user_stats.get(user) { - Self::touch_user_stats(stats.value()); - stats.curr_connects.fetch_add(1, Ordering::Relaxed); - return; - } - let stats = self.user_stats.entry(user.to_string()).or_default(); - Self::touch_user_stats(stats.value()); + let stats = self.get_or_create_user_stats_handle(user); + Self::touch_user_stats(stats.as_ref()); stats.curr_connects.fetch_add(1, Ordering::Relaxed); } @@ -1800,9 +1920,8 @@ impl Stats { return true; } - self.maybe_cleanup_user_stats(); - let stats = self.user_stats.entry(user.to_string()).or_default(); - Self::touch_user_stats(stats.value()); + let stats = self.get_or_create_user_stats_handle(user); + Self::touch_user_stats(stats.as_ref()); let counter = &stats.curr_connects; let mut current = counter.load(Ordering::Relaxed); @@ -1827,7 +1946,7 @@ impl Stats { pub fn decrement_user_curr_connects(&self, user: &str) { self.maybe_cleanup_user_stats(); if let Some(stats) = self.user_stats.get(user) { - Self::touch_user_stats(stats.value()); + Self::touch_user_stats(stats.value().as_ref()); let counter = &stats.curr_connects; let mut current = counter.load(Ordering::Relaxed); loop { @@ -1858,60 +1977,32 @@ impl Stats { if !self.telemetry_user_enabled() { return; } - self.maybe_cleanup_user_stats(); - if let Some(stats) = self.user_stats.get(user) { - Self::touch_user_stats(stats.value()); - stats.octets_from_client.fetch_add(bytes, Ordering::Relaxed); - return; - } - let stats = self.user_stats.entry(user.to_string()).or_default(); - Self::touch_user_stats(stats.value()); - stats.octets_from_client.fetch_add(bytes, Ordering::Relaxed); + let stats = self.get_or_create_user_stats_handle(user); + self.add_user_octets_from_handle(stats.as_ref(), bytes); } pub fn add_user_octets_to(&self, user: &str, bytes: u64) { if !self.telemetry_user_enabled() { return; } - self.maybe_cleanup_user_stats(); - if let Some(stats) = self.user_stats.get(user) { - Self::touch_user_stats(stats.value()); - stats.octets_to_client.fetch_add(bytes, Ordering::Relaxed); - return; - } - let stats = self.user_stats.entry(user.to_string()).or_default(); - Self::touch_user_stats(stats.value()); - stats.octets_to_client.fetch_add(bytes, Ordering::Relaxed); + let stats = self.get_or_create_user_stats_handle(user); + self.add_user_octets_to_handle(stats.as_ref(), bytes); } pub fn increment_user_msgs_from(&self, user: &str) { if !self.telemetry_user_enabled() { return; } - self.maybe_cleanup_user_stats(); - if let Some(stats) = self.user_stats.get(user) { - Self::touch_user_stats(stats.value()); - stats.msgs_from_client.fetch_add(1, Ordering::Relaxed); - return; - } - let stats = self.user_stats.entry(user.to_string()).or_default(); - Self::touch_user_stats(stats.value()); - stats.msgs_from_client.fetch_add(1, Ordering::Relaxed); + let stats = self.get_or_create_user_stats_handle(user); + self.increment_user_msgs_from_handle(stats.as_ref()); } pub fn increment_user_msgs_to(&self, user: &str) { if !self.telemetry_user_enabled() { return; } - self.maybe_cleanup_user_stats(); - if let Some(stats) = self.user_stats.get(user) { - Self::touch_user_stats(stats.value()); - stats.msgs_to_client.fetch_add(1, Ordering::Relaxed); - return; - } - let stats = self.user_stats.entry(user.to_string()).or_default(); - Self::touch_user_stats(stats.value()); - stats.msgs_to_client.fetch_add(1, Ordering::Relaxed); + let stats = self.get_or_create_user_stats_handle(user); + self.increment_user_msgs_to_handle(stats.as_ref()); } pub fn get_user_total_octets(&self, user: &str) -> u64 { @@ -1924,6 +2015,13 @@ impl Stats { .unwrap_or(0) } + pub fn get_user_quota_used(&self, user: &str) -> u64 { + self.user_stats + .get(user) + .map(|s| s.quota_used.load(Ordering::Relaxed)) + .unwrap_or(0) + } + pub fn get_handshake_timeouts(&self) -> u64 { self.handshake_timeouts.load(Ordering::Relaxed) } @@ -1989,7 +2087,7 @@ impl Stats { .load(Ordering::Relaxed) } - pub fn iter_user_stats(&self) -> dashmap::iter::Iter<'_, String, UserStats> { + pub fn iter_user_stats(&self) -> dashmap::iter::Iter<'_, String, Arc> { self.user_stats.iter() } @@ -2137,6 +2235,22 @@ impl ReplayChecker { found } + fn check_only_internal( + &self, + data: &[u8], + shards: &[Mutex], + window: Duration, + ) -> bool { + self.checks.fetch_add(1, Ordering::Relaxed); + let idx = self.get_shard_idx(data); + let mut shard = shards[idx].lock(); + let found = shard.check(data, Instant::now(), window); + if found { + self.hits.fetch_add(1, Ordering::Relaxed); + } + found + } + fn add_only(&self, data: &[u8], shards: &[Mutex], window: Duration) { self.additions.fetch_add(1, Ordering::Relaxed); let idx = self.get_shard_idx(data); @@ -2160,7 +2274,7 @@ impl ReplayChecker { self.add_only(data, &self.handshake_shards, self.window) } pub fn check_tls_digest(&self, data: &[u8]) -> bool { - self.check_and_add_tls_digest(data) + self.check_only_internal(data, &self.tls_shards, self.tls_window) } pub fn add_tls_digest(&self, data: &[u8]) { self.add_only(data, &self.tls_shards, self.tls_window) @@ -2264,6 +2378,7 @@ mod tests { use super::*; use crate::config::MeTelemetryLevel; use std::sync::Arc; + use std::sync::atomic::{AtomicU64, Ordering}; #[test] fn test_stats_shared_counters() { @@ -2431,6 +2546,137 @@ mod tests { } assert_eq!(checker.stats().total_entries, 500); } + + #[test] + fn test_quota_reserve_under_contention_hits_limit_exactly() { + let user_stats = Arc::new(UserStats::default()); + let successes = Arc::new(AtomicU64::new(0)); + let limit = 8_192u64; + let mut workers = Vec::new(); + + for _ in 0..8 { + let user_stats = user_stats.clone(); + let successes = successes.clone(); + workers.push(std::thread::spawn(move || { + loop { + match user_stats.quota_try_reserve(1, limit) { + Ok(_) => { + successes.fetch_add(1, Ordering::Relaxed); + } + Err(QuotaReserveError::Contended) => { + std::hint::spin_loop(); + } + Err(QuotaReserveError::LimitExceeded) => { + break; + } + } + } + })); + } + + for worker in workers { + worker.join().expect("worker thread must finish"); + } + + assert_eq!( + successes.load(Ordering::Relaxed), + limit, + "successful reservations must stop exactly at limit" + ); + assert_eq!(user_stats.quota_used(), limit); + } + + #[test] + fn test_quota_reserve_200x_1k_reaches_100k_without_overshoot() { + let user_stats = Arc::new(UserStats::default()); + let successes = Arc::new(AtomicU64::new(0)); + let failures = Arc::new(AtomicU64::new(0)); + let attempts = 200usize; + let reserve_bytes = 1_024u64; + let limit = 100 * 1_024u64; + let mut workers = Vec::with_capacity(attempts); + + for _ in 0..attempts { + let user_stats = user_stats.clone(); + let successes = successes.clone(); + let failures = failures.clone(); + workers.push(std::thread::spawn(move || { + loop { + match user_stats.quota_try_reserve(reserve_bytes, limit) { + Ok(_) => { + successes.fetch_add(1, Ordering::Relaxed); + return; + } + Err(QuotaReserveError::LimitExceeded) => { + failures.fetch_add(1, Ordering::Relaxed); + return; + } + Err(QuotaReserveError::Contended) => { + std::hint::spin_loop(); + } + } + } + })); + } + + for worker in workers { + worker.join().expect("reservation worker must finish"); + } + + assert_eq!( + successes.load(Ordering::Relaxed), + 100, + "exactly 100 reservations of 1 KiB must fit into a 100 KiB quota" + ); + assert_eq!( + failures.load(Ordering::Relaxed), + 100, + "remaining workers must fail once quota is fully reserved" + ); + assert_eq!(user_stats.quota_used(), limit); + } + + #[test] + fn test_quota_used_is_authoritative_and_independent_from_octets_telemetry() { + let stats = Stats::new(); + let user = "quota-authoritative-user"; + let user_stats = stats.get_or_create_user_stats_handle(user); + + stats.add_user_octets_to_handle(&user_stats, 5); + assert_eq!(stats.get_user_total_octets(user), 5); + assert_eq!(stats.get_user_quota_used(user), 0); + + stats.quota_charge_post_write(&user_stats, 7); + assert_eq!(stats.get_user_total_octets(user), 5); + assert_eq!(stats.get_user_quota_used(user), 7); + } + + #[test] + fn test_cached_handle_survives_map_cleanup_until_last_drop() { + let stats = Stats::new(); + let user = "quota-handle-lifetime-user"; + let user_stats = stats.get_or_create_user_stats_handle(user); + let weak = Arc::downgrade(&user_stats); + + stats.user_stats.remove(user); + assert!( + stats.user_stats.get(user).is_none(), + "map cleanup should remove idle entry" + ); + assert!( + weak.upgrade().is_some(), + "cached handle must keep user stats object alive after map removal" + ); + + stats.quota_charge_post_write(user_stats.as_ref(), 3); + assert_eq!(user_stats.quota_used(), 3); + + drop(user_stats); + assert!( + weak.upgrade().is_none(), + "user stats object must be dropped after the last cached handle is released" + ); + } } #[cfg(test)] diff --git a/src/stream/frame_stream_padding_security_tests.rs b/src/stream/frame_stream_padding_security_tests.rs index 83b30f9..1ec787e 100644 --- a/src/stream/frame_stream_padding_security_tests.rs +++ b/src/stream/frame_stream_padding_security_tests.rs @@ -14,7 +14,10 @@ fn padding_rounding_equivalent_for_extensive_safe_domain() { let old = old_padding_round_up_to_4(len).expect("old expression must be safe"); let new = new_padding_round_up_to_4(len).expect("new expression must be safe"); assert_eq!(old, new, "mismatch for len={len}"); - assert!(new >= len, "rounded length must not shrink: len={len}, out={new}"); + assert!( + new >= len, + "rounded length must not shrink: len={len}, out={new}" + ); assert_eq!(new % 4, 0, "rounded length must stay 4-byte aligned"); } } diff --git a/src/tests/ip_tracker_encapsulation_adversarial_tests.rs b/src/tests/ip_tracker_encapsulation_adversarial_tests.rs index cf42e75..3fc9727 100644 --- a/src/tests/ip_tracker_encapsulation_adversarial_tests.rs +++ b/src/tests/ip_tracker_encapsulation_adversarial_tests.rs @@ -44,7 +44,10 @@ async fn encapsulation_repeated_queue_poison_recovery_preserves_forward_progress let ip_primary = ip_from_idx(10_001); let ip_alt = ip_from_idx(10_002); - tracker.check_and_add("encap-poison", ip_primary).await.unwrap(); + tracker + .check_and_add("encap-poison", ip_primary) + .await + .unwrap(); for _ in 0..128 { let queue = tracker.cleanup_queue_mutex_for_tests(); diff --git a/src/tls_front/fetcher.rs b/src/tls_front/fetcher.rs index 4408b5a..bbfc336 100644 --- a/src/tls_front/fetcher.rs +++ b/src/tls_front/fetcher.rs @@ -1,7 +1,9 @@ #![allow(clippy::too_many_arguments)] +use dashmap::DashMap; use std::sync::Arc; -use std::time::Duration; +use std::sync::OnceLock; +use std::time::{Duration, Instant}; use anyhow::{Result, anyhow}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; @@ -21,7 +23,8 @@ use rustls::{DigitallySignedStruct, Error as RustlsError}; use x509_parser::certificate::X509Certificate; use x509_parser::prelude::FromDer; -use crate::crypto::SecureRandom; +use crate::config::TlsFetchProfile; +use crate::crypto::{SecureRandom, sha256}; use crate::network::dns_overrides::resolve_socket_addr; use crate::protocol::constants::{ TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER, TLS_RECORD_HANDSHAKE, @@ -78,6 +81,200 @@ impl ServerCertVerifier for NoVerify { } } +#[derive(Debug, Clone)] +pub struct TlsFetchStrategy { + pub profiles: Vec, + pub strict_route: bool, + pub attempt_timeout: Duration, + pub total_budget: Duration, + pub grease_enabled: bool, + pub deterministic: bool, + pub profile_cache_ttl: Duration, +} + +impl TlsFetchStrategy { + #[allow(dead_code)] + pub fn single_attempt(connect_timeout: Duration) -> Self { + Self { + profiles: vec![TlsFetchProfile::CompatTls12], + strict_route: false, + attempt_timeout: connect_timeout.max(Duration::from_millis(1)), + total_budget: connect_timeout.max(Duration::from_millis(1)), + grease_enabled: false, + deterministic: false, + profile_cache_ttl: Duration::ZERO, + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +struct ProfileCacheKey { + host: String, + port: u16, + sni: String, + scope: Option, + proxy_protocol: u8, + route_hint: RouteHint, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +enum RouteHint { + Direct, + Upstream, + Unix, +} + +#[derive(Debug, Clone, Copy)] +struct ProfileCacheValue { + profile: TlsFetchProfile, + updated_at: Instant, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum FetchErrorKind { + Connect, + Route, + EarlyEof, + Timeout, + ServerHelloMissing, + TlsAlert, + Parse, + Other, +} + +static PROFILE_CACHE: OnceLock> = OnceLock::new(); + +fn profile_cache() -> &'static DashMap { + PROFILE_CACHE.get_or_init(DashMap::new) +} + +fn route_hint( + upstream: Option<&std::sync::Arc>, + unix_sock: Option<&str>, +) -> RouteHint { + if unix_sock.is_some() { + RouteHint::Unix + } else if upstream.is_some() { + RouteHint::Upstream + } else { + RouteHint::Direct + } +} + +fn profile_cache_key( + host: &str, + port: u16, + sni: &str, + upstream: Option<&std::sync::Arc>, + scope: Option<&str>, + proxy_protocol: u8, + unix_sock: Option<&str>, +) -> ProfileCacheKey { + ProfileCacheKey { + host: host.to_string(), + port, + sni: sni.to_string(), + scope: scope.map(ToString::to_string), + proxy_protocol, + route_hint: route_hint(upstream, unix_sock), + } +} + +fn classify_fetch_error(err: &anyhow::Error) -> FetchErrorKind { + for cause in err.chain() { + if let Some(io) = cause.downcast_ref::() { + return match io.kind() { + std::io::ErrorKind::TimedOut => FetchErrorKind::Timeout, + std::io::ErrorKind::UnexpectedEof => FetchErrorKind::EarlyEof, + std::io::ErrorKind::ConnectionRefused + | std::io::ErrorKind::ConnectionAborted + | std::io::ErrorKind::ConnectionReset + | std::io::ErrorKind::NotConnected + | std::io::ErrorKind::AddrNotAvailable => FetchErrorKind::Connect, + _ => FetchErrorKind::Other, + }; + } + } + + let message = err.to_string().to_lowercase(); + if message.contains("upstream route") { + FetchErrorKind::Route + } else if message.contains("serverhello not received") { + FetchErrorKind::ServerHelloMissing + } else if message.contains("alert") { + FetchErrorKind::TlsAlert + } else if message.contains("parse") { + FetchErrorKind::Parse + } else if message.contains("timed out") || message.contains("deadline has elapsed") { + FetchErrorKind::Timeout + } else if message.contains("eof") { + FetchErrorKind::EarlyEof + } else { + FetchErrorKind::Other + } +} + +fn order_profiles( + strategy: &TlsFetchStrategy, + cache_key: Option<&ProfileCacheKey>, + now: Instant, +) -> Vec { + let mut ordered = if strategy.profiles.is_empty() { + vec![TlsFetchProfile::CompatTls12] + } else { + strategy.profiles.clone() + }; + + if strategy.profile_cache_ttl.is_zero() { + return ordered; + } + + let Some(key) = cache_key else { + return ordered; + }; + + if let Some(cached) = profile_cache().get(key) { + let age = now.saturating_duration_since(cached.updated_at); + if age > strategy.profile_cache_ttl { + drop(cached); + profile_cache().remove(key); + return ordered; + } + + if let Some(pos) = ordered + .iter() + .position(|profile| *profile == cached.profile) + { + if pos != 0 { + ordered.swap(0, pos); + } + } + } + + ordered +} + +fn remember_profile_success( + strategy: &TlsFetchStrategy, + cache_key: Option, + profile: TlsFetchProfile, + now: Instant, +) { + if strategy.profile_cache_ttl.is_zero() { + return; + } + let Some(key) = cache_key else { + return; + }; + profile_cache().insert( + key, + ProfileCacheValue { + profile, + updated_at: now, + }, + ); +} + fn build_client_config() -> Arc { let root = rustls::RootCertStore::empty(); @@ -95,7 +292,114 @@ fn build_client_config() -> Arc { Arc::new(config) } -fn build_client_hello(sni: &str, rng: &SecureRandom) -> Vec { +fn deterministic_bytes(seed: &str, len: usize) -> Vec { + let mut out = Vec::with_capacity(len); + let mut counter: u32 = 0; + while out.len() < len { + let mut chunk_seed = Vec::with_capacity(seed.len() + std::mem::size_of::()); + chunk_seed.extend_from_slice(seed.as_bytes()); + chunk_seed.extend_from_slice(&counter.to_le_bytes()); + out.extend_from_slice(&sha256(&chunk_seed)); + counter = counter.wrapping_add(1); + } + out.truncate(len); + out +} + +fn profile_cipher_suites(profile: TlsFetchProfile) -> &'static [u16] { + const MODERN_CHROME: &[u16] = &[ + 0x1301, 0x1302, 0x1303, 0xc02b, 0xc02c, 0xcca9, 0xc02f, 0xc030, 0xcca8, 0x009e, 0x00ff, + ]; + const MODERN_FIREFOX: &[u16] = &[ + 0x1301, 0x1303, 0x1302, 0xc02b, 0xcca9, 0xc02c, 0xc02f, 0xcca8, 0xc030, 0x009e, 0x00ff, + ]; + const COMPAT_TLS12: &[u16] = &[ + 0xc02b, 0xc02c, 0xc02f, 0xc030, 0xcca9, 0xcca8, 0x1301, 0x1302, 0x1303, 0x009e, 0x00ff, + ]; + const LEGACY_MINIMAL: &[u16] = &[0xc02b, 0xc02f, 0x1301, 0x1302, 0x00ff]; + + match profile { + TlsFetchProfile::ModernChromeLike => MODERN_CHROME, + TlsFetchProfile::ModernFirefoxLike => MODERN_FIREFOX, + TlsFetchProfile::CompatTls12 => COMPAT_TLS12, + TlsFetchProfile::LegacyMinimal => LEGACY_MINIMAL, + } +} + +fn profile_groups(profile: TlsFetchProfile) -> &'static [u16] { + const MODERN: &[u16] = &[0x001d, 0x0017, 0x0018]; // x25519, secp256r1, secp384r1 + const COMPAT: &[u16] = &[0x001d, 0x0017]; + const LEGACY: &[u16] = &[0x0017]; + + match profile { + TlsFetchProfile::ModernChromeLike | TlsFetchProfile::ModernFirefoxLike => MODERN, + TlsFetchProfile::CompatTls12 => COMPAT, + TlsFetchProfile::LegacyMinimal => LEGACY, + } +} + +fn profile_sig_algs(profile: TlsFetchProfile) -> &'static [u16] { + const MODERN: &[u16] = &[0x0804, 0x0805, 0x0403, 0x0503, 0x0806]; + const COMPAT: &[u16] = &[0x0403, 0x0503, 0x0804, 0x0805]; + const LEGACY: &[u16] = &[0x0403, 0x0804]; + + match profile { + TlsFetchProfile::ModernChromeLike | TlsFetchProfile::ModernFirefoxLike => MODERN, + TlsFetchProfile::CompatTls12 => COMPAT, + TlsFetchProfile::LegacyMinimal => LEGACY, + } +} + +fn profile_alpn(profile: TlsFetchProfile) -> &'static [&'static [u8]] { + const H2_HTTP11: &[&[u8]] = &[b"h2", b"http/1.1"]; + const HTTP11: &[&[u8]] = &[b"http/1.1"]; + match profile { + TlsFetchProfile::ModernChromeLike | TlsFetchProfile::ModernFirefoxLike => H2_HTTP11, + TlsFetchProfile::CompatTls12 | TlsFetchProfile::LegacyMinimal => HTTP11, + } +} + +fn profile_supported_versions(profile: TlsFetchProfile) -> &'static [u16] { + const MODERN: &[u16] = &[0x0304, 0x0303]; + const COMPAT: &[u16] = &[0x0303, 0x0304]; + const LEGACY: &[u16] = &[0x0303]; + match profile { + TlsFetchProfile::ModernChromeLike | TlsFetchProfile::ModernFirefoxLike => MODERN, + TlsFetchProfile::CompatTls12 => COMPAT, + TlsFetchProfile::LegacyMinimal => LEGACY, + } +} + +fn profile_padding_target(profile: TlsFetchProfile) -> usize { + match profile { + TlsFetchProfile::ModernChromeLike => 220, + TlsFetchProfile::ModernFirefoxLike => 200, + TlsFetchProfile::CompatTls12 => 180, + TlsFetchProfile::LegacyMinimal => 64, + } +} + +fn grease_value(rng: &SecureRandom, deterministic: bool, seed: &str) -> u16 { + const GREASE_VALUES: [u16; 16] = [ + 0x0a0a, 0x1a1a, 0x2a2a, 0x3a3a, 0x4a4a, 0x5a5a, 0x6a6a, 0x7a7a, 0x8a8a, 0x9a9a, 0xaaaa, + 0xbaba, 0xcaca, 0xdada, 0xeaea, 0xfafa, + ]; + if deterministic { + let idx = deterministic_bytes(seed, 1)[0] as usize % GREASE_VALUES.len(); + GREASE_VALUES[idx] + } else { + let idx = (rng.bytes(1)[0] as usize) % GREASE_VALUES.len(); + GREASE_VALUES[idx] + } +} + +fn build_client_hello( + sni: &str, + rng: &SecureRandom, + profile: TlsFetchProfile, + grease_enabled: bool, + deterministic: bool, +) -> Vec { // === ClientHello body === let mut body = Vec::new(); @@ -103,21 +407,24 @@ fn build_client_hello(sni: &str, rng: &SecureRandom) -> Vec { body.extend_from_slice(&[0x03, 0x03]); // Random - body.extend_from_slice(&rng.bytes(32)); + if deterministic { + body.extend_from_slice(&deterministic_bytes(&format!("tls-fetch-random:{sni}"), 32)); + } else { + body.extend_from_slice(&rng.bytes(32)); + } // Session ID: empty body.push(0); - // Cipher suites (common minimal set, TLS1.3 + a few 1.2 fallbacks) - let cipher_suites: [u8; 10] = [ - 0x13, 0x01, // TLS_AES_128_GCM_SHA256 - 0x13, 0x02, // TLS_AES_256_GCM_SHA384 - 0x13, 0x03, // TLS_CHACHA20_POLY1305_SHA256 - 0x00, 0x2f, // TLS_RSA_WITH_AES_128_CBC_SHA (legacy) - 0x00, 0xff, // RENEGOTIATION_INFO_SCSV - ]; - body.extend_from_slice(&(cipher_suites.len() as u16).to_be_bytes()); - body.extend_from_slice(&cipher_suites); + let mut cipher_suites = profile_cipher_suites(profile).to_vec(); + if grease_enabled { + let grease = grease_value(rng, deterministic, &format!("cipher:{sni}")); + cipher_suites.insert(0, grease); + } + body.extend_from_slice(&((cipher_suites.len() * 2) as u16).to_be_bytes()); + for suite in cipher_suites { + body.extend_from_slice(&suite.to_be_bytes()); + } // Compression methods: null only body.push(1); @@ -138,7 +445,11 @@ fn build_client_hello(sni: &str, rng: &SecureRandom) -> Vec { exts.extend_from_slice(&sni_ext); // supported_groups - let groups: [u16; 2] = [0x001d, 0x0017]; // x25519, secp256r1 + let mut groups = profile_groups(profile).to_vec(); + if grease_enabled { + let grease = grease_value(rng, deterministic, &format!("group:{sni}")); + groups.insert(0, grease); + } exts.extend_from_slice(&0x000au16.to_be_bytes()); exts.extend_from_slice(&((2 + groups.len() * 2) as u16).to_be_bytes()); exts.extend_from_slice(&(groups.len() as u16 * 2).to_be_bytes()); @@ -147,7 +458,11 @@ fn build_client_hello(sni: &str, rng: &SecureRandom) -> Vec { } // signature_algorithms - let sig_algs: [u16; 4] = [0x0804, 0x0805, 0x0403, 0x0503]; // rsa_pss_rsae_sha256/384, ecdsa_secp256r1_sha256, rsa_pkcs1_sha256 + let mut sig_algs = profile_sig_algs(profile).to_vec(); + if grease_enabled { + let grease = grease_value(rng, deterministic, &format!("sigalg:{sni}")); + sig_algs.insert(0, grease); + } exts.extend_from_slice(&0x000du16.to_be_bytes()); exts.extend_from_slice(&((2 + sig_algs.len() * 2) as u16).to_be_bytes()); exts.extend_from_slice(&(sig_algs.len() as u16 * 2).to_be_bytes()); @@ -155,8 +470,12 @@ fn build_client_hello(sni: &str, rng: &SecureRandom) -> Vec { exts.extend_from_slice(&a.to_be_bytes()); } - // supported_versions (TLS1.3 + TLS1.2) - let versions: [u16; 2] = [0x0304, 0x0303]; + // supported_versions + let mut versions = profile_supported_versions(profile).to_vec(); + if grease_enabled { + let grease = grease_value(rng, deterministic, &format!("version:{sni}")); + versions.insert(0, grease); + } exts.extend_from_slice(&0x002bu16.to_be_bytes()); exts.extend_from_slice(&((1 + versions.len() * 2) as u16).to_be_bytes()); exts.push((versions.len() * 2) as u8); @@ -165,7 +484,14 @@ fn build_client_hello(sni: &str, rng: &SecureRandom) -> Vec { } // key_share (x25519) - let key = gen_key_share(rng); + let key = if deterministic { + let det = deterministic_bytes(&format!("keyshare:{sni}"), 32); + let mut key = [0u8; 32]; + key.copy_from_slice(&det); + key + } else { + gen_key_share(rng) + }; let mut keyshare = Vec::with_capacity(4 + key.len()); keyshare.extend_from_slice(&0x001du16.to_be_bytes()); // group keyshare.extend_from_slice(&(key.len() as u16).to_be_bytes()); @@ -175,18 +501,29 @@ fn build_client_hello(sni: &str, rng: &SecureRandom) -> Vec { exts.extend_from_slice(&(keyshare.len() as u16).to_be_bytes()); exts.extend_from_slice(&keyshare); - // ALPN (http/1.1) - let alpn_proto = b"http/1.1"; - exts.extend_from_slice(&0x0010u16.to_be_bytes()); - exts.extend_from_slice(&((2 + 1 + alpn_proto.len()) as u16).to_be_bytes()); - exts.extend_from_slice(&((1 + alpn_proto.len()) as u16).to_be_bytes()); - exts.push(alpn_proto.len() as u8); - exts.extend_from_slice(alpn_proto); + // ALPN + let mut alpn_list = Vec::new(); + for proto in profile_alpn(profile) { + alpn_list.push(proto.len() as u8); + alpn_list.extend_from_slice(proto); + } + if !alpn_list.is_empty() { + exts.extend_from_slice(&0x0010u16.to_be_bytes()); + exts.extend_from_slice(&((2 + alpn_list.len()) as u16).to_be_bytes()); + exts.extend_from_slice(&(alpn_list.len() as u16).to_be_bytes()); + exts.extend_from_slice(&alpn_list); + } + + if grease_enabled { + let grease = grease_value(rng, deterministic, &format!("ext:{sni}")); + exts.extend_from_slice(&grease.to_be_bytes()); + exts.extend_from_slice(&0u16.to_be_bytes()); + } // padding to reduce recognizability and keep length ~500 bytes - const TARGET_EXT_LEN: usize = 180; - if exts.len() < TARGET_EXT_LEN { - let remaining = TARGET_EXT_LEN - exts.len(); + let target_ext_len = profile_padding_target(profile); + if exts.len() < target_ext_len { + let remaining = target_ext_len - exts.len(); if remaining > 4 { let pad_len = remaining - 4; // minus type+len exts.extend_from_slice(&0x0015u16.to_be_bytes()); // padding extension @@ -402,27 +739,41 @@ async fn connect_tcp_with_upstream( connect_timeout: Duration, upstream: Option>, scope: Option<&str>, + strict_route: bool, ) -> Result { if let Some(manager) = upstream { - if let Some(addr) = resolve_socket_addr(host, port) { - match manager.connect(addr, None, scope).await { - Ok(stream) => return Ok(stream), - Err(e) => { - warn!( - host = %host, - port = port, - scope = ?scope, - error = %e, - "Upstream connect failed, using direct connect" - ); - } - } - } else if let Ok(mut addrs) = tokio::net::lookup_host((host, port)).await - && let Some(addr) = addrs.find(|a| a.is_ipv4()) - { + let resolved = if let Some(addr) = resolve_socket_addr(host, port) { + Some(addr) + } else { + match tokio::net::lookup_host((host, port)).await { + Ok(mut addrs) => addrs.find(|a| a.is_ipv4()), + Err(e) => { + if strict_route { + return Err(anyhow!( + "upstream route DNS resolution failed for {host}:{port}: {e}" + )); + } + warn!( + host = %host, + port = port, + scope = ?scope, + error = %e, + "Upstream DNS resolution failed, using direct connect" + ); + None + } + } + }; + + if let Some(addr) = resolved { match manager.connect(addr, None, scope).await { Ok(stream) => return Ok(stream), Err(e) => { + if strict_route { + return Err(anyhow!( + "upstream route connect failed for {host}:{port}: {e}" + )); + } warn!( host = %host, port = port, @@ -432,6 +783,10 @@ async fn connect_tcp_with_upstream( ); } } + } else if strict_route { + return Err(anyhow!( + "upstream route resolution produced no usable address for {host}:{port}" + )); } } Ok(UpstreamStream::Tcp( @@ -471,12 +826,15 @@ async fn fetch_via_raw_tls_stream( sni: &str, connect_timeout: Duration, proxy_protocol: u8, + profile: TlsFetchProfile, + grease_enabled: bool, + deterministic: bool, ) -> Result where S: AsyncRead + AsyncWrite + Unpin, { let rng = SecureRandom::new(); - let client_hello = build_client_hello(sni, &rng); + let client_hello = build_client_hello(sni, &rng, profile, grease_enabled, deterministic); timeout(connect_timeout, async { if proxy_protocol > 0 { let header = match proxy_protocol { @@ -550,6 +908,10 @@ async fn fetch_via_raw_tls( scope: Option<&str>, proxy_protocol: u8, unix_sock: Option<&str>, + strict_route: bool, + profile: TlsFetchProfile, + grease_enabled: bool, + deterministic: bool, ) -> Result { #[cfg(unix)] if let Some(sock_path) = unix_sock { @@ -560,8 +922,16 @@ async fn fetch_via_raw_tls( sock = %sock_path, "Raw TLS fetch using mask unix socket" ); - return fetch_via_raw_tls_stream(stream, sni, connect_timeout, proxy_protocol) - .await; + return fetch_via_raw_tls_stream( + stream, + sni, + connect_timeout, + proxy_protocol, + profile, + grease_enabled, + deterministic, + ) + .await; } Ok(Err(e)) => { warn!( @@ -584,8 +954,19 @@ async fn fetch_via_raw_tls( #[cfg(not(unix))] let _ = unix_sock; - let stream = connect_tcp_with_upstream(host, port, connect_timeout, upstream, scope).await?; - fetch_via_raw_tls_stream(stream, sni, connect_timeout, proxy_protocol).await + let stream = + connect_tcp_with_upstream(host, port, connect_timeout, upstream, scope, strict_route) + .await?; + fetch_via_raw_tls_stream( + stream, + sni, + connect_timeout, + proxy_protocol, + profile, + grease_enabled, + deterministic, + ) + .await } async fn fetch_via_rustls_stream( @@ -691,6 +1072,7 @@ async fn fetch_via_rustls( scope: Option<&str>, proxy_protocol: u8, unix_sock: Option<&str>, + strict_route: bool, ) -> Result { #[cfg(unix)] if let Some(sock_path) = unix_sock { @@ -724,16 +1106,153 @@ async fn fetch_via_rustls( #[cfg(not(unix))] let _ = unix_sock; - let stream = connect_tcp_with_upstream(host, port, connect_timeout, upstream, scope).await?; + let stream = + connect_tcp_with_upstream(host, port, connect_timeout, upstream, scope, strict_route) + .await?; fetch_via_rustls_stream(stream, host, sni, proxy_protocol).await } -/// Fetch real TLS metadata for the given SNI. -/// -/// Strategy: -/// 1) Probe raw TLS for realistic ServerHello and ApplicationData record sizes. -/// 2) Fetch certificate chain via rustls to build cert payload. -/// 3) Merge both when possible; otherwise auto-fallback to whichever succeeded. +/// Fetch real TLS metadata with an adaptive multi-profile strategy. +pub async fn fetch_real_tls_with_strategy( + host: &str, + port: u16, + sni: &str, + strategy: &TlsFetchStrategy, + upstream: Option>, + scope: Option<&str>, + proxy_protocol: u8, + unix_sock: Option<&str>, +) -> Result { + let attempt_timeout = strategy.attempt_timeout.max(Duration::from_millis(1)); + let total_budget = strategy.total_budget.max(Duration::from_millis(1)); + let started_at = Instant::now(); + let cache_key = profile_cache_key( + host, + port, + sni, + upstream.as_ref(), + scope, + proxy_protocol, + unix_sock, + ); + let profiles = order_profiles(strategy, Some(&cache_key), started_at); + + let mut raw_result = None; + let mut raw_last_error: Option = None; + let mut raw_last_error_kind = FetchErrorKind::Other; + let mut selected_profile = None; + + for profile in profiles { + let elapsed = started_at.elapsed(); + if elapsed >= total_budget { + break; + } + let timeout_for_attempt = attempt_timeout.min(total_budget - elapsed); + + match fetch_via_raw_tls( + host, + port, + sni, + timeout_for_attempt, + upstream.clone(), + scope, + proxy_protocol, + unix_sock, + strategy.strict_route, + profile, + strategy.grease_enabled, + strategy.deterministic, + ) + .await + { + Ok(res) => { + selected_profile = Some(profile); + raw_result = Some(res); + break; + } + Err(err) => { + let kind = classify_fetch_error(&err); + warn!( + sni = %sni, + profile = profile.as_str(), + error_kind = ?kind, + error = %err, + "Raw TLS fetch attempt failed" + ); + raw_last_error_kind = kind; + raw_last_error = Some(err); + if strategy.strict_route && matches!(kind, FetchErrorKind::Route) { + break; + } + } + } + } + + if let Some(profile) = selected_profile { + remember_profile_success(strategy, Some(cache_key), profile, Instant::now()); + } + + if raw_result.is_none() + && strategy.strict_route + && matches!(raw_last_error_kind, FetchErrorKind::Route) + { + if let Some(err) = raw_last_error { + return Err(err); + } + return Err(anyhow!("TLS fetch strict-route failure")); + } + + let elapsed = started_at.elapsed(); + if elapsed >= total_budget { + return match raw_result { + Some(raw) => Ok(raw), + None => { + Err(raw_last_error.unwrap_or_else(|| anyhow!("TLS fetch total budget exhausted"))) + } + }; + } + + let rustls_timeout = attempt_timeout.min(total_budget - elapsed); + let rustls_result = fetch_via_rustls( + host, + port, + sni, + rustls_timeout, + upstream, + scope, + proxy_protocol, + unix_sock, + strategy.strict_route, + ) + .await; + + match rustls_result { + Ok(rustls) => { + if let Some(mut raw) = raw_result { + raw.cert_info = rustls.cert_info; + raw.cert_payload = rustls.cert_payload; + raw.behavior_profile.source = TlsProfileSource::Merged; + debug!(sni = %sni, "Fetched TLS metadata via adaptive raw probe + rustls cert chain"); + Ok(raw) + } else { + Ok(rustls) + } + } + Err(err) => { + if let Some(raw) = raw_result { + warn!(sni = %sni, error = %err, "Rustls cert fetch failed, using raw TLS metadata only"); + Ok(raw) + } else if let Some(raw_err) = raw_last_error { + Err(anyhow!("TLS fetch failed (raw: {raw_err}; rustls: {err})")) + } else { + Err(err) + } + } + } +} + +/// Fetch real TLS metadata for the given SNI using a single-attempt compatibility strategy. +#[allow(dead_code)] pub async fn fetch_real_tls( host: &str, port: u16, @@ -744,62 +1263,30 @@ pub async fn fetch_real_tls( proxy_protocol: u8, unix_sock: Option<&str>, ) -> Result { - let raw_result = match fetch_via_raw_tls( + let strategy = TlsFetchStrategy::single_attempt(connect_timeout); + fetch_real_tls_with_strategy( host, port, sni, - connect_timeout, - upstream.clone(), - scope, - proxy_protocol, - unix_sock, - ) - .await - { - Ok(res) => Some(res), - Err(e) => { - warn!(sni = %sni, error = %e, "Raw TLS fetch failed"); - None - } - }; - - match fetch_via_rustls( - host, - port, - sni, - connect_timeout, + &strategy, upstream, scope, proxy_protocol, unix_sock, ) .await - { - Ok(rustls_result) => { - if let Some(mut raw) = raw_result { - raw.cert_info = rustls_result.cert_info; - raw.cert_payload = rustls_result.cert_payload; - raw.behavior_profile.source = TlsProfileSource::Merged; - debug!(sni = %sni, "Fetched TLS metadata via raw probe + rustls cert chain"); - Ok(raw) - } else { - Ok(rustls_result) - } - } - Err(e) => { - if let Some(raw) = raw_result { - warn!(sni = %sni, error = %e, "Rustls cert fetch failed, using raw TLS metadata only"); - Ok(raw) - } else { - Err(e) - } - } - } } #[cfg(test)] mod tests { - use super::{derive_behavior_profile, encode_tls13_certificate_message}; + use std::time::{Duration, Instant}; + + use super::{ + ProfileCacheValue, TlsFetchStrategy, build_client_hello, derive_behavior_profile, + encode_tls13_certificate_message, order_profiles, profile_cache, profile_cache_key, + }; + use crate::config::TlsFetchProfile; + use crate::crypto::SecureRandom; use crate::protocol::constants::{ TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER, TLS_RECORD_HANDSHAKE, }; @@ -812,8 +1299,8 @@ mod tests { #[test] fn test_encode_tls13_certificate_message_single_cert() { let cert = vec![0x30, 0x03, 0x02, 0x01, 0x01]; - let message = encode_tls13_certificate_message(std::slice::from_ref(&cert)) - .expect("message"); + let message = + encode_tls13_certificate_message(std::slice::from_ref(&cert)).expect("message"); assert_eq!(message[0], 0x0b); assert_eq!(read_u24(&message[1..4]), message.len() - 4); @@ -848,4 +1335,93 @@ mod tests { assert_eq!(profile.ticket_record_sizes, vec![220, 180]); assert_eq!(profile.source, TlsProfileSource::Raw); } + + #[test] + fn test_order_profiles_prioritizes_fresh_cached_winner() { + let strategy = TlsFetchStrategy { + profiles: vec![ + TlsFetchProfile::ModernChromeLike, + TlsFetchProfile::CompatTls12, + TlsFetchProfile::LegacyMinimal, + ], + strict_route: true, + attempt_timeout: Duration::from_secs(1), + total_budget: Duration::from_secs(2), + grease_enabled: false, + deterministic: false, + profile_cache_ttl: Duration::from_secs(60), + }; + let cache_key = profile_cache_key( + "mask.example", + 443, + "tls.example", + None, + Some("tls"), + 0, + None, + ); + profile_cache().remove(&cache_key); + profile_cache().insert( + cache_key.clone(), + ProfileCacheValue { + profile: TlsFetchProfile::CompatTls12, + updated_at: Instant::now(), + }, + ); + + let ordered = order_profiles(&strategy, Some(&cache_key), Instant::now()); + assert_eq!(ordered[0], TlsFetchProfile::CompatTls12); + profile_cache().remove(&cache_key); + } + + #[test] + fn test_order_profiles_drops_expired_cached_winner() { + let strategy = TlsFetchStrategy { + profiles: vec![ + TlsFetchProfile::ModernFirefoxLike, + TlsFetchProfile::CompatTls12, + ], + strict_route: true, + attempt_timeout: Duration::from_secs(1), + total_budget: Duration::from_secs(2), + grease_enabled: false, + deterministic: false, + profile_cache_ttl: Duration::from_secs(5), + }; + let cache_key = + profile_cache_key("mask2.example", 443, "tls2.example", None, None, 0, None); + profile_cache().remove(&cache_key); + profile_cache().insert( + cache_key.clone(), + ProfileCacheValue { + profile: TlsFetchProfile::CompatTls12, + updated_at: Instant::now() - Duration::from_secs(6), + }, + ); + + let ordered = order_profiles(&strategy, Some(&cache_key), Instant::now()); + assert_eq!(ordered[0], TlsFetchProfile::ModernFirefoxLike); + assert!(profile_cache().get(&cache_key).is_none()); + } + + #[test] + fn test_deterministic_client_hello_is_stable() { + let rng = SecureRandom::new(); + let first = build_client_hello( + "stable.example", + &rng, + TlsFetchProfile::ModernChromeLike, + true, + true, + ); + let second = build_client_hello( + "stable.example", + &rng, + TlsFetchProfile::ModernChromeLike, + true, + true, + ); + + assert_eq!(first, second); + } } diff --git a/src/transport/middle_proxy/config_updater.rs b/src/transport/middle_proxy/config_updater.rs index 8e5a701..ba90c1a 100644 --- a/src/transport/middle_proxy/config_updater.rs +++ b/src/transport/middle_proxy/config_updater.rs @@ -11,17 +11,19 @@ use tracing::{debug, info, warn}; use crate::config::ProxyConfig; use crate::error::Result; +use crate::transport::UpstreamManager; use super::MePool; +use super::http_fetch::https_get; use super::rotation::{MeReinitTrigger, enqueue_reinit_trigger}; -use super::secret::download_proxy_secret_with_max_len; +use super::secret::download_proxy_secret_with_max_len_via_upstream; use super::selftest::record_timeskew_sample; use std::time::SystemTime; -async fn retry_fetch(url: &str) -> Option { +async fn retry_fetch(url: &str, upstream: Option>) -> Option { let delays = [1u64, 5, 15]; for (i, d) in delays.iter().enumerate() { - match fetch_proxy_config(url).await { + match fetch_proxy_config_via_upstream(url, upstream.clone()).await { Ok(cfg) => return Some(cfg), Err(e) => { if i == delays.len() - 1 { @@ -95,14 +97,19 @@ pub async fn save_proxy_config_cache(path: &str, raw_text: &str) -> Result<()> { Ok(()) } +#[allow(dead_code)] pub async fn fetch_proxy_config_with_raw(url: &str) -> Result<(ProxyConfigData, String)> { - let resp = reqwest::get(url).await.map_err(|e| { - crate::error::ProxyError::Proxy(format!("fetch_proxy_config GET failed: {e}")) - })?; - let http_status = resp.status().as_u16(); + fetch_proxy_config_with_raw_via_upstream(url, None).await +} - if let Some(date) = resp.headers().get(reqwest::header::DATE) - && let Ok(date_str) = date.to_str() +pub async fn fetch_proxy_config_with_raw_via_upstream( + url: &str, + upstream: Option>, +) -> Result<(ProxyConfigData, String)> { + let resp = https_get(url, upstream).await?; + let http_status = resp.status; + + if let Some(date_str) = resp.date_header.as_deref() && let Ok(server_time) = httpdate::parse_http_date(date_str) && let Ok(skew) = SystemTime::now() .duration_since(server_time) @@ -123,9 +130,7 @@ pub async fn fetch_proxy_config_with_raw(url: &str) -> Result<(ProxyConfigData, } } - let text = resp.text().await.map_err(|e| { - crate::error::ProxyError::Proxy(format!("fetch_proxy_config read failed: {e}")) - })?; + let text = String::from_utf8_lossy(&resp.body).into_owned(); let parsed = parse_proxy_config_text(&text, http_status); Ok((parsed, text)) } @@ -260,8 +265,16 @@ fn parse_proxy_line(line: &str) -> Option<(i32, IpAddr, u16)> { Some((dc, ip, port)) } +#[allow(dead_code)] pub async fn fetch_proxy_config(url: &str) -> Result { - fetch_proxy_config_with_raw(url) + fetch_proxy_config_via_upstream(url, None).await +} + +pub async fn fetch_proxy_config_via_upstream( + url: &str, + upstream: Option>, +) -> Result { + fetch_proxy_config_with_raw_via_upstream(url, upstream) .await .map(|(parsed, _raw)| parsed) } @@ -300,6 +313,7 @@ async fn run_update_cycle( state: &mut UpdaterState, reinit_tx: &mpsc::Sender, ) { + let upstream = pool.upstream.clone(); pool.update_runtime_reinit_policy( cfg.general.hardswap, cfg.general.me_pool_drain_ttl_secs, @@ -354,7 +368,7 @@ async fn run_update_cycle( let mut maps_changed = false; let mut ready_v4: Option<(ProxyConfigData, u64)> = None; - let cfg_v4 = retry_fetch("https://core.telegram.org/getProxyConfig").await; + let cfg_v4 = retry_fetch("https://core.telegram.org/getProxyConfig", upstream.clone()).await; if let Some(cfg_v4) = cfg_v4 && snapshot_passes_guards(cfg, &cfg_v4, "getProxyConfig") { @@ -378,7 +392,11 @@ async fn run_update_cycle( } let mut ready_v6: Option<(ProxyConfigData, u64)> = None; - let cfg_v6 = retry_fetch("https://core.telegram.org/getProxyConfigV6").await; + let cfg_v6 = retry_fetch( + "https://core.telegram.org/getProxyConfigV6", + upstream.clone(), + ) + .await; if let Some(cfg_v6) = cfg_v6 && snapshot_passes_guards(cfg, &cfg_v6, "getProxyConfigV6") { @@ -456,7 +474,12 @@ async fn run_update_cycle( pool.reset_stun_state(); if cfg.general.proxy_secret_rotate_runtime { - match download_proxy_secret_with_max_len(cfg.general.proxy_secret_len_max).await { + match download_proxy_secret_with_max_len_via_upstream( + cfg.general.proxy_secret_len_max, + upstream, + ) + .await + { Ok(secret) => { let secret_hash = hash_secret(&secret); let stable_hits = state.secret.observe(secret_hash); diff --git a/src/transport/middle_proxy/http_fetch.rs b/src/transport/middle_proxy/http_fetch.rs new file mode 100644 index 0000000..5be601e --- /dev/null +++ b/src/transport/middle_proxy/http_fetch.rs @@ -0,0 +1,183 @@ +use std::sync::Arc; +use std::time::Duration; + +use http_body_util::{BodyExt, Empty}; +use hyper::header::{CONNECTION, DATE, HOST, USER_AGENT}; +use hyper::{Method, Request}; +use hyper_util::rt::TokioIo; +use rustls::pki_types::ServerName; +use tokio::net::TcpStream; +use tokio::time::timeout; +use tokio_rustls::TlsConnector; +use tracing::debug; + +use crate::error::{ProxyError, Result}; +use crate::network::dns_overrides::resolve_socket_addr; +use crate::transport::{UpstreamManager, UpstreamStream}; + +const HTTP_CONNECT_TIMEOUT: Duration = Duration::from_secs(10); +const HTTP_REQUEST_TIMEOUT: Duration = Duration::from_secs(15); + +pub(crate) struct HttpsGetResponse { + pub(crate) status: u16, + pub(crate) date_header: Option, + pub(crate) body: Vec, +} + +fn build_tls_client_config() -> Arc { + let mut root_store = rustls::RootCertStore::empty(); + root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); + let provider = rustls::crypto::ring::default_provider(); + let config = rustls::ClientConfig::builder_with_provider(Arc::new(provider)) + .with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12]) + .expect("HTTPS fetch rustls protocol versions must be valid") + .with_root_certificates(root_store) + .with_no_client_auth(); + Arc::new(config) +} + +fn extract_host_port_path(url: &str) -> Result<(String, u16, String)> { + let parsed = + url::Url::parse(url).map_err(|e| ProxyError::Proxy(format!("invalid URL '{url}': {e}")))?; + if parsed.scheme() != "https" { + return Err(ProxyError::Proxy(format!( + "unsupported URL scheme '{}': only https is supported", + parsed.scheme() + ))); + } + + let host = parsed + .host_str() + .ok_or_else(|| ProxyError::Proxy(format!("URL has no host: {url}")))? + .to_string(); + let port = parsed + .port_or_known_default() + .ok_or_else(|| ProxyError::Proxy(format!("URL has no known port: {url}")))?; + + let mut path = parsed.path().to_string(); + if path.is_empty() { + path.push('/'); + } + if let Some(query) = parsed.query() { + path.push('?'); + path.push_str(query); + } + + Ok((host, port, path)) +} + +async fn resolve_target_addr(host: &str, port: u16) -> Result { + if let Some(addr) = resolve_socket_addr(host, port) { + return Ok(addr); + } + + let addrs: Vec = tokio::net::lookup_host((host, port)) + .await + .map_err(|e| ProxyError::Proxy(format!("DNS resolve failed for {host}:{port}: {e}")))? + .collect(); + + if let Some(addr) = addrs.iter().copied().find(|addr| addr.is_ipv4()) { + return Ok(addr); + } + + addrs + .first() + .copied() + .ok_or_else(|| ProxyError::Proxy(format!("DNS returned no addresses for {host}:{port}"))) +} + +async fn connect_https_transport( + host: &str, + port: u16, + upstream: Option>, +) -> Result { + if let Some(manager) = upstream { + let target = resolve_target_addr(host, port).await?; + return timeout(HTTP_CONNECT_TIMEOUT, manager.connect(target, None, None)) + .await + .map_err(|_| ProxyError::Proxy(format!("upstream connect timeout for {host}:{port}")))? + .map_err(|e| { + ProxyError::Proxy(format!("upstream connect failed for {host}:{port}: {e}")) + }); + } + + if let Some(addr) = resolve_socket_addr(host, port) { + let stream = timeout(HTTP_CONNECT_TIMEOUT, TcpStream::connect(addr)) + .await + .map_err(|_| ProxyError::Proxy(format!("connect timeout for {host}:{port}")))? + .map_err(|e| ProxyError::Proxy(format!("connect failed for {host}:{port}: {e}")))?; + return Ok(UpstreamStream::Tcp(stream)); + } + + let stream = timeout(HTTP_CONNECT_TIMEOUT, TcpStream::connect((host, port))) + .await + .map_err(|_| ProxyError::Proxy(format!("connect timeout for {host}:{port}")))? + .map_err(|e| ProxyError::Proxy(format!("connect failed for {host}:{port}: {e}")))?; + Ok(UpstreamStream::Tcp(stream)) +} + +pub(crate) async fn https_get( + url: &str, + upstream: Option>, +) -> Result { + let (host, port, path_and_query) = extract_host_port_path(url)?; + let stream = connect_https_transport(&host, port, upstream).await?; + + let server_name = ServerName::try_from(host.clone()) + .map_err(|_| ProxyError::Proxy(format!("invalid TLS server name: {host}")))?; + let connector = TlsConnector::from(build_tls_client_config()); + let tls_stream = timeout(HTTP_REQUEST_TIMEOUT, connector.connect(server_name, stream)) + .await + .map_err(|_| ProxyError::Proxy(format!("TLS handshake timeout for {host}:{port}")))? + .map_err(|e| ProxyError::Proxy(format!("TLS handshake failed for {host}:{port}: {e}")))?; + + let (mut sender, connection) = hyper::client::conn::http1::handshake(TokioIo::new(tls_stream)) + .await + .map_err(|e| ProxyError::Proxy(format!("HTTP handshake failed for {host}:{port}: {e}")))?; + + tokio::spawn(async move { + if let Err(e) = connection.await { + debug!(error = %e, "HTTPS fetch connection task failed"); + } + }); + + let host_header = if port == 443 { + host.clone() + } else { + format!("{host}:{port}") + }; + + let request = Request::builder() + .method(Method::GET) + .uri(path_and_query) + .header(HOST, host_header) + .header(USER_AGENT, "telemt-middle-proxy/1") + .header(CONNECTION, "close") + .body(Empty::::new()) + .map_err(|e| ProxyError::Proxy(format!("build HTTP request failed for {url}: {e}")))?; + + let response = timeout(HTTP_REQUEST_TIMEOUT, sender.send_request(request)) + .await + .map_err(|_| ProxyError::Proxy(format!("HTTP request timeout for {url}")))? + .map_err(|e| ProxyError::Proxy(format!("HTTP request failed for {url}: {e}")))?; + + let status = response.status().as_u16(); + let date_header = response + .headers() + .get(DATE) + .and_then(|value| value.to_str().ok()) + .map(|value| value.to_string()); + + let body = timeout(HTTP_REQUEST_TIMEOUT, response.into_body().collect()) + .await + .map_err(|_| ProxyError::Proxy(format!("HTTP body read timeout for {url}")))? + .map_err(|e| ProxyError::Proxy(format!("HTTP body read failed for {url}: {e}")))? + .to_bytes() + .to_vec(); + + Ok(HttpsGetResponse { + status, + date_header, + body, + }) +} diff --git a/src/transport/middle_proxy/mod.rs b/src/transport/middle_proxy/mod.rs index 5536869..6dfbee6 100644 --- a/src/transport/middle_proxy/mod.rs +++ b/src/transport/middle_proxy/mod.rs @@ -13,6 +13,7 @@ mod health_integration_tests; #[cfg(test)] #[path = "tests/health_regression_tests.rs"] mod health_regression_tests; +mod http_fetch; mod ping; mod pool; mod pool_config; @@ -44,7 +45,8 @@ use bytes::Bytes; #[allow(unused_imports)] pub use config_updater::{ - ProxyConfigData, fetch_proxy_config, fetch_proxy_config_with_raw, load_proxy_config_cache, + ProxyConfigData, fetch_proxy_config, fetch_proxy_config_via_upstream, + fetch_proxy_config_with_raw, fetch_proxy_config_with_raw_via_upstream, load_proxy_config_cache, me_config_updater, save_proxy_config_cache, }; pub use health::{me_drain_timeout_enforcer, me_health_monitor, me_zombie_writer_watchdog}; @@ -57,7 +59,8 @@ pub use pool::MePool; pub use pool_nat::{detect_public_ip, stun_probe}; pub use registry::ConnRegistry; pub use rotation::{MeReinitTrigger, me_reinit_scheduler, me_rotation_task}; -pub use secret::fetch_proxy_secret; +#[allow(unused_imports)] +pub use secret::{fetch_proxy_secret, fetch_proxy_secret_with_upstream}; pub(crate) use selftest::{bnd_snapshot, timeskew_snapshot, upstream_bnd_snapshots}; pub use wire::proto_flags_for_tag; diff --git a/src/transport/middle_proxy/pool_status.rs b/src/transport/middle_proxy/pool_status.rs index 918ccd4..1ef59e1 100644 --- a/src/transport/middle_proxy/pool_status.rs +++ b/src/transport/middle_proxy/pool_status.rs @@ -293,9 +293,7 @@ impl MePool { WriterContour::Draining => "draining", }; - if !draining - && let Some(dc_idx) = dc - { + if !draining && let Some(dc_idx) = dc { *live_writers_by_dc_endpoint .entry((dc_idx, endpoint)) .or_insert(0) += 1; diff --git a/src/transport/middle_proxy/secret.rs b/src/transport/middle_proxy/secret.rs index 504270a..a167773 100644 --- a/src/transport/middle_proxy/secret.rs +++ b/src/transport/middle_proxy/secret.rs @@ -1,9 +1,12 @@ use httpdate; +use std::sync::Arc; use std::time::SystemTime; use tracing::{debug, info, warn}; +use super::http_fetch::https_get; use super::selftest::record_timeskew_sample; use crate::error::{ProxyError, Result}; +use crate::transport::UpstreamManager; pub const PROXY_SECRET_MIN_LEN: usize = 32; @@ -33,11 +36,21 @@ pub(super) fn validate_proxy_secret_len(data_len: usize, max_len: usize) -> Resu } /// Fetch Telegram proxy-secret binary. +#[allow(dead_code)] pub async fn fetch_proxy_secret(cache_path: Option<&str>, max_len: usize) -> Result> { + fetch_proxy_secret_with_upstream(cache_path, max_len, None).await +} + +/// Fetch Telegram proxy-secret binary, optionally through upstream routing. +pub async fn fetch_proxy_secret_with_upstream( + cache_path: Option<&str>, + max_len: usize, + upstream: Option>, +) -> Result> { let cache = cache_path.unwrap_or("proxy-secret"); // 1) Try fresh download first. - match download_proxy_secret_with_max_len(max_len).await { + match download_proxy_secret_with_max_len_via_upstream(max_len, upstream).await { Ok(data) => { if let Err(e) = tokio::fs::write(cache, &data).await { warn!(error = %e, "Failed to cache proxy-secret (non-fatal)"); @@ -76,20 +89,25 @@ pub async fn fetch_proxy_secret(cache_path: Option<&str>, max_len: usize) -> Res } } +#[allow(dead_code)] pub async fn download_proxy_secret_with_max_len(max_len: usize) -> Result> { - let resp = reqwest::get("https://core.telegram.org/getProxySecret") - .await - .map_err(|e| ProxyError::Proxy(format!("Failed to download proxy-secret: {e}")))?; + download_proxy_secret_with_max_len_via_upstream(max_len, None).await +} - if !resp.status().is_success() { +pub async fn download_proxy_secret_with_max_len_via_upstream( + max_len: usize, + upstream: Option>, +) -> Result> { + let resp = https_get("https://core.telegram.org/getProxySecret", upstream).await?; + + if !(200..=299).contains(&resp.status) { return Err(ProxyError::Proxy(format!( "proxy-secret download HTTP {}", - resp.status() + resp.status ))); } - if let Some(date) = resp.headers().get(reqwest::header::DATE) - && let Ok(date_str) = date.to_str() + if let Some(date_str) = resp.date_header.as_deref() && let Ok(server_time) = httpdate::parse_http_date(date_str) && let Ok(skew) = SystemTime::now() .duration_since(server_time) @@ -110,11 +128,7 @@ pub async fn download_proxy_secret_with_max_len(max_len: usize) -> Result>)> = { let pools = self.pools.read(); - pools.iter().map(|(addr, pool)| (*addr, Arc::clone(pool))).collect() + pools + .iter() + .map(|(addr, pool)| (*addr, Arc::clone(pool))) + .collect() }; for (addr, pool) in pools_snapshot {