Compare commits

..

30 Commits

Author SHA1 Message Date
Alexey 3bdbba8777
Merge 2d3c2807ab into a95678988a 2026-03-21 18:51:38 +00:00
Alexey 2d3c2807ab
Merge pull request #531 from DavidOsipov/flow
Small brittle test fix
2026-03-21 21:51:33 +03:00
David Osipov 50ae16ddf7
Add interval_gap_usize function and enhance integration test assertions for class separability 2026-03-21 22:49:39 +04:00
Alexey de5c26b7d7
Merge branch 'main' into flow 2026-03-21 21:46:45 +03:00
Alexey a95678988a
Merge pull request #530 from telemt/workflow
Update release.yml
2026-03-21 21:45:23 +03:00
Alexey b17482ede3
Update release.yml 2026-03-21 21:45:01 +03:00
Alexey a059de9191
Merge pull request #529 from DavidOsipov/flow
Усиление обхода DPI (Shape/Timing Hardening), защита от тайминг-атак и масштабное покрытие тестами
2026-03-21 21:31:05 +03:00
David Osipov e7e763888b
Implement aggressive shape hardening mode and related tests 2026-03-21 22:25:29 +04:00
David Osipov c0a3e43aa8
Add comprehensive security tests for proxy functionality
- Introduced client TLS record wrapping tests to ensure correct handling of empty and oversized payloads.
- Added integration tests for middle relay to validate quota saturation behavior under concurrent pressure.
- Implemented high-risk security tests covering various payload scenarios, including alignment checks and boundary conditions.
- Developed length cast hardening tests to verify proper handling of wire lengths and overflow conditions.
- Created quota overflow lock tests to ensure stable behavior under saturation and reclaim scenarios.
- Refactored existing middle relay security tests for improved clarity and consistency in lock handling.
2026-03-21 20:54:13 +04:00
David Osipov 4c32370b25
Refactor proxy and transport modules for improved safety and performance
- Enhanced linting rules in `src/proxy/mod.rs` to enforce stricter code quality checks in production.
- Updated hash functions in `src/proxy/middle_relay.rs` for better efficiency.
- Added new security tests in `src/proxy/tests/middle_relay_stub_completion_security_tests.rs` to validate desynchronization behavior.
- Removed ignored test stubs in `src/proxy/tests/middle_relay_security_tests.rs` to clean up the test suite.
- Improved error handling and code readability in various transport modules, including `src/transport/middle_proxy/config_updater.rs` and `src/transport/middle_proxy/pool.rs`.
- Introduced new padding functions in `src/stream/frame_stream_padding_security_tests.rs` to ensure consistent behavior across different implementations.
- Adjusted TLS stream validation in `src/stream/tls_stream.rs` for better boundary checking.
- General code cleanup and dead code elimination across multiple files to enhance maintainability.
2026-03-21 20:05:07 +04:00
Alexey a6c298b633
Merge branch 'main' into flow 2026-03-21 16:54:47 +03:00
Alexey e7a1d26e6e
Merge pull request #526 from telemt/workflow
Update release.yml
2026-03-21 16:48:53 +03:00
Alexey b91c6cb339
Update release.yml 2026-03-21 16:48:42 +03:00
Alexey e676633dcd
Merge branch 'main' into flow 2026-03-21 16:32:24 +03:00
Alexey c4e7f54cbe
Merge pull request #524 from telemt/workflow
Update release.yml
2026-03-21 16:31:15 +03:00
Alexey f85205d48d
Update release.yml 2026-03-21 16:31:05 +03:00
Alexey d767ec02ee
Update release.yml 2026-03-21 16:24:06 +03:00
Alexey 51835c33f2
Merge branch 'main' into flow 2026-03-21 16:19:02 +03:00
Alexey 88a4c652b6
Merge pull request #523 from telemt/workflow
Update release.yml
2026-03-21 16:18:48 +03:00
Alexey ea2d964502
Update release.yml 2026-03-21 16:18:24 +03:00
Alexey bd7218c39c
Merge branch 'main' into flow 2026-03-21 16:06:03 +03:00
Alexey 3055637571
Merge pull request #522 from telemt/workflow
Update release.yml
2026-03-21 16:04:56 +03:00
Alexey 19b84b9d73
Update release.yml 2026-03-21 16:03:54 +03:00
Alexey 165a1ede57
Merge branch 'main' into flow 2026-03-21 15:58:53 +03:00
Alexey 6ead8b1922
Merge pull request #521 from telemt/workflow
Update release.yml
2026-03-21 15:58:36 +03:00
Alexey 63aa1038c0
Update release.yml 2026-03-21 15:58:25 +03:00
Alexey 4473826303
Update crypto_bench.rs 2026-03-21 15:48:28 +03:00
Alexey d7bbb376c9
Format 2026-03-21 15:45:29 +03:00
Alexey 7a8f946029
Update Cargo.lock 2026-03-21 15:35:03 +03:00
Alexey f2e6dc1774
Update Cargo.toml 2026-03-21 15:27:21 +03:00
173 changed files with 8934 additions and 4201 deletions

View File

@ -21,16 +21,13 @@ env:
jobs: jobs:
prepare: prepare:
name: Prepare metadata
runs-on: ubuntu-latest runs-on: ubuntu-latest
outputs: outputs:
version: ${{ steps.meta.outputs.version }} version: ${{ steps.meta.outputs.version }}
prerelease: ${{ steps.meta.outputs.prerelease }} prerelease: ${{ steps.meta.outputs.prerelease }}
release_enabled: ${{ steps.meta.outputs.release_enabled }} release_enabled: ${{ steps.meta.outputs.release_enabled }}
steps: steps:
- name: Derive version - id: meta
id: meta
shell: bash
run: | run: |
set -euo pipefail set -euo pipefail
@ -53,67 +50,38 @@ jobs:
echo "release_enabled=$RELEASE_ENABLED" >> "$GITHUB_OUTPUT" echo "release_enabled=$RELEASE_ENABLED" >> "$GITHUB_OUTPUT"
checks: checks:
name: Checks
runs-on: ubuntu-latest runs-on: ubuntu-latest
container: container:
image: debian:trixie image: debian:trixie
steps: steps:
- name: Install system dependencies - run: |
shell: bash
run: |
set -euo pipefail
apt-get update apt-get update
apt-get install -y --no-install-recommends \ apt-get install -y build-essential clang llvm pkg-config curl git
ca-certificates \
curl \
git \
build-essential \
pkg-config \
clang \
llvm \
python3 \
python3-pip
update-ca-certificates
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: dtolnay/rust-toolchain@stable - uses: dtolnay/rust-toolchain@stable
with: with:
components: rustfmt, clippy components: rustfmt, clippy
- name: Cache cargo - uses: actions/cache@v4
uses: actions/cache@v4
with: with:
path: | path: |
/github/home/.cargo/registry /github/home/.cargo/registry
/github/home/.cargo/git /github/home/.cargo/git
target target
key: checks-${{ runner.os }}-${{ hashFiles('**/Cargo.lock') }} key: checks-${{ hashFiles('**/Cargo.lock') }}
restore-keys: |
checks-${{ runner.os }}-
- name: Cargo fetch - run: cargo fetch --locked
shell: bash - run: cargo fmt --all -- --check
run: cargo fetch --locked - run: cargo clippy
- run: cargo test
- name: Format
shell: bash
run: cargo fmt --all -- --check
- name: Clippy
shell: bash
run: cargo clippy --workspace --all-targets --locked -- -D warnings
- name: Tests
shell: bash
run: cargo test --workspace --all-targets --locked
build-binaries: build-binaries:
name: Build ${{ matrix.asset_name }}
needs: [prepare, checks] needs: [prepare, checks]
runs-on: ubuntu-latest runs-on: ubuntu-latest
container: container:
image: debian:trixie image: debian:trixie
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
@ -132,156 +100,80 @@ jobs:
asset_name: telemt-aarch64-linux-musl asset_name: telemt-aarch64-linux-musl
steps: steps:
- name: Install system dependencies - run: |
shell: bash
run: |
set -euo pipefail
apt-get update apt-get update
apt-get install -y --no-install-recommends \ apt-get install -y clang llvm pkg-config curl git python3 python3-pip file tar xz-utils
ca-certificates \
curl \
git \
build-essential \
pkg-config \
clang \
llvm \
file \
tar \
xz-utils \
python3 \
python3-pip
update-ca-certificates
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: dtolnay/rust-toolchain@stable - uses: dtolnay/rust-toolchain@stable
with: with:
targets: ${{ matrix.rust_target }} targets: ${{ matrix.rust_target }}
- name: Cache cargo - uses: actions/cache@v4
uses: actions/cache@v4
with: with:
path: | path: |
/github/home/.cargo/registry /github/home/.cargo/registry
/github/home/.cargo/git /github/home/.cargo/git
target target
key: build-${{ matrix.zig_target }}-${{ hashFiles('**/Cargo.lock') }} key: build-${{ matrix.zig_target }}-${{ hashFiles('**/Cargo.lock') }}
restore-keys: |
build-${{ matrix.zig_target }}-
- name: Install cargo-zigbuild + Zig - run: |
shell: bash
run: |
set -euo pipefail
python3 -m pip install --user --break-system-packages cargo-zigbuild python3 -m pip install --user --break-system-packages cargo-zigbuild
echo "/github/home/.local/bin" >> "$GITHUB_PATH" echo "/github/home/.local/bin" >> "$GITHUB_PATH"
- name: Cargo fetch - run: cargo fetch --locked
shell: bash
run: cargo fetch --locked
- name: Build release - run: |
shell: bash
env:
CARGO_PROFILE_RELEASE_LTO: "fat"
CARGO_PROFILE_RELEASE_CODEGEN_UNITS: "1"
CARGO_PROFILE_RELEASE_PANIC: "abort"
run: |
set -euo pipefail
cargo zigbuild --release --locked --target "${{ matrix.zig_target }}" cargo zigbuild --release --locked --target "${{ matrix.zig_target }}"
- name: Strip binary - run: |
shell: bash BIN="target/${{ matrix.rust_target }}/release/${BINARY_NAME}"
run: | llvm-strip "$BIN" || true
set -euo pipefail
llvm-strip "target/${{ matrix.zig_target }}/release/${BINARY_NAME}" || true
- name: Inspect binary - run: |
shell: bash BIN="target/${{ matrix.rust_target }}/release/${BINARY_NAME}"
run: | OUT="$RUNNER_TEMP/${{ matrix.asset_name }}"
set -euo pipefail mkdir -p "$OUT"
file "target/${{ matrix.zig_target }}/release/${BINARY_NAME}" install -m755 "$BIN" "$OUT/${BINARY_NAME}"
- name: Package tar -C "$RUNNER_TEMP" -czf "${{ matrix.asset_name }}.tar.gz" "${{ matrix.asset_name }}"
shell: bash sha256sum "${{ matrix.asset_name }}.tar.gz" > "${{ matrix.asset_name }}.sha256"
run: |
set -euo pipefail
OUTDIR="$RUNNER_TEMP/pkg/${{ matrix.asset_name }}"
mkdir -p "$OUTDIR"
install -m 0755 "target/${{ matrix.zig_target }}/release/${BINARY_NAME}" "$OUTDIR/${BINARY_NAME}"
if [[ -f LICENSE ]]; then cp LICENSE "$OUTDIR/"; fi
if [[ -f README.md ]]; then cp README.md "$OUTDIR/"; fi
cat > "$OUTDIR/BUILD-INFO.txt" <<EOF
project=${GITHUB_REPOSITORY}
version=${{ needs.prepare.outputs.version }}
git_ref=${GITHUB_REF}
git_sha=${GITHUB_SHA}
rust_target=${{ matrix.rust_target }}
zig_target=${{ matrix.zig_target }}
built_at=$(date -u +%Y-%m-%dT%H:%M:%SZ)
EOF
mkdir -p dist
tar -C "$RUNNER_TEMP/pkg" -czf "dist/${{ matrix.asset_name }}.tar.gz" "${{ matrix.asset_name }}"
sha256sum "dist/${{ matrix.asset_name }}.tar.gz" > "dist/${{ matrix.asset_name }}.sha256"
- uses: actions/upload-artifact@v4 - uses: actions/upload-artifact@v4
with: with:
name: ${{ matrix.asset_name }} name: ${{ matrix.asset_name }}
path: | path: |
dist/${{ matrix.asset_name }}.tar.gz ${{ matrix.asset_name }}.tar.gz
dist/${{ matrix.asset_name }}.sha256 ${{ matrix.asset_name }}.sha256
if-no-files-found: error
retention-days: 14
attest-binaries:
name: Attest binary archives
needs: build-binaries
runs-on: ubuntu-latest
permissions:
contents: read
attestations: write
id-token: write
steps:
- uses: actions/download-artifact@v4
with:
path: dist
- name: Flatten artifacts
shell: bash
run: |
set -euo pipefail
mkdir -p upload
find dist -type f \( -name '*.tar.gz' -o -name '*.sha256' \) -exec cp {} upload/ \;
ls -lah upload
- name: Attest release archives
uses: actions/attest-build-provenance@v3
with:
subject-path: 'upload/*.tar.gz'
docker-image: docker-image:
name: Build and push GHCR image name: Docker ${{ matrix.platform }}
needs: [prepare, checks] needs: [prepare, build-binaries]
runs-on: ubuntu-latest runs-on: ubuntu-latest
permissions:
contents: read strategy:
packages: write matrix:
include:
- platform: linux/amd64
artifact: telemt-x86_64-linux-gnu
- platform: linux/arm64
artifact: telemt-aarch64-linux-gnu
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- name: Set up QEMU - uses: actions/download-artifact@v4
uses: docker/setup-qemu-action@v3 with:
name: ${{ matrix.artifact }}
path: dist
- name: Set up Buildx - run: |
uses: docker/setup-buildx-action@v3 mkdir docker-build
tar -xzf dist/*.tar.gz -C docker-build --strip-components=1
- name: Log in to GHCR - uses: docker/setup-buildx-action@v3
- name: Login
if: ${{ needs.prepare.outputs.release_enabled == 'true' }} if: ${{ needs.prepare.outputs.release_enabled == 'true' }}
uses: docker/login-action@v3 uses: docker/login-action@v3
with: with:
@ -289,43 +181,20 @@ jobs:
username: ${{ github.actor }} username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }} password: ${{ secrets.GITHUB_TOKEN }}
- name: Docker metadata - uses: docker/build-push-action@v6
id: meta
uses: docker/metadata-action@v5
with: with:
images: ghcr.io/${{ github.repository }} context: ./docker-build
tags: | platforms: ${{ matrix.platform }}
type=raw,value=${{ needs.prepare.outputs.version }}
type=raw,value=latest,enable=${{ needs.prepare.outputs.prerelease != 'true' && needs.prepare.outputs.release_enabled == 'true' }}
labels: |
org.opencontainers.image.title=telemt
org.opencontainers.image.description=telemt
org.opencontainers.image.source=https://github.com/${{ github.repository }}
org.opencontainers.image.version=${{ needs.prepare.outputs.version }}
org.opencontainers.image.revision=${{ github.sha }}
- name: Build and push
id: build
uses: docker/build-push-action@v6
with:
context: .
file: ./Dockerfile
platforms: linux/amd64,linux/arm64
push: ${{ needs.prepare.outputs.release_enabled == 'true' }} push: ${{ needs.prepare.outputs.release_enabled == 'true' }}
tags: ${{ steps.meta.outputs.tags }} tags: ghcr.io/${{ github.repository }}:${{ needs.prepare.outputs.version }}
labels: ${{ steps.meta.outputs.labels }} cache-from: type=gha,scope=telemt-${{ matrix.platform }}
cache-from: type=gha cache-to: type=gha,mode=max,scope=telemt-${{ matrix.platform }}
cache-to: type=gha,mode=max provenance: false
provenance: mode=max sbom: false
sbom: true
build-args: |
TELEMT_VERSION=${{ needs.prepare.outputs.version }}
VCS_REF=${{ github.sha }}
release: release:
name: Create GitHub Release
if: ${{ needs.prepare.outputs.release_enabled == 'true' }} if: ${{ needs.prepare.outputs.release_enabled == 'true' }}
needs: [prepare, build-binaries, attest-binaries, docker-image] needs: [prepare, build-binaries]
runs-on: ubuntu-latest runs-on: ubuntu-latest
permissions: permissions:
contents: write contents: write
@ -334,19 +203,14 @@ jobs:
- uses: actions/download-artifact@v4 - uses: actions/download-artifact@v4
with: with:
path: release-artifacts path: release-artifacts
pattern: telemt-*
- name: Flatten artifacts - run: |
shell: bash mkdir upload
run: |
set -euo pipefail
mkdir -p upload
find release-artifacts -type f \( -name '*.tar.gz' -o -name '*.sha256' \) -exec cp {} upload/ \; find release-artifacts -type f \( -name '*.tar.gz' -o -name '*.sha256' \) -exec cp {} upload/ \;
ls -lah upload
- name: Create release - uses: softprops/action-gh-release@v2
uses: softprops/action-gh-release@v2
with: with:
files: upload/* files: upload/*
generate_release_notes: true generate_release_notes: true
draft: false
prerelease: ${{ needs.prepare.outputs.prerelease == 'true' }} prerelease: ${{ needs.prepare.outputs.prerelease == 'true' }}

6
Cargo.lock generated
View File

@ -90,9 +90,9 @@ checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c"
[[package]] [[package]]
name = "arc-swap" name = "arc-swap"
version = "1.8.2" version = "1.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f9f3647c145568cec02c42054e07bdf9a5a698e15b466fb2341bfc393cd24aa5" checksum = "a07d1f37ff60921c83bdfc7407723bdefe89b44b98a9b772f225c8f9d67141a6"
dependencies = [ dependencies = [
"rustversion", "rustversion",
] ]
@ -2771,7 +2771,7 @@ checksum = "7b2093cf4c8eb1e67749a6762251bc9cd836b6fc171623bd0a9d324d37af2417"
[[package]] [[package]]
name = "telemt" name = "telemt"
version = "3.4.0" version = "3.3.29"
dependencies = [ dependencies = [
"aes", "aes",
"anyhow", "anyhow",

View File

@ -1,6 +1,6 @@
[package] [package]
name = "telemt" name = "telemt"
version = "3.4.0" version = "3.3.29"
edition = "2024" edition = "2024"
[dependencies] [dependencies]
@ -27,7 +27,7 @@ static_assertions = "1.1"
# Network # Network
socket2 = { version = "0.6", features = ["all"] } socket2 = { version = "0.6", features = ["all"] }
nix = { version = "0.31", default-features = false, features = ["net"] } nix = { version = "0.31", default-features = false, features = ["net", "fs"] }
shadowsocks = { version = "1.24", features = ["aead-cipher-2022"] } shadowsocks = { version = "1.24", features = ["aead-cipher-2022"] }
# Serialization # Serialization

View File

@ -1,5 +1,5 @@
// Cryptobench // Cryptobench
use criterion::{black_box, criterion_group, Criterion}; use criterion::{Criterion, black_box, criterion_group};
fn bench_aes_ctr(c: &mut Criterion) { fn bench_aes_ctr(c: &mut Criterion) {
c.bench_function("aes_ctr_encrypt_64kb", |b| { c.bench_function("aes_ctr_encrypt_64kb", |b| {

View File

@ -261,6 +261,7 @@ This document lists all configuration keys accepted by `config.toml`.
| alpn_enforce | `bool` | `true` | — | Enforces ALPN echo behavior based on client preference. | | alpn_enforce | `bool` | `true` | — | Enforces ALPN echo behavior based on client preference. |
| mask_proxy_protocol | `u8` | `0` | — | PROXY protocol mode for mask backend (`0` disabled, `1` v1, `2` v2). | | mask_proxy_protocol | `u8` | `0` | — | PROXY protocol mode for mask backend (`0` disabled, `1` v1, `2` v2). |
| mask_shape_hardening | `bool` | `true` | — | Enables client->mask shape-channel hardening by applying controlled tail padding to bucket boundaries on mask relay shutdown. | | mask_shape_hardening | `bool` | `true` | — | Enables client->mask shape-channel hardening by applying controlled tail padding to bucket boundaries on mask relay shutdown. |
| mask_shape_hardening_aggressive_mode | `bool` | `false` | Requires `mask_shape_hardening = true`. | Opt-in aggressive shaping profile: allows shaping on backend-silent non-EOF paths and switches above-cap blur to strictly positive random tail. |
| mask_shape_bucket_floor_bytes | `usize` | `512` | Must be `> 0`; should be `<= mask_shape_bucket_cap_bytes`. | Minimum bucket size used by shape-channel hardening. | | mask_shape_bucket_floor_bytes | `usize` | `512` | Must be `> 0`; should be `<= mask_shape_bucket_cap_bytes`. | Minimum bucket size used by shape-channel hardening. |
| mask_shape_bucket_cap_bytes | `usize` | `4096` | Must be `>= mask_shape_bucket_floor_bytes`. | Maximum bucket size used by shape-channel hardening; traffic above cap is not padded further. | | mask_shape_bucket_cap_bytes | `usize` | `4096` | Must be `>= mask_shape_bucket_floor_bytes`. | Maximum bucket size used by shape-channel hardening; traffic above cap is not padded further. |
| mask_shape_above_cap_blur | `bool` | `false` | Requires `mask_shape_hardening = true`; requires `mask_shape_above_cap_blur_max_bytes > 0`. | Adds bounded randomized tail bytes even when forwarded size already exceeds cap. | | mask_shape_above_cap_blur | `bool` | `false` | Requires `mask_shape_hardening = true`; requires `mask_shape_above_cap_blur_max_bytes > 0`. | Adds bounded randomized tail bytes even when forwarded size already exceeds cap. |
@ -284,6 +285,27 @@ When `mask_shape_hardening = true`, Telemt pads the **client->mask** stream tail
This means multiple nearby probe sizes collapse into the same backend-observed size class, making active classification harder. This means multiple nearby probe sizes collapse into the same backend-observed size class, making active classification harder.
What each parameter changes in practice:
- `mask_shape_hardening`
Enables or disables this entire length-shaping stage on the fallback path.
When `false`, backend-observed length stays close to the real forwarded probe length.
When `true`, clean relay shutdown can append random padding bytes to move the total into a bucket.
- `mask_shape_bucket_floor_bytes`
Sets the first bucket boundary used for small probes.
Example: with floor `512`, a malformed probe that would otherwise forward `37` bytes can be expanded to `512` bytes on clean EOF.
Larger floor values hide very small probes better, but increase egress cost.
- `mask_shape_bucket_cap_bytes`
Sets the largest bucket Telemt will pad up to with bucket logic.
Example: with cap `4096`, a forwarded total of `1800` bytes may be padded to `2048` or `4096` depending on the bucket ladder, but a total already above `4096` will not be bucket-padded further.
Larger cap values increase the range over which size classes are collapsed, but also increase worst-case overhead.
- Clean EOF matters in conservative mode
In the default profile, shape padding is intentionally conservative: it is applied on clean relay shutdown, not on every timeout/drip path.
This avoids introducing new timeout-tail artifacts that some backends or tests interpret as a separate fingerprint.
Practical trade-offs: Practical trade-offs:
- Better anti-fingerprinting on size/shape channel. - Better anti-fingerprinting on size/shape channel.
@ -296,14 +318,56 @@ Recommended starting profile:
- `mask_shape_bucket_floor_bytes = 512` - `mask_shape_bucket_floor_bytes = 512`
- `mask_shape_bucket_cap_bytes = 4096` - `mask_shape_bucket_cap_bytes = 4096`
### Aggressive mode notes (`[censorship]`)
`mask_shape_hardening_aggressive_mode` is an opt-in profile for higher anti-classifier pressure.
- Default is `false` to preserve conservative timeout/no-tail behavior.
- Requires `mask_shape_hardening = true`.
- When enabled, backend-silent non-EOF masking paths may be shaped.
- When enabled together with above-cap blur, the random extra tail uses `[1, max]` instead of `[0, max]`.
What changes when aggressive mode is enabled:
- Backend-silent timeout paths can be shaped
In default mode, a client that keeps the socket half-open and times out will usually not receive shape padding on that path.
In aggressive mode, Telemt may still shape that backend-silent session if no backend bytes were returned.
This is specifically aimed at active probes that try to avoid EOF in order to preserve an exact backend-observed length.
- Above-cap blur always adds at least one byte
In default mode, above-cap blur may choose `0`, so some oversized probes still land on their exact base forwarded length.
In aggressive mode, that exact-base sample is removed by construction.
- Tradeoff
Aggressive mode improves resistance to active length classifiers, but it is more opinionated and less conservative.
If your deployment prioritizes strict compatibility with timeout/no-tail semantics, leave it disabled.
If your threat model includes repeated active probing by a censor, this mode is the stronger profile.
Use this mode only when your threat model prioritizes classifier resistance over strict compatibility with conservative masking semantics.
### Above-cap blur notes (`[censorship]`) ### Above-cap blur notes (`[censorship]`)
`mask_shape_above_cap_blur` adds a second-stage blur for very large probes that are already above `mask_shape_bucket_cap_bytes`. `mask_shape_above_cap_blur` adds a second-stage blur for very large probes that are already above `mask_shape_bucket_cap_bytes`.
- A random tail in `[0, mask_shape_above_cap_blur_max_bytes]` is appended. - A random tail in `[0, mask_shape_above_cap_blur_max_bytes]` is appended in default mode.
- In aggressive mode, the random tail becomes strictly positive: `[1, mask_shape_above_cap_blur_max_bytes]`.
- This reduces exact-size leakage above cap at bounded overhead. - This reduces exact-size leakage above cap at bounded overhead.
- Keep `mask_shape_above_cap_blur_max_bytes` conservative to avoid unnecessary egress growth. - Keep `mask_shape_above_cap_blur_max_bytes` conservative to avoid unnecessary egress growth.
Operational meaning:
- Without above-cap blur
A probe that forwards `5005` bytes will still look like `5005` bytes to the backend if it is already above cap.
- With above-cap blur enabled
That same probe may look like any value in a bounded window above its base length.
Example with `mask_shape_above_cap_blur_max_bytes = 64`:
backend-observed size becomes `5005..5069` in default mode, or `5006..5069` in aggressive mode.
- Choosing `mask_shape_above_cap_blur_max_bytes`
Small values reduce cost but preserve more separability between far-apart oversized classes.
Larger values blur oversized classes more aggressively, but add more egress overhead and more output variance.
### Timing normalization envelope notes (`[censorship]`) ### Timing normalization envelope notes (`[censorship]`)
`mask_timing_normalization_enabled` smooths timing differences between masking outcomes by applying a target duration envelope. `mask_timing_normalization_enabled` smooths timing differences between masking outcomes by applying a target duration envelope.

View File

@ -24,10 +24,7 @@ pub(super) fn success_response<T: Serialize>(
.unwrap() .unwrap()
} }
pub(super) fn error_response( pub(super) fn error_response(request_id: u64, failure: ApiFailure) -> hyper::Response<Full<Bytes>> {
request_id: u64,
failure: ApiFailure,
) -> hyper::Response<Full<Bytes>> {
let payload = ErrorResponse { let payload = ErrorResponse {
ok: false, ok: false,
error: ErrorBody { error: ErrorBody {

View File

@ -1,3 +1,5 @@
#![allow(clippy::too_many_arguments)]
use std::convert::Infallible; use std::convert::Infallible;
use std::net::{IpAddr, SocketAddr}; use std::net::{IpAddr, SocketAddr};
use std::path::PathBuf; use std::path::PathBuf;
@ -19,8 +21,8 @@ use crate::ip_tracker::UserIpTracker;
use crate::proxy::route_mode::RouteRuntimeController; use crate::proxy::route_mode::RouteRuntimeController;
use crate::startup::StartupTracker; use crate::startup::StartupTracker;
use crate::stats::Stats; use crate::stats::Stats;
use crate::transport::middle_proxy::MePool;
use crate::transport::UpstreamManager; use crate::transport::UpstreamManager;
use crate::transport::middle_proxy::MePool;
mod config_store; mod config_store;
mod events; mod events;
@ -36,8 +38,8 @@ mod runtime_zero;
mod users; mod users;
use config_store::{current_revision, parse_if_match}; use config_store::{current_revision, parse_if_match};
use http_utils::{error_response, read_json, read_optional_json, success_response};
use events::ApiEventStore; use events::ApiEventStore;
use http_utils::{error_response, read_json, read_optional_json, success_response};
use model::{ use model::{
ApiFailure, CreateUserRequest, HealthData, PatchUserRequest, RotateSecretRequest, SummaryData, ApiFailure, CreateUserRequest, HealthData, PatchUserRequest, RotateSecretRequest, SummaryData,
}; };
@ -55,11 +57,11 @@ use runtime_stats::{
MinimalCacheEntry, build_dcs_data, build_me_writers_data, build_minimal_all_data, MinimalCacheEntry, build_dcs_data, build_me_writers_data, build_minimal_all_data,
build_upstreams_data, build_zero_all_data, build_upstreams_data, build_zero_all_data,
}; };
use runtime_watch::spawn_runtime_watchers;
use runtime_zero::{ use runtime_zero::{
build_limits_effective_data, build_runtime_gates_data, build_security_posture_data, build_limits_effective_data, build_runtime_gates_data, build_security_posture_data,
build_system_info_data, build_system_info_data,
}; };
use runtime_watch::spawn_runtime_watchers;
use users::{create_user, delete_user, patch_user, rotate_secret, users_from_config}; use users::{create_user, delete_user, patch_user, rotate_secret, users_from_config};
pub(super) struct ApiRuntimeState { pub(super) struct ApiRuntimeState {
@ -208,15 +210,15 @@ async fn handle(
)); ));
} }
if !api_cfg.whitelist.is_empty() if !api_cfg.whitelist.is_empty() && !api_cfg.whitelist.iter().any(|net| net.contains(peer.ip()))
&& !api_cfg
.whitelist
.iter()
.any(|net| net.contains(peer.ip()))
{ {
return Ok(error_response( return Ok(error_response(
request_id, request_id,
ApiFailure::new(StatusCode::FORBIDDEN, "forbidden", "Source IP is not allowed"), ApiFailure::new(
StatusCode::FORBIDDEN,
"forbidden",
"Source IP is not allowed",
),
)); ));
} }
@ -347,7 +349,8 @@ async fn handle(
} }
("GET", "/v1/runtime/connections/summary") => { ("GET", "/v1/runtime/connections/summary") => {
let revision = current_revision(&shared.config_path).await?; let revision = current_revision(&shared.config_path).await?;
let data = build_runtime_connections_summary_data(shared.as_ref(), cfg.as_ref()).await; let data =
build_runtime_connections_summary_data(shared.as_ref(), cfg.as_ref()).await;
Ok(success_response(StatusCode::OK, data, revision)) Ok(success_response(StatusCode::OK, data, revision))
} }
("GET", "/v1/runtime/events/recent") => { ("GET", "/v1/runtime/events/recent") => {
@ -389,13 +392,16 @@ async fn handle(
let (data, revision) = match result { let (data, revision) = match result {
Ok(ok) => ok, Ok(ok) => ok,
Err(error) => { Err(error) => {
shared.runtime_events.record("api.user.create.failed", error.code); shared
.runtime_events
.record("api.user.create.failed", error.code);
return Err(error); return Err(error);
} }
}; };
shared shared.runtime_events.record(
.runtime_events "api.user.create.ok",
.record("api.user.create.ok", format!("username={}", data.user.username)); format!("username={}", data.user.username),
);
Ok(success_response(StatusCode::CREATED, data, revision)) Ok(success_response(StatusCode::CREATED, data, revision))
} }
_ => { _ => {
@ -414,7 +420,8 @@ async fn handle(
detected_ip_v6, detected_ip_v6,
) )
.await; .await;
if let Some(user_info) = users.into_iter().find(|entry| entry.username == user) if let Some(user_info) =
users.into_iter().find(|entry| entry.username == user)
{ {
return Ok(success_response(StatusCode::OK, user_info, revision)); return Ok(success_response(StatusCode::OK, user_info, revision));
} }
@ -435,7 +442,8 @@ async fn handle(
)); ));
} }
let expected_revision = parse_if_match(req.headers()); let expected_revision = parse_if_match(req.headers());
let body = read_json::<PatchUserRequest>(req.into_body(), body_limit).await?; let body =
read_json::<PatchUserRequest>(req.into_body(), body_limit).await?;
let result = patch_user(user, body, expected_revision, &shared).await; let result = patch_user(user, body, expected_revision, &shared).await;
let (data, revision) = match result { let (data, revision) = match result {
Ok(ok) => ok, Ok(ok) => ok,
@ -475,10 +483,9 @@ async fn handle(
return Err(error); return Err(error);
} }
}; };
shared.runtime_events.record( shared
"api.user.delete.ok", .runtime_events
format!("username={}", deleted_user), .record("api.user.delete.ok", format!("username={}", deleted_user));
);
return Ok(success_response(StatusCode::OK, deleted_user, revision)); return Ok(success_response(StatusCode::OK, deleted_user, revision));
} }
if method == Method::POST if method == Method::POST

View File

@ -167,11 +167,7 @@ async fn current_me_pool_stage_progress(shared: &ApiShared) -> Option<f64> {
let pool = shared.me_pool.read().await.clone()?; let pool = shared.me_pool.read().await.clone()?;
let status = pool.api_status_snapshot().await; let status = pool.api_status_snapshot().await;
let configured_dc_groups = status.configured_dc_groups; let configured_dc_groups = status.configured_dc_groups;
let covered_dc_groups = status let covered_dc_groups = status.dcs.iter().filter(|dc| dc.alive_writers > 0).count();
.dcs
.iter()
.filter(|dc| dc.alive_writers > 0)
.count();
let dc_coverage = ratio_01(covered_dc_groups, configured_dc_groups); let dc_coverage = ratio_01(covered_dc_groups, configured_dc_groups);
let writer_coverage = ratio_01(status.alive_writers, status.required_writers); let writer_coverage = ratio_01(status.alive_writers, status.required_writers);

View File

@ -2,8 +2,8 @@ use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
use crate::config::ApiConfig; use crate::config::ApiConfig;
use crate::stats::Stats; use crate::stats::Stats;
use crate::transport::upstream::IpPreference;
use crate::transport::UpstreamRouteKind; use crate::transport::UpstreamRouteKind;
use crate::transport::upstream::IpPreference;
use super::ApiShared; use super::ApiShared;
use super::model::{ use super::model::{

View File

@ -128,7 +128,8 @@ pub(super) fn build_system_info_data(
.runtime_state .runtime_state
.last_config_reload_epoch_secs .last_config_reload_epoch_secs
.load(Ordering::Relaxed); .load(Ordering::Relaxed);
let last_config_reload_epoch_secs = (last_reload_epoch_secs > 0).then_some(last_reload_epoch_secs); let last_config_reload_epoch_secs =
(last_reload_epoch_secs > 0).then_some(last_reload_epoch_secs);
let git_commit = option_env!("TELEMT_GIT_COMMIT") let git_commit = option_env!("TELEMT_GIT_COMMIT")
.or(option_env!("VERGEN_GIT_SHA")) .or(option_env!("VERGEN_GIT_SHA"))
@ -153,7 +154,10 @@ pub(super) fn build_system_info_data(
uptime_seconds: shared.stats.uptime_secs(), uptime_seconds: shared.stats.uptime_secs(),
config_path: shared.config_path.display().to_string(), config_path: shared.config_path.display().to_string(),
config_hash: revision.to_string(), config_hash: revision.to_string(),
config_reload_count: shared.runtime_state.config_reload_count.load(Ordering::Relaxed), config_reload_count: shared
.runtime_state
.config_reload_count
.load(Ordering::Relaxed),
last_config_reload_epoch_secs, last_config_reload_epoch_secs,
} }
} }
@ -233,9 +237,7 @@ pub(super) fn build_limits_effective_data(cfg: &ProxyConfig) -> EffectiveLimitsD
adaptive_floor_writers_per_core_total: cfg adaptive_floor_writers_per_core_total: cfg
.general .general
.me_adaptive_floor_writers_per_core_total, .me_adaptive_floor_writers_per_core_total,
adaptive_floor_cpu_cores_override: cfg adaptive_floor_cpu_cores_override: cfg.general.me_adaptive_floor_cpu_cores_override,
.general
.me_adaptive_floor_cpu_cores_override,
adaptive_floor_max_extra_writers_single_per_core: cfg adaptive_floor_max_extra_writers_single_per_core: cfg
.general .general
.me_adaptive_floor_max_extra_writers_single_per_core, .me_adaptive_floor_max_extra_writers_single_per_core,

View File

@ -46,7 +46,9 @@ pub(super) async fn create_user(
None => random_user_secret(), None => random_user_secret(),
}; };
if let Some(ad_tag) = body.user_ad_tag.as_ref() && !is_valid_ad_tag(ad_tag) { if let Some(ad_tag) = body.user_ad_tag.as_ref()
&& !is_valid_ad_tag(ad_tag)
{
return Err(ApiFailure::bad_request( return Err(ApiFailure::bad_request(
"user_ad_tag must be exactly 32 hex characters", "user_ad_tag must be exactly 32 hex characters",
)); ));
@ -65,12 +67,18 @@ pub(super) async fn create_user(
)); ));
} }
cfg.access.users.insert(body.username.clone(), secret.clone()); cfg.access
.users
.insert(body.username.clone(), secret.clone());
if let Some(ad_tag) = body.user_ad_tag { if let Some(ad_tag) = body.user_ad_tag {
cfg.access.user_ad_tags.insert(body.username.clone(), ad_tag); cfg.access
.user_ad_tags
.insert(body.username.clone(), ad_tag);
} }
if let Some(limit) = body.max_tcp_conns { if let Some(limit) = body.max_tcp_conns {
cfg.access.user_max_tcp_conns.insert(body.username.clone(), limit); cfg.access
.user_max_tcp_conns
.insert(body.username.clone(), limit);
} }
if let Some(expiration) = expiration { if let Some(expiration) = expiration {
cfg.access cfg.access
@ -78,7 +86,9 @@ pub(super) async fn create_user(
.insert(body.username.clone(), expiration); .insert(body.username.clone(), expiration);
} }
if let Some(quota) = body.data_quota_bytes { if let Some(quota) = body.data_quota_bytes {
cfg.access.user_data_quota.insert(body.username.clone(), quota); cfg.access
.user_data_quota
.insert(body.username.clone(), quota);
} }
let updated_limit = body.max_unique_ips; let updated_limit = body.max_unique_ips;
@ -108,11 +118,15 @@ pub(super) async fn create_user(
touched_sections.push(AccessSection::UserMaxUniqueIps); touched_sections.push(AccessSection::UserMaxUniqueIps);
} }
let revision = save_access_sections_to_disk(&shared.config_path, &cfg, &touched_sections).await?; let revision =
save_access_sections_to_disk(&shared.config_path, &cfg, &touched_sections).await?;
drop(_guard); drop(_guard);
if let Some(limit) = updated_limit { if let Some(limit) = updated_limit {
shared.ip_tracker.set_user_limit(&body.username, limit).await; shared
.ip_tracker
.set_user_limit(&body.username, limit)
.await;
} }
let (detected_ip_v4, detected_ip_v6) = shared.detected_link_ips(); let (detected_ip_v4, detected_ip_v6) = shared.detected_link_ips();
@ -140,12 +154,7 @@ pub(super) async fn create_user(
recent_unique_ips: 0, recent_unique_ips: 0,
recent_unique_ips_list: Vec::new(), recent_unique_ips_list: Vec::new(),
total_octets: 0, total_octets: 0,
links: build_user_links( links: build_user_links(&cfg, &secret, detected_ip_v4, detected_ip_v6),
&cfg,
&secret,
detected_ip_v4,
detected_ip_v6,
),
}); });
Ok((CreateUserResponse { user, secret }, revision)) Ok((CreateUserResponse { user, secret }, revision))
@ -157,12 +166,16 @@ pub(super) async fn patch_user(
expected_revision: Option<String>, expected_revision: Option<String>,
shared: &ApiShared, shared: &ApiShared,
) -> Result<(UserInfo, String), ApiFailure> { ) -> Result<(UserInfo, String), ApiFailure> {
if let Some(secret) = body.secret.as_ref() && !is_valid_user_secret(secret) { if let Some(secret) = body.secret.as_ref()
&& !is_valid_user_secret(secret)
{
return Err(ApiFailure::bad_request( return Err(ApiFailure::bad_request(
"secret must be exactly 32 hex characters", "secret must be exactly 32 hex characters",
)); ));
} }
if let Some(ad_tag) = body.user_ad_tag.as_ref() && !is_valid_ad_tag(ad_tag) { if let Some(ad_tag) = body.user_ad_tag.as_ref()
&& !is_valid_ad_tag(ad_tag)
{
return Err(ApiFailure::bad_request( return Err(ApiFailure::bad_request(
"user_ad_tag must be exactly 32 hex characters", "user_ad_tag must be exactly 32 hex characters",
)); ));
@ -187,10 +200,14 @@ pub(super) async fn patch_user(
cfg.access.user_ad_tags.insert(user.to_string(), ad_tag); cfg.access.user_ad_tags.insert(user.to_string(), ad_tag);
} }
if let Some(limit) = body.max_tcp_conns { if let Some(limit) = body.max_tcp_conns {
cfg.access.user_max_tcp_conns.insert(user.to_string(), limit); cfg.access
.user_max_tcp_conns
.insert(user.to_string(), limit);
} }
if let Some(expiration) = expiration { if let Some(expiration) = expiration {
cfg.access.user_expirations.insert(user.to_string(), expiration); cfg.access
.user_expirations
.insert(user.to_string(), expiration);
} }
if let Some(quota) = body.data_quota_bytes { if let Some(quota) = body.data_quota_bytes {
cfg.access.user_data_quota.insert(user.to_string(), quota); cfg.access.user_data_quota.insert(user.to_string(), quota);
@ -198,7 +215,9 @@ pub(super) async fn patch_user(
let mut updated_limit = None; let mut updated_limit = None;
if let Some(limit) = body.max_unique_ips { if let Some(limit) = body.max_unique_ips {
cfg.access.user_max_unique_ips.insert(user.to_string(), limit); cfg.access
.user_max_unique_ips
.insert(user.to_string(), limit);
updated_limit = Some(limit); updated_limit = Some(limit);
} }
@ -263,7 +282,8 @@ pub(super) async fn rotate_secret(
AccessSection::UserDataQuota, AccessSection::UserDataQuota,
AccessSection::UserMaxUniqueIps, AccessSection::UserMaxUniqueIps,
]; ];
let revision = save_access_sections_to_disk(&shared.config_path, &cfg, &touched_sections).await?; let revision =
save_access_sections_to_disk(&shared.config_path, &cfg, &touched_sections).await?;
drop(_guard); drop(_guard);
let (detected_ip_v4, detected_ip_v6) = shared.detected_link_ips(); let (detected_ip_v4, detected_ip_v6) = shared.detected_link_ips();
@ -330,7 +350,8 @@ pub(super) async fn delete_user(
AccessSection::UserDataQuota, AccessSection::UserDataQuota,
AccessSection::UserMaxUniqueIps, AccessSection::UserMaxUniqueIps,
]; ];
let revision = save_access_sections_to_disk(&shared.config_path, &cfg, &touched_sections).await?; let revision =
save_access_sections_to_disk(&shared.config_path, &cfg, &touched_sections).await?;
drop(_guard); drop(_guard);
shared.ip_tracker.remove_user_limit(user).await; shared.ip_tracker.remove_user_limit(user).await;
shared.ip_tracker.clear_user_ips(user).await; shared.ip_tracker.clear_user_ips(user).await;
@ -365,12 +386,7 @@ pub(super) async fn users_from_config(
.users .users
.get(&username) .get(&username)
.map(|secret| { .map(|secret| {
build_user_links( build_user_links(cfg, secret, startup_detected_ip_v4, startup_detected_ip_v6)
cfg,
secret,
startup_detected_ip_v4,
startup_detected_ip_v6,
)
}) })
.unwrap_or(UserLinks { .unwrap_or(UserLinks {
classic: Vec::new(), classic: Vec::new(),
@ -392,10 +408,8 @@ pub(super) async fn users_from_config(
.get(&username) .get(&username)
.copied() .copied()
.filter(|limit| *limit > 0) .filter(|limit| *limit > 0)
.or( .or((cfg.access.user_max_unique_ips_global_each > 0)
(cfg.access.user_max_unique_ips_global_each > 0) .then_some(cfg.access.user_max_unique_ips_global_each)),
.then_some(cfg.access.user_max_unique_ips_global_each),
),
current_connections: stats.get_user_curr_connects(&username), current_connections: stats.get_user_curr_connects(&username),
active_unique_ips: active_ip_list.len(), active_unique_ips: active_ip_list.len(),
active_unique_ips_list: active_ip_list, active_unique_ips_list: active_ip_list,
@ -481,12 +495,12 @@ fn resolve_link_hosts(
push_unique_host(&mut hosts, host); push_unique_host(&mut hosts, host);
continue; continue;
} }
if let Some(ip) = listener.announce_ip { if let Some(ip) = listener.announce_ip
if !ip.is_unspecified() { && !ip.is_unspecified()
{
push_unique_host(&mut hosts, &ip.to_string()); push_unique_host(&mut hosts, &ip.to_string());
continue; continue;
} }
}
if listener.ip.is_unspecified() { if listener.ip.is_unspecified() {
let detected_ip = if listener.ip.is_ipv4() { let detected_ip = if listener.ip.is_ipv4() {
startup_detected_ip_v4 startup_detected_ip_v4

View File

@ -1,9 +1,9 @@
//! CLI commands: --init (fire-and-forget setup) //! CLI commands: --init (fire-and-forget setup)
use rand::RngExt;
use std::fs; use std::fs;
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use std::process::Command; use std::process::Command;
use rand::RngExt;
/// Options for the init command /// Options for the init command
pub struct InitOptions { pub struct InitOptions {
@ -114,8 +114,8 @@ pub fn run_init(opts: InitOptions) -> Result<(), Box<dyn std::error::Error>> {
eprintln!("[+] Config written to {}", config_path.display()); eprintln!("[+] Config written to {}", config_path.display());
// 4. Write systemd unit // 4. Write systemd unit
let exe_path = std::env::current_exe() let exe_path =
.unwrap_or_else(|_| PathBuf::from("/usr/local/bin/telemt")); std::env::current_exe().unwrap_or_else(|_| PathBuf::from("/usr/local/bin/telemt"));
let unit_path = Path::new("/etc/systemd/system/telemt.service"); let unit_path = Path::new("/etc/systemd/system/telemt.service");
let unit_content = generate_systemd_unit(&exe_path, &config_path); let unit_content = generate_systemd_unit(&exe_path, &config_path);
@ -183,7 +183,7 @@ fn generate_secret() -> String {
fn generate_config(username: &str, secret: &str, port: u16, domain: &str) -> String { fn generate_config(username: &str, secret: &str, port: u16, domain: &str) -> String {
format!( format!(
r#"# Telemt MTProxy — auto-generated config r#"# Telemt MTProxy — auto-generated config
# Re-run `telemt --init` to regenerate # Re-run `telemt --init` to regenerate
show_link = ["{username}"] show_link = ["{username}"]
@ -266,7 +266,7 @@ weight = 10
fn generate_systemd_unit(exe_path: &Path, config_path: &Path) -> String { fn generate_systemd_unit(exe_path: &Path, config_path: &Path) -> String {
format!( format!(
r#"[Unit] r#"[Unit]
Description=Telemt MTProxy Description=Telemt MTProxy
Documentation=https://github.com/telemt/telemt Documentation=https://github.com/telemt/telemt
After=network-online.target After=network-online.target
@ -312,8 +312,10 @@ fn print_links(username: &str, secret: &str, port: u16, domain: &str) {
println!("=== Proxy Links ==="); println!("=== Proxy Links ===");
println!("[{}]", username); println!("[{}]", username);
println!(" EE-TLS: tg://proxy?server=YOUR_SERVER_IP&port={}&secret=ee{}{}", println!(
port, secret, domain_hex); " EE-TLS: tg://proxy?server=YOUR_SERVER_IP&port={}&secret=ee{}{}",
port, secret, domain_hex
);
println!(); println!();
println!("Replace YOUR_SERVER_IP with your server's public IP."); println!("Replace YOUR_SERVER_IP with your server's public IP.");
println!("The proxy will auto-detect and display the correct link on startup."); println!("The proxy will auto-detect and display the correct link on startup.");

View File

@ -1,6 +1,6 @@
use std::collections::HashMap;
use ipnetwork::IpNetwork; use ipnetwork::IpNetwork;
use serde::Deserialize; use serde::Deserialize;
use std::collections::HashMap;
// Helper defaults kept private to the config module. // Helper defaults kept private to the config module.
const DEFAULT_NETWORK_IPV6: Option<bool> = Some(false); const DEFAULT_NETWORK_IPV6: Option<bool> = Some(false);
@ -143,10 +143,7 @@ pub(crate) fn default_weight() -> u16 {
} }
pub(crate) fn default_metrics_whitelist() -> Vec<IpNetwork> { pub(crate) fn default_metrics_whitelist() -> Vec<IpNetwork> {
vec![ vec!["127.0.0.1/32".parse().unwrap(), "::1/128".parse().unwrap()]
"127.0.0.1/32".parse().unwrap(),
"::1/128".parse().unwrap(),
]
} }
pub(crate) fn default_api_listen() -> String { pub(crate) fn default_api_listen() -> String {
@ -169,10 +166,18 @@ pub(crate) fn default_api_minimal_runtime_cache_ttl_ms() -> u64 {
1000 1000
} }
pub(crate) fn default_api_runtime_edge_enabled() -> bool { false } pub(crate) fn default_api_runtime_edge_enabled() -> bool {
pub(crate) fn default_api_runtime_edge_cache_ttl_ms() -> u64 { 1000 } false
pub(crate) fn default_api_runtime_edge_top_n() -> usize { 10 } }
pub(crate) fn default_api_runtime_edge_events_capacity() -> usize { 256 } pub(crate) fn default_api_runtime_edge_cache_ttl_ms() -> u64 {
1000
}
pub(crate) fn default_api_runtime_edge_top_n() -> usize {
10
}
pub(crate) fn default_api_runtime_edge_events_capacity() -> usize {
256
}
pub(crate) fn default_proxy_protocol_header_timeout_ms() -> u64 { pub(crate) fn default_proxy_protocol_header_timeout_ms() -> u64 {
500 500
@ -518,6 +523,10 @@ pub(crate) fn default_mask_shape_hardening() -> bool {
true true
} }
pub(crate) fn default_mask_shape_hardening_aggressive_mode() -> bool {
false
}
pub(crate) fn default_mask_shape_bucket_floor_bytes() -> usize { pub(crate) fn default_mask_shape_bucket_floor_bytes() -> usize {
512 512
} }

View File

@ -31,11 +31,10 @@ use notify::{EventKind, RecursiveMode, Watcher, recommended_watcher};
use tokio::sync::{mpsc, watch}; use tokio::sync::{mpsc, watch};
use tracing::{error, info, warn}; use tracing::{error, info, warn};
use crate::config::{
LogLevel, MeBindStaleMode, MeFloorMode, MeSocksKdfPolicy, MeTelemetryLevel,
MeWriterPickMode,
};
use super::load::{LoadedConfig, ProxyConfig}; use super::load::{LoadedConfig, ProxyConfig};
use crate::config::{
LogLevel, MeBindStaleMode, MeFloorMode, MeSocksKdfPolicy, MeTelemetryLevel, MeWriterPickMode,
};
const HOT_RELOAD_DEBOUNCE: Duration = Duration::from_millis(50); const HOT_RELOAD_DEBOUNCE: Duration = Duration::from_millis(50);
@ -189,15 +188,11 @@ impl HotFields {
me_adaptive_floor_min_writers_multi_endpoint: cfg me_adaptive_floor_min_writers_multi_endpoint: cfg
.general .general
.me_adaptive_floor_min_writers_multi_endpoint, .me_adaptive_floor_min_writers_multi_endpoint,
me_adaptive_floor_recover_grace_secs: cfg me_adaptive_floor_recover_grace_secs: cfg.general.me_adaptive_floor_recover_grace_secs,
.general
.me_adaptive_floor_recover_grace_secs,
me_adaptive_floor_writers_per_core_total: cfg me_adaptive_floor_writers_per_core_total: cfg
.general .general
.me_adaptive_floor_writers_per_core_total, .me_adaptive_floor_writers_per_core_total,
me_adaptive_floor_cpu_cores_override: cfg me_adaptive_floor_cpu_cores_override: cfg.general.me_adaptive_floor_cpu_cores_override,
.general
.me_adaptive_floor_cpu_cores_override,
me_adaptive_floor_max_extra_writers_single_per_core: cfg me_adaptive_floor_max_extra_writers_single_per_core: cfg
.general .general
.me_adaptive_floor_max_extra_writers_single_per_core, .me_adaptive_floor_max_extra_writers_single_per_core,
@ -216,9 +211,15 @@ impl HotFields {
me_adaptive_floor_max_warm_writers_global: cfg me_adaptive_floor_max_warm_writers_global: cfg
.general .general
.me_adaptive_floor_max_warm_writers_global, .me_adaptive_floor_max_warm_writers_global,
me_route_backpressure_base_timeout_ms: cfg.general.me_route_backpressure_base_timeout_ms, me_route_backpressure_base_timeout_ms: cfg
me_route_backpressure_high_timeout_ms: cfg.general.me_route_backpressure_high_timeout_ms, .general
me_route_backpressure_high_watermark_pct: cfg.general.me_route_backpressure_high_watermark_pct, .me_route_backpressure_base_timeout_ms,
me_route_backpressure_high_timeout_ms: cfg
.general
.me_route_backpressure_high_timeout_ms,
me_route_backpressure_high_watermark_pct: cfg
.general
.me_route_backpressure_high_watermark_pct,
me_reader_route_data_wait_ms: cfg.general.me_reader_route_data_wait_ms, me_reader_route_data_wait_ms: cfg.general.me_reader_route_data_wait_ms,
me_d2c_flush_batch_max_frames: cfg.general.me_d2c_flush_batch_max_frames, me_d2c_flush_batch_max_frames: cfg.general.me_d2c_flush_batch_max_frames,
me_d2c_flush_batch_max_bytes: cfg.general.me_d2c_flush_batch_max_bytes, me_d2c_flush_batch_max_bytes: cfg.general.me_d2c_flush_batch_max_bytes,
@ -334,7 +335,9 @@ struct ReloadState {
impl ReloadState { impl ReloadState {
fn new(applied_snapshot_hash: Option<u64>) -> Self { fn new(applied_snapshot_hash: Option<u64>) -> Self {
Self { applied_snapshot_hash } Self {
applied_snapshot_hash,
}
} }
fn is_applied(&self, hash: u64) -> bool { fn is_applied(&self, hash: u64) -> bool {
@ -481,10 +484,14 @@ fn overlay_hot_fields(old: &ProxyConfig, new: &ProxyConfig) -> ProxyConfig {
new.general.me_adaptive_floor_writers_per_core_total; new.general.me_adaptive_floor_writers_per_core_total;
cfg.general.me_adaptive_floor_cpu_cores_override = cfg.general.me_adaptive_floor_cpu_cores_override =
new.general.me_adaptive_floor_cpu_cores_override; new.general.me_adaptive_floor_cpu_cores_override;
cfg.general.me_adaptive_floor_max_extra_writers_single_per_core = cfg.general
new.general.me_adaptive_floor_max_extra_writers_single_per_core; .me_adaptive_floor_max_extra_writers_single_per_core = new
cfg.general.me_adaptive_floor_max_extra_writers_multi_per_core = .general
new.general.me_adaptive_floor_max_extra_writers_multi_per_core; .me_adaptive_floor_max_extra_writers_single_per_core;
cfg.general
.me_adaptive_floor_max_extra_writers_multi_per_core = new
.general
.me_adaptive_floor_max_extra_writers_multi_per_core;
cfg.general.me_adaptive_floor_max_active_writers_per_core = cfg.general.me_adaptive_floor_max_active_writers_per_core =
new.general.me_adaptive_floor_max_active_writers_per_core; new.general.me_adaptive_floor_max_active_writers_per_core;
cfg.general.me_adaptive_floor_max_warm_writers_per_core = cfg.general.me_adaptive_floor_max_warm_writers_per_core =
@ -543,8 +550,7 @@ fn warn_non_hot_changes(old: &ProxyConfig, new: &ProxyConfig, non_hot_changed: b
|| old.server.api.minimal_runtime_cache_ttl_ms || old.server.api.minimal_runtime_cache_ttl_ms
!= new.server.api.minimal_runtime_cache_ttl_ms != new.server.api.minimal_runtime_cache_ttl_ms
|| old.server.api.runtime_edge_enabled != new.server.api.runtime_edge_enabled || old.server.api.runtime_edge_enabled != new.server.api.runtime_edge_enabled
|| old.server.api.runtime_edge_cache_ttl_ms || old.server.api.runtime_edge_cache_ttl_ms != new.server.api.runtime_edge_cache_ttl_ms
!= new.server.api.runtime_edge_cache_ttl_ms
|| old.server.api.runtime_edge_top_n != new.server.api.runtime_edge_top_n || old.server.api.runtime_edge_top_n != new.server.api.runtime_edge_top_n
|| old.server.api.runtime_edge_events_capacity || old.server.api.runtime_edge_events_capacity
!= new.server.api.runtime_edge_events_capacity != new.server.api.runtime_edge_events_capacity
@ -583,10 +589,8 @@ fn warn_non_hot_changes(old: &ProxyConfig, new: &ProxyConfig, non_hot_changed: b
|| old.censorship.mask_shape_hardening != new.censorship.mask_shape_hardening || old.censorship.mask_shape_hardening != new.censorship.mask_shape_hardening
|| old.censorship.mask_shape_bucket_floor_bytes || old.censorship.mask_shape_bucket_floor_bytes
!= new.censorship.mask_shape_bucket_floor_bytes != new.censorship.mask_shape_bucket_floor_bytes
|| old.censorship.mask_shape_bucket_cap_bytes || old.censorship.mask_shape_bucket_cap_bytes != new.censorship.mask_shape_bucket_cap_bytes
!= new.censorship.mask_shape_bucket_cap_bytes || old.censorship.mask_shape_above_cap_blur != new.censorship.mask_shape_above_cap_blur
|| old.censorship.mask_shape_above_cap_blur
!= new.censorship.mask_shape_above_cap_blur
|| old.censorship.mask_shape_above_cap_blur_max_bytes || old.censorship.mask_shape_above_cap_blur_max_bytes
!= new.censorship.mask_shape_above_cap_blur_max_bytes != new.censorship.mask_shape_above_cap_blur_max_bytes
|| old.censorship.mask_timing_normalization_enabled || old.censorship.mask_timing_normalization_enabled
@ -870,8 +874,7 @@ fn log_changes(
{ {
info!( info!(
"config reload: me_bind_stale: mode={:?} ttl={}s", "config reload: me_bind_stale: mode={:?} ttl={}s",
new_hot.me_bind_stale_mode, new_hot.me_bind_stale_mode, new_hot.me_bind_stale_ttl_secs
new_hot.me_bind_stale_ttl_secs
); );
} }
if old_hot.me_secret_atomic_snapshot != new_hot.me_secret_atomic_snapshot if old_hot.me_secret_atomic_snapshot != new_hot.me_secret_atomic_snapshot
@ -951,8 +954,7 @@ fn log_changes(
if old_hot.me_socks_kdf_policy != new_hot.me_socks_kdf_policy { if old_hot.me_socks_kdf_policy != new_hot.me_socks_kdf_policy {
info!( info!(
"config reload: me_socks_kdf_policy: {:?} → {:?}", "config reload: me_socks_kdf_policy: {:?} → {:?}",
old_hot.me_socks_kdf_policy, old_hot.me_socks_kdf_policy, new_hot.me_socks_kdf_policy,
new_hot.me_socks_kdf_policy,
); );
} }
@ -1006,8 +1008,7 @@ fn log_changes(
|| old_hot.me_route_backpressure_high_watermark_pct || old_hot.me_route_backpressure_high_watermark_pct
!= new_hot.me_route_backpressure_high_watermark_pct != new_hot.me_route_backpressure_high_watermark_pct
|| old_hot.me_reader_route_data_wait_ms != new_hot.me_reader_route_data_wait_ms || old_hot.me_reader_route_data_wait_ms != new_hot.me_reader_route_data_wait_ms
|| old_hot.me_health_interval_ms_unhealthy || old_hot.me_health_interval_ms_unhealthy != new_hot.me_health_interval_ms_unhealthy
!= new_hot.me_health_interval_ms_unhealthy
|| old_hot.me_health_interval_ms_healthy != new_hot.me_health_interval_ms_healthy || old_hot.me_health_interval_ms_healthy != new_hot.me_health_interval_ms_healthy
|| old_hot.me_admission_poll_ms != new_hot.me_admission_poll_ms || old_hot.me_admission_poll_ms != new_hot.me_admission_poll_ms
|| old_hot.me_warn_rate_limit_ms != new_hot.me_warn_rate_limit_ms || old_hot.me_warn_rate_limit_ms != new_hot.me_warn_rate_limit_ms
@ -1044,19 +1045,27 @@ fn log_changes(
} }
if old_hot.users != new_hot.users { if old_hot.users != new_hot.users {
let mut added: Vec<&String> = new_hot.users.keys() let mut added: Vec<&String> = new_hot
.users
.keys()
.filter(|u| !old_hot.users.contains_key(*u)) .filter(|u| !old_hot.users.contains_key(*u))
.collect(); .collect();
added.sort(); added.sort();
let mut removed: Vec<&String> = old_hot.users.keys() let mut removed: Vec<&String> = old_hot
.users
.keys()
.filter(|u| !new_hot.users.contains_key(*u)) .filter(|u| !new_hot.users.contains_key(*u))
.collect(); .collect();
removed.sort(); removed.sort();
let mut changed: Vec<&String> = new_hot.users.keys() let mut changed: Vec<&String> = new_hot
.users
.keys()
.filter(|u| { .filter(|u| {
old_hot.users.get(*u) old_hot
.users
.get(*u)
.map(|s| s != &new_hot.users[*u]) .map(|s| s != &new_hot.users[*u])
.unwrap_or(false) .unwrap_or(false)
}) })
@ -1066,10 +1075,18 @@ fn log_changes(
if !added.is_empty() { if !added.is_empty() {
info!( info!(
"config reload: users added: [{}]", "config reload: users added: [{}]",
added.iter().map(|s| s.as_str()).collect::<Vec<_>>().join(", ") added
.iter()
.map(|s| s.as_str())
.collect::<Vec<_>>()
.join(", ")
); );
let host = resolve_link_host(new_cfg, detected_ip_v4, detected_ip_v6); let host = resolve_link_host(new_cfg, detected_ip_v4, detected_ip_v6);
let port = new_cfg.general.links.public_port.unwrap_or(new_cfg.server.port); let port = new_cfg
.general
.links
.public_port
.unwrap_or(new_cfg.server.port);
for user in &added { for user in &added {
if let Some(secret) = new_hot.users.get(*user) { if let Some(secret) = new_hot.users.get(*user) {
print_user_links(user, secret, &host, port, new_cfg); print_user_links(user, secret, &host, port, new_cfg);
@ -1079,13 +1096,21 @@ fn log_changes(
if !removed.is_empty() { if !removed.is_empty() {
info!( info!(
"config reload: users removed: [{}]", "config reload: users removed: [{}]",
removed.iter().map(|s| s.as_str()).collect::<Vec<_>>().join(", ") removed
.iter()
.map(|s| s.as_str())
.collect::<Vec<_>>()
.join(", ")
); );
} }
if !changed.is_empty() { if !changed.is_empty() {
info!( info!(
"config reload: users secret changed: [{}]", "config reload: users secret changed: [{}]",
changed.iter().map(|s| s.as_str()).collect::<Vec<_>>().join(", ") changed
.iter()
.map(|s| s.as_str())
.collect::<Vec<_>>()
.join(", ")
); );
} }
} }
@ -1116,8 +1141,7 @@ fn log_changes(
} }
if old_hot.user_max_unique_ips_global_each != new_hot.user_max_unique_ips_global_each if old_hot.user_max_unique_ips_global_each != new_hot.user_max_unique_ips_global_each
|| old_hot.user_max_unique_ips_mode != new_hot.user_max_unique_ips_mode || old_hot.user_max_unique_ips_mode != new_hot.user_max_unique_ips_mode
|| old_hot.user_max_unique_ips_window_secs || old_hot.user_max_unique_ips_window_secs != new_hot.user_max_unique_ips_window_secs
!= new_hot.user_max_unique_ips_window_secs
{ {
info!( info!(
"config reload: user_max_unique_ips policy global_each={} mode={:?} window={}s", "config reload: user_max_unique_ips policy global_each={} mode={:?} window={}s",
@ -1152,7 +1176,10 @@ fn reload_config(
let next_manifest = WatchManifest::from_source_files(&source_files); let next_manifest = WatchManifest::from_source_files(&source_files);
if let Err(e) = new_cfg.validate() { if let Err(e) = new_cfg.validate() {
error!("config reload: validation failed: {}; keeping old config", e); error!(
"config reload: validation failed: {}; keeping old config",
e
);
return Some(next_manifest); return Some(next_manifest);
} }
@ -1234,9 +1261,13 @@ pub fn spawn_config_watcher(
let tx_inotify = notify_tx.clone(); let tx_inotify = notify_tx.clone();
let manifest_for_inotify = manifest_state.clone(); let manifest_for_inotify = manifest_state.clone();
let mut inotify_watcher = match recommended_watcher(move |res: notify::Result<notify::Event>| { let mut inotify_watcher =
match recommended_watcher(move |res: notify::Result<notify::Event>| {
let Ok(event) = res else { return }; let Ok(event) = res else { return };
if !matches!(event.kind, EventKind::Modify(_) | EventKind::Create(_) | EventKind::Remove(_)) { if !matches!(
event.kind,
EventKind::Modify(_) | EventKind::Create(_) | EventKind::Remove(_)
) {
return; return;
} }
let is_our_file = manifest_for_inotify let is_our_file = manifest_for_inotify
@ -1268,7 +1299,10 @@ pub fn spawn_config_watcher(
let mut poll_watcher = match notify::poll::PollWatcher::new( let mut poll_watcher = match notify::poll::PollWatcher::new(
move |res: notify::Result<notify::Event>| { move |res: notify::Result<notify::Event>| {
let Ok(event) = res else { return }; let Ok(event) = res else { return };
if !matches!(event.kind, EventKind::Modify(_) | EventKind::Create(_) | EventKind::Remove(_)) { if !matches!(
event.kind,
EventKind::Modify(_) | EventKind::Create(_) | EventKind::Remove(_)
) {
return; return;
} }
let is_our_file = manifest_for_poll let is_our_file = manifest_for_poll
@ -1316,7 +1350,9 @@ pub fn spawn_config_watcher(
} }
} }
#[cfg(not(unix))] #[cfg(not(unix))]
if notify_rx.recv().await.is_none() { break; } if notify_rx.recv().await.is_none() {
break;
}
// Debounce: drain extra events that arrive within a short quiet window. // Debounce: drain extra events that arrive within a short quiet window.
tokio::time::sleep(HOT_RELOAD_DEBOUNCE).await; tokio::time::sleep(HOT_RELOAD_DEBOUNCE).await;
@ -1418,7 +1454,10 @@ mod tests {
new.server.port = old.server.port.saturating_add(1); new.server.port = old.server.port.saturating_add(1);
let applied = overlay_hot_fields(&old, &new); let applied = overlay_hot_fields(&old, &new);
assert_eq!(HotFields::from_config(&old), HotFields::from_config(&applied)); assert_eq!(
HotFields::from_config(&old),
HotFields::from_config(&applied)
);
assert_eq!(applied.server.port, old.server.port); assert_eq!(applied.server.port, old.server.port);
} }
@ -1437,7 +1476,10 @@ mod tests {
applied.general.me_bind_stale_mode, applied.general.me_bind_stale_mode,
new.general.me_bind_stale_mode new.general.me_bind_stale_mode
); );
assert_ne!(HotFields::from_config(&old), HotFields::from_config(&applied)); assert_ne!(
HotFields::from_config(&old),
HotFields::from_config(&applied)
);
} }
#[test] #[test]
@ -1451,7 +1493,10 @@ mod tests {
applied.general.me_keepalive_interval_secs, applied.general.me_keepalive_interval_secs,
old.general.me_keepalive_interval_secs old.general.me_keepalive_interval_secs
); );
assert_eq!(HotFields::from_config(&old), HotFields::from_config(&applied)); assert_eq!(
HotFields::from_config(&old),
HotFields::from_config(&applied)
);
} }
#[test] #[test]
@ -1463,7 +1508,10 @@ mod tests {
let applied = overlay_hot_fields(&old, &new); let applied = overlay_hot_fields(&old, &new);
assert_eq!(applied.general.hardswap, new.general.hardswap); assert_eq!(applied.general.hardswap, new.general.hardswap);
assert_eq!(applied.general.use_middle_proxy, old.general.use_middle_proxy); assert_eq!(
applied.general.use_middle_proxy,
old.general.use_middle_proxy
);
assert!(!config_equal(&applied, &new)); assert!(!config_equal(&applied, &new));
} }
@ -1475,14 +1523,19 @@ mod tests {
write_reload_config(&path, Some(initial_tag), None); write_reload_config(&path, Some(initial_tag), None);
let initial_cfg = Arc::new(ProxyConfig::load(&path).unwrap()); let initial_cfg = Arc::new(ProxyConfig::load(&path).unwrap());
let initial_hash = ProxyConfig::load_with_metadata(&path).unwrap().rendered_hash; let initial_hash = ProxyConfig::load_with_metadata(&path)
.unwrap()
.rendered_hash;
let (config_tx, _config_rx) = watch::channel(initial_cfg.clone()); let (config_tx, _config_rx) = watch::channel(initial_cfg.clone());
let (log_tx, _log_rx) = watch::channel(initial_cfg.general.log_level.clone()); let (log_tx, _log_rx) = watch::channel(initial_cfg.general.log_level.clone());
let mut reload_state = ReloadState::new(Some(initial_hash)); let mut reload_state = ReloadState::new(Some(initial_hash));
write_reload_config(&path, Some(final_tag), None); write_reload_config(&path, Some(final_tag), None);
reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap(); reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap();
assert_eq!(config_tx.borrow().general.ad_tag.as_deref(), Some(final_tag)); assert_eq!(
config_tx.borrow().general.ad_tag.as_deref(),
Some(final_tag)
);
let _ = std::fs::remove_file(path); let _ = std::fs::remove_file(path);
} }
@ -1495,7 +1548,9 @@ mod tests {
write_reload_config(&path, Some(initial_tag), None); write_reload_config(&path, Some(initial_tag), None);
let initial_cfg = Arc::new(ProxyConfig::load(&path).unwrap()); let initial_cfg = Arc::new(ProxyConfig::load(&path).unwrap());
let initial_hash = ProxyConfig::load_with_metadata(&path).unwrap().rendered_hash; let initial_hash = ProxyConfig::load_with_metadata(&path)
.unwrap()
.rendered_hash;
let (config_tx, _config_rx) = watch::channel(initial_cfg.clone()); let (config_tx, _config_rx) = watch::channel(initial_cfg.clone());
let (log_tx, _log_rx) = watch::channel(initial_cfg.general.log_level.clone()); let (log_tx, _log_rx) = watch::channel(initial_cfg.general.log_level.clone());
let mut reload_state = ReloadState::new(Some(initial_hash)); let mut reload_state = ReloadState::new(Some(initial_hash));
@ -1518,7 +1573,9 @@ mod tests {
write_reload_config(&path, Some(initial_tag), None); write_reload_config(&path, Some(initial_tag), None);
let initial_cfg = Arc::new(ProxyConfig::load(&path).unwrap()); let initial_cfg = Arc::new(ProxyConfig::load(&path).unwrap());
let initial_hash = ProxyConfig::load_with_metadata(&path).unwrap().rendered_hash; let initial_hash = ProxyConfig::load_with_metadata(&path)
.unwrap()
.rendered_hash;
let (config_tx, _config_rx) = watch::channel(initial_cfg.clone()); let (config_tx, _config_rx) = watch::channel(initial_cfg.clone());
let (log_tx, _log_rx) = watch::channel(initial_cfg.general.log_level.clone()); let (log_tx, _log_rx) = watch::channel(initial_cfg.general.log_level.clone());
let mut reload_state = ReloadState::new(Some(initial_hash)); let mut reload_state = ReloadState::new(Some(initial_hash));
@ -1532,7 +1589,10 @@ mod tests {
write_reload_config(&path, Some(final_tag), None); write_reload_config(&path, Some(final_tag), None);
reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap(); reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap();
assert_eq!(config_tx.borrow().general.ad_tag.as_deref(), Some(final_tag)); assert_eq!(
config_tx.borrow().general.ad_tag.as_deref(),
Some(final_tag)
);
let _ = std::fs::remove_file(path); let _ = std::fs::remove_file(path);
} }

View File

@ -399,11 +399,18 @@ impl ProxyConfig {
)); ));
} }
if config.censorship.mask_shape_above_cap_blur if config.censorship.mask_shape_above_cap_blur && !config.censorship.mask_shape_hardening {
return Err(ProxyError::Config(
"censorship.mask_shape_above_cap_blur requires censorship.mask_shape_hardening = true"
.to_string(),
));
}
if config.censorship.mask_shape_hardening_aggressive_mode
&& !config.censorship.mask_shape_hardening && !config.censorship.mask_shape_hardening
{ {
return Err(ProxyError::Config( return Err(ProxyError::Config(
"censorship.mask_shape_above_cap_blur requires censorship.mask_shape_hardening = true" "censorship.mask_shape_hardening_aggressive_mode requires censorship.mask_shape_hardening = true"
.to_string(), .to_string(),
)); ));
} }
@ -419,8 +426,7 @@ impl ProxyConfig {
if config.censorship.mask_shape_above_cap_blur_max_bytes > 1_048_576 { if config.censorship.mask_shape_above_cap_blur_max_bytes > 1_048_576 {
return Err(ProxyError::Config( return Err(ProxyError::Config(
"censorship.mask_shape_above_cap_blur_max_bytes must be <= 1048576" "censorship.mask_shape_above_cap_blur_max_bytes must be <= 1048576".to_string(),
.to_string(),
)); ));
} }
@ -444,8 +450,7 @@ impl ProxyConfig {
if config.censorship.mask_timing_normalization_ceiling_ms > 60_000 { if config.censorship.mask_timing_normalization_ceiling_ms > 60_000 {
return Err(ProxyError::Config( return Err(ProxyError::Config(
"censorship.mask_timing_normalization_ceiling_ms must be <= 60000" "censorship.mask_timing_normalization_ceiling_ms must be <= 60000".to_string(),
.to_string(),
)); ));
} }
@ -461,8 +466,7 @@ impl ProxyConfig {
)); ));
} }
if config.timeouts.relay_client_idle_hard_secs if config.timeouts.relay_client_idle_hard_secs < config.timeouts.relay_client_idle_soft_secs
< config.timeouts.relay_client_idle_soft_secs
{ {
return Err(ProxyError::Config( return Err(ProxyError::Config(
"timeouts.relay_client_idle_hard_secs must be >= timeouts.relay_client_idle_soft_secs" "timeouts.relay_client_idle_hard_secs must be >= timeouts.relay_client_idle_soft_secs"
@ -470,7 +474,9 @@ impl ProxyConfig {
)); ));
} }
if config.timeouts.relay_idle_grace_after_downstream_activity_secs if config
.timeouts
.relay_idle_grace_after_downstream_activity_secs
> config.timeouts.relay_client_idle_hard_secs > config.timeouts.relay_client_idle_hard_secs
{ {
return Err(ProxyError::Config( return Err(ProxyError::Config(
@ -767,7 +773,8 @@ impl ProxyConfig {
} }
if config.general.me_route_backpressure_base_timeout_ms > 5000 { if config.general.me_route_backpressure_base_timeout_ms > 5000 {
return Err(ProxyError::Config( return Err(ProxyError::Config(
"general.me_route_backpressure_base_timeout_ms must be within [1, 5000]".to_string(), "general.me_route_backpressure_base_timeout_ms must be within [1, 5000]"
.to_string(),
)); ));
} }
@ -780,7 +787,8 @@ impl ProxyConfig {
} }
if config.general.me_route_backpressure_high_timeout_ms > 5000 { if config.general.me_route_backpressure_high_timeout_ms > 5000 {
return Err(ProxyError::Config( return Err(ProxyError::Config(
"general.me_route_backpressure_high_timeout_ms must be within [1, 5000]".to_string(), "general.me_route_backpressure_high_timeout_ms must be within [1, 5000]"
.to_string(),
)); ));
} }
@ -1828,7 +1836,9 @@ mod tests {
let path = dir.join("telemt_me_route_backpressure_base_timeout_ms_out_of_range_test.toml"); let path = dir.join("telemt_me_route_backpressure_base_timeout_ms_out_of_range_test.toml");
std::fs::write(&path, toml).unwrap(); std::fs::write(&path, toml).unwrap();
let err = ProxyConfig::load(&path).unwrap_err().to_string(); let err = ProxyConfig::load(&path).unwrap_err().to_string();
assert!(err.contains("general.me_route_backpressure_base_timeout_ms must be within [1, 5000]")); assert!(
err.contains("general.me_route_backpressure_base_timeout_ms must be within [1, 5000]")
);
let _ = std::fs::remove_file(path); let _ = std::fs::remove_file(path);
} }
@ -1849,7 +1859,9 @@ mod tests {
let path = dir.join("telemt_me_route_backpressure_high_timeout_ms_out_of_range_test.toml"); let path = dir.join("telemt_me_route_backpressure_high_timeout_ms_out_of_range_test.toml");
std::fs::write(&path, toml).unwrap(); std::fs::write(&path, toml).unwrap();
let err = ProxyConfig::load(&path).unwrap_err().to_string(); let err = ProxyConfig::load(&path).unwrap_err().to_string();
assert!(err.contains("general.me_route_backpressure_high_timeout_ms must be within [1, 5000]")); assert!(
err.contains("general.me_route_backpressure_high_timeout_ms must be within [1, 5000]")
);
let _ = std::fs::remove_file(path); let _ = std::fs::remove_file(path);
} }

View File

@ -1,9 +1,9 @@
//! Configuration. //! Configuration.
pub(crate) mod defaults; pub(crate) mod defaults;
mod types;
mod load;
pub mod hot_reload; pub mod hot_reload;
mod load;
mod types;
pub use load::ProxyConfig; pub use load::ProxyConfig;
pub use types::*; pub use types::*;

View File

@ -30,7 +30,9 @@ relay_client_idle_hard_secs = 60
let err = ProxyConfig::load(&path).expect_err("config with hard<soft must fail"); let err = ProxyConfig::load(&path).expect_err("config with hard<soft must fail");
let msg = err.to_string(); let msg = err.to_string();
assert!( assert!(
msg.contains("timeouts.relay_client_idle_hard_secs must be >= timeouts.relay_client_idle_soft_secs"), msg.contains(
"timeouts.relay_client_idle_hard_secs must be >= timeouts.relay_client_idle_soft_secs"
),
"error must explain the violated hard>=soft invariant, got: {msg}" "error must explain the violated hard>=soft invariant, got: {msg}"
); );

View File

@ -91,11 +91,13 @@ mask_shape_above_cap_blur_max_bytes = 64
"#, "#,
); );
let err = ProxyConfig::load(&path) let err =
.expect_err("above-cap blur must require shape hardening enabled"); ProxyConfig::load(&path).expect_err("above-cap blur must require shape hardening enabled");
let msg = err.to_string(); let msg = err.to_string();
assert!( assert!(
msg.contains("censorship.mask_shape_above_cap_blur requires censorship.mask_shape_hardening = true"), msg.contains(
"censorship.mask_shape_above_cap_blur requires censorship.mask_shape_hardening = true"
),
"error must explain blur prerequisite, got: {msg}" "error must explain blur prerequisite, got: {msg}"
); );
@ -113,8 +115,8 @@ mask_shape_above_cap_blur_max_bytes = 0
"#, "#,
); );
let err = ProxyConfig::load(&path) let err =
.expect_err("above-cap blur max bytes must be > 0 when enabled"); ProxyConfig::load(&path).expect_err("above-cap blur max bytes must be > 0 when enabled");
let msg = err.to_string(); let msg = err.to_string();
assert!( assert!(
msg.contains("censorship.mask_shape_above_cap_blur_max_bytes must be > 0 when censorship.mask_shape_above_cap_blur is enabled"), msg.contains("censorship.mask_shape_above_cap_blur_max_bytes must be > 0 when censorship.mask_shape_above_cap_blur is enabled"),
@ -135,8 +137,8 @@ mask_timing_normalization_ceiling_ms = 200
"#, "#,
); );
let err = ProxyConfig::load(&path) let err =
.expect_err("timing normalization floor must be > 0 when enabled"); ProxyConfig::load(&path).expect_err("timing normalization floor must be > 0 when enabled");
let msg = err.to_string(); let msg = err.to_string();
assert!( assert!(
msg.contains("censorship.mask_timing_normalization_floor_ms must be > 0 when censorship.mask_timing_normalization_enabled is true"), msg.contains("censorship.mask_timing_normalization_floor_ms must be > 0 when censorship.mask_timing_normalization_enabled is true"),
@ -157,8 +159,7 @@ mask_timing_normalization_ceiling_ms = 200
"#, "#,
); );
let err = ProxyConfig::load(&path) let err = ProxyConfig::load(&path).expect_err("timing normalization ceiling must be >= floor");
.expect_err("timing normalization ceiling must be >= floor");
let msg = err.to_string(); let msg = err.to_string();
assert!( assert!(
msg.contains("censorship.mask_timing_normalization_ceiling_ms must be >= censorship.mask_timing_normalization_floor_ms"), msg.contains("censorship.mask_timing_normalization_ceiling_ms must be >= censorship.mask_timing_normalization_floor_ms"),
@ -193,3 +194,45 @@ mask_timing_normalization_ceiling_ms = 240
remove_temp_config(&path); remove_temp_config(&path);
} }
#[test]
fn load_rejects_aggressive_shape_mode_when_shape_hardening_disabled() {
let path = write_temp_config(
r#"
[censorship]
mask_shape_hardening = false
mask_shape_hardening_aggressive_mode = true
"#,
);
let err = ProxyConfig::load(&path)
.expect_err("aggressive shape hardening mode must require shape hardening enabled");
let msg = err.to_string();
assert!(
msg.contains("censorship.mask_shape_hardening_aggressive_mode requires censorship.mask_shape_hardening = true"),
"error must explain aggressive-mode prerequisite, got: {msg}"
);
remove_temp_config(&path);
}
#[test]
fn load_accepts_aggressive_shape_mode_when_shape_hardening_enabled() {
let path = write_temp_config(
r#"
[censorship]
mask_shape_hardening = true
mask_shape_hardening_aggressive_mode = true
mask_shape_above_cap_blur = true
mask_shape_above_cap_blur_max_bytes = 8
"#,
);
let cfg = ProxyConfig::load(&path)
.expect("aggressive shape hardening mode should be accepted when prerequisites are met");
assert!(cfg.censorship.mask_shape_hardening);
assert!(cfg.censorship.mask_shape_hardening_aggressive_mode);
assert!(cfg.censorship.mask_shape_above_cap_blur);
remove_temp_config(&path);
}

View File

@ -29,11 +29,13 @@ server_hello_delay_max_ms = 1000
"#, "#,
); );
let err = ProxyConfig::load(&path) let err =
.expect_err("delay equal to handshake timeout must be rejected"); ProxyConfig::load(&path).expect_err("delay equal to handshake timeout must be rejected");
let msg = err.to_string(); let msg = err.to_string();
assert!( assert!(
msg.contains("censorship.server_hello_delay_max_ms must be < timeouts.client_handshake * 1000"), msg.contains(
"censorship.server_hello_delay_max_ms must be < timeouts.client_handshake * 1000"
),
"error must explain delay<timeout invariant, got: {msg}" "error must explain delay<timeout invariant, got: {msg}"
); );
@ -52,11 +54,13 @@ server_hello_delay_max_ms = 1500
"#, "#,
); );
let err = ProxyConfig::load(&path) let err =
.expect_err("delay larger than handshake timeout must be rejected"); ProxyConfig::load(&path).expect_err("delay larger than handshake timeout must be rejected");
let msg = err.to_string(); let msg = err.to_string();
assert!( assert!(
msg.contains("censorship.server_hello_delay_max_ms must be < timeouts.client_handshake * 1000"), msg.contains(
"censorship.server_hello_delay_max_ms must be < timeouts.client_handshake * 1000"
),
"error must explain delay<timeout invariant, got: {msg}" "error must explain delay<timeout invariant, got: {msg}"
); );
@ -75,8 +79,8 @@ server_hello_delay_max_ms = 999
"#, "#,
); );
let cfg = ProxyConfig::load(&path) let cfg =
.expect("delay below handshake timeout budget must be accepted"); ProxyConfig::load(&path).expect("delay below handshake timeout budget must be accepted");
assert_eq!(cfg.timeouts.client_handshake, 1); assert_eq!(cfg.timeouts.client_handshake, 1);
assert_eq!(cfg.censorship.server_hello_delay_max_ms, 999); assert_eq!(cfg.censorship.server_hello_delay_max_ms, 999);

View File

@ -1047,8 +1047,7 @@ impl Default for GeneralConfig {
me_pool_drain_soft_evict_per_writer: default_me_pool_drain_soft_evict_per_writer(), me_pool_drain_soft_evict_per_writer: default_me_pool_drain_soft_evict_per_writer(),
me_pool_drain_soft_evict_budget_per_core: me_pool_drain_soft_evict_budget_per_core:
default_me_pool_drain_soft_evict_budget_per_core(), default_me_pool_drain_soft_evict_budget_per_core(),
me_pool_drain_soft_evict_cooldown_ms: me_pool_drain_soft_evict_cooldown_ms: default_me_pool_drain_soft_evict_cooldown_ms(),
default_me_pool_drain_soft_evict_cooldown_ms(),
me_bind_stale_mode: MeBindStaleMode::default(), me_bind_stale_mode: MeBindStaleMode::default(),
me_bind_stale_ttl_secs: default_me_bind_stale_ttl_secs(), me_bind_stale_ttl_secs: default_me_bind_stale_ttl_secs(),
me_pool_min_fresh_ratio: default_me_pool_min_fresh_ratio(), me_pool_min_fresh_ratio: default_me_pool_min_fresh_ratio(),
@ -1418,6 +1417,12 @@ pub struct AntiCensorshipConfig {
#[serde(default = "default_mask_shape_hardening")] #[serde(default = "default_mask_shape_hardening")]
pub mask_shape_hardening: bool, pub mask_shape_hardening: bool,
/// Opt-in aggressive shape hardening mode.
/// When enabled, masking may shape some backend-silent timeout paths and
/// enforces strictly positive above-cap blur when blur is enabled.
#[serde(default = "default_mask_shape_hardening_aggressive_mode")]
pub mask_shape_hardening_aggressive_mode: bool,
/// Minimum bucket size for mask shape hardening padding. /// Minimum bucket size for mask shape hardening padding.
#[serde(default = "default_mask_shape_bucket_floor_bytes")] #[serde(default = "default_mask_shape_bucket_floor_bytes")]
pub mask_shape_bucket_floor_bytes: usize, pub mask_shape_bucket_floor_bytes: usize,
@ -1468,6 +1473,7 @@ impl Default for AntiCensorshipConfig {
alpn_enforce: default_alpn_enforce(), alpn_enforce: default_alpn_enforce(),
mask_proxy_protocol: 0, mask_proxy_protocol: 0,
mask_shape_hardening: default_mask_shape_hardening(), mask_shape_hardening: default_mask_shape_hardening(),
mask_shape_hardening_aggressive_mode: default_mask_shape_hardening_aggressive_mode(),
mask_shape_bucket_floor_bytes: default_mask_shape_bucket_floor_bytes(), mask_shape_bucket_floor_bytes: default_mask_shape_bucket_floor_bytes(),
mask_shape_bucket_cap_bytes: default_mask_shape_bucket_cap_bytes(), mask_shape_bucket_cap_bytes: default_mask_shape_bucket_cap_bytes(),
mask_shape_above_cap_blur: default_mask_shape_above_cap_blur(), mask_shape_above_cap_blur: default_mask_shape_above_cap_blur(),

View File

@ -13,10 +13,13 @@
#![allow(dead_code)] #![allow(dead_code)]
use aes::Aes256;
use ctr::{Ctr128BE, cipher::{KeyIvInit, StreamCipher}};
use zeroize::Zeroize;
use crate::error::{ProxyError, Result}; use crate::error::{ProxyError, Result};
use aes::Aes256;
use ctr::{
Ctr128BE,
cipher::{KeyIvInit, StreamCipher},
};
use zeroize::Zeroize;
type Aes256Ctr = Ctr128BE<Aes256>; type Aes256Ctr = Ctr128BE<Aes256>;
@ -46,10 +49,16 @@ impl AesCtr {
/// Create from key and IV slices /// Create from key and IV slices
pub fn from_key_iv(key: &[u8], iv: &[u8]) -> Result<Self> { pub fn from_key_iv(key: &[u8], iv: &[u8]) -> Result<Self> {
if key.len() != 32 { if key.len() != 32 {
return Err(ProxyError::InvalidKeyLength { expected: 32, got: key.len() }); return Err(ProxyError::InvalidKeyLength {
expected: 32,
got: key.len(),
});
} }
if iv.len() != 16 { if iv.len() != 16 {
return Err(ProxyError::InvalidKeyLength { expected: 16, got: iv.len() }); return Err(ProxyError::InvalidKeyLength {
expected: 16,
got: iv.len(),
});
} }
let key: [u8; 32] = key.try_into().unwrap(); let key: [u8; 32] = key.try_into().unwrap();
@ -108,10 +117,16 @@ impl AesCbc {
/// Create from slices /// Create from slices
pub fn from_slices(key: &[u8], iv: &[u8]) -> Result<Self> { pub fn from_slices(key: &[u8], iv: &[u8]) -> Result<Self> {
if key.len() != 32 { if key.len() != 32 {
return Err(ProxyError::InvalidKeyLength { expected: 32, got: key.len() }); return Err(ProxyError::InvalidKeyLength {
expected: 32,
got: key.len(),
});
} }
if iv.len() != 16 { if iv.len() != 16 {
return Err(ProxyError::InvalidKeyLength { expected: 16, got: iv.len() }); return Err(ProxyError::InvalidKeyLength {
expected: 16,
got: iv.len(),
});
} }
Ok(Self { Ok(Self {
@ -150,9 +165,10 @@ impl AesCbc {
/// CBC Encryption: C[i] = AES_Encrypt(P[i] XOR C[i-1]), where C[-1] = IV /// CBC Encryption: C[i] = AES_Encrypt(P[i] XOR C[i-1]), where C[-1] = IV
pub fn encrypt(&self, data: &[u8]) -> Result<Vec<u8>> { pub fn encrypt(&self, data: &[u8]) -> Result<Vec<u8>> {
if !data.len().is_multiple_of(Self::BLOCK_SIZE) { if !data.len().is_multiple_of(Self::BLOCK_SIZE) {
return Err(ProxyError::Crypto( return Err(ProxyError::Crypto(format!(
format!("CBC data must be aligned to 16 bytes, got {}", data.len()) "CBC data must be aligned to 16 bytes, got {}",
)); data.len()
)));
} }
if data.is_empty() { if data.is_empty() {
@ -181,9 +197,10 @@ impl AesCbc {
/// CBC Decryption: P[i] = AES_Decrypt(C[i]) XOR C[i-1], where C[-1] = IV /// CBC Decryption: P[i] = AES_Decrypt(C[i]) XOR C[i-1], where C[-1] = IV
pub fn decrypt(&self, data: &[u8]) -> Result<Vec<u8>> { pub fn decrypt(&self, data: &[u8]) -> Result<Vec<u8>> {
if !data.len().is_multiple_of(Self::BLOCK_SIZE) { if !data.len().is_multiple_of(Self::BLOCK_SIZE) {
return Err(ProxyError::Crypto( return Err(ProxyError::Crypto(format!(
format!("CBC data must be aligned to 16 bytes, got {}", data.len()) "CBC data must be aligned to 16 bytes, got {}",
)); data.len()
)));
} }
if data.is_empty() { if data.is_empty() {
@ -210,9 +227,10 @@ impl AesCbc {
/// Encrypt data in-place /// Encrypt data in-place
pub fn encrypt_in_place(&self, data: &mut [u8]) -> Result<()> { pub fn encrypt_in_place(&self, data: &mut [u8]) -> Result<()> {
if !data.len().is_multiple_of(Self::BLOCK_SIZE) { if !data.len().is_multiple_of(Self::BLOCK_SIZE) {
return Err(ProxyError::Crypto( return Err(ProxyError::Crypto(format!(
format!("CBC data must be aligned to 16 bytes, got {}", data.len()) "CBC data must be aligned to 16 bytes, got {}",
)); data.len()
)));
} }
if data.is_empty() { if data.is_empty() {
@ -243,9 +261,10 @@ impl AesCbc {
/// Decrypt data in-place /// Decrypt data in-place
pub fn decrypt_in_place(&self, data: &mut [u8]) -> Result<()> { pub fn decrypt_in_place(&self, data: &mut [u8]) -> Result<()> {
if !data.len().is_multiple_of(Self::BLOCK_SIZE) { if !data.len().is_multiple_of(Self::BLOCK_SIZE) {
return Err(ProxyError::Crypto( return Err(ProxyError::Crypto(format!(
format!("CBC data must be aligned to 16 bytes, got {}", data.len()) "CBC data must be aligned to 16 bytes, got {}",
)); data.len()
)));
} }
if data.is_empty() { if data.is_empty() {

View File

@ -12,10 +12,10 @@
//! usages are intentional and protocol-mandated. //! usages are intentional and protocol-mandated.
use hmac::{Hmac, Mac}; use hmac::{Hmac, Mac};
use sha2::Sha256;
use md5::Md5; use md5::Md5;
use sha1::Sha1; use sha1::Sha1;
use sha2::Digest; use sha2::Digest;
use sha2::Sha256;
type HmacSha256 = Hmac<Sha256>; type HmacSha256 = Hmac<Sha256>;
@ -28,8 +28,7 @@ pub fn sha256(data: &[u8]) -> [u8; 32] {
/// SHA-256 HMAC /// SHA-256 HMAC
pub fn sha256_hmac(key: &[u8], data: &[u8]) -> [u8; 32] { pub fn sha256_hmac(key: &[u8], data: &[u8]) -> [u8; 32] {
let mut mac = HmacSha256::new_from_slice(key) let mut mac = HmacSha256::new_from_slice(key).expect("HMAC accepts any key length");
.expect("HMAC accepts any key length");
mac.update(data); mac.update(data);
mac.finalize().into_bytes().into() mac.finalize().into_bytes().into()
} }
@ -124,17 +123,8 @@ pub fn derive_middleproxy_keys(
srv_ipv6: Option<&[u8; 16]>, srv_ipv6: Option<&[u8; 16]>,
) -> ([u8; 32], [u8; 16]) { ) -> ([u8; 32], [u8; 16]) {
let s = build_middleproxy_prekey( let s = build_middleproxy_prekey(
nonce_srv, nonce_srv, nonce_clt, clt_ts, srv_ip, clt_port, purpose, clt_ip, srv_port, secret,
nonce_clt, clt_ipv6, srv_ipv6,
clt_ts,
srv_ip,
clt_port,
purpose,
clt_ip,
srv_port,
secret,
clt_ipv6,
srv_ipv6,
); );
let md5_1 = md5(&s[1..]); let md5_1 = md5(&s[1..]);
@ -164,17 +154,8 @@ mod tests {
let secret = vec![0x55u8; 128]; let secret = vec![0x55u8; 128];
let prekey = build_middleproxy_prekey( let prekey = build_middleproxy_prekey(
&nonce_srv, &nonce_srv, &nonce_clt, &clt_ts, srv_ip, &clt_port, b"CLIENT", clt_ip, &srv_port,
&nonce_clt, &secret, None, None,
&clt_ts,
srv_ip,
&clt_port,
b"CLIENT",
clt_ip,
&srv_port,
&secret,
None,
None,
); );
let digest = sha256(&prekey); let digest = sha256(&prekey);
assert_eq!( assert_eq!(

View File

@ -4,7 +4,7 @@ pub mod aes;
pub mod hash; pub mod hash;
pub mod random; pub mod random;
pub use aes::{AesCtr, AesCbc}; pub use aes::{AesCbc, AesCtr};
pub use hash::{ pub use hash::{
build_middleproxy_prekey, crc32, crc32c, derive_middleproxy_keys, sha256, sha256_hmac, build_middleproxy_prekey, crc32, crc32c, derive_middleproxy_keys, sha256, sha256_hmac,
}; };

View File

@ -3,11 +3,11 @@
#![allow(deprecated)] #![allow(deprecated)]
#![allow(dead_code)] #![allow(dead_code)]
use rand::{Rng, RngExt, SeedableRng};
use rand::rngs::StdRng;
use parking_lot::Mutex;
use zeroize::Zeroize;
use crate::crypto::AesCtr; use crate::crypto::AesCtr;
use parking_lot::Mutex;
use rand::rngs::StdRng;
use rand::{Rng, RngExt, SeedableRng};
use zeroize::Zeroize;
/// Cryptographically secure PRNG with AES-CTR /// Cryptographically secure PRNG with AES-CTR
pub struct SecureRandom { pub struct SecureRandom {

View File

@ -12,28 +12,15 @@ use thiserror::Error;
#[derive(Debug)] #[derive(Debug)]
pub enum StreamError { pub enum StreamError {
/// Partial read: got fewer bytes than expected /// Partial read: got fewer bytes than expected
PartialRead { PartialRead { expected: usize, got: usize },
expected: usize,
got: usize,
},
/// Partial write: wrote fewer bytes than expected /// Partial write: wrote fewer bytes than expected
PartialWrite { PartialWrite { expected: usize, written: usize },
expected: usize,
written: usize,
},
/// Stream is in poisoned state and cannot be used /// Stream is in poisoned state and cannot be used
Poisoned { Poisoned { reason: String },
reason: String,
},
/// Buffer overflow: attempted to buffer more than allowed /// Buffer overflow: attempted to buffer more than allowed
BufferOverflow { BufferOverflow { limit: usize, attempted: usize },
limit: usize,
attempted: usize,
},
/// Invalid frame format /// Invalid frame format
InvalidFrame { InvalidFrame { details: String },
details: String,
},
/// Unexpected end of stream /// Unexpected end of stream
UnexpectedEof, UnexpectedEof,
/// Underlying I/O error /// Underlying I/O error
@ -47,13 +34,21 @@ impl fmt::Display for StreamError {
write!(f, "partial read: expected {} bytes, got {}", expected, got) write!(f, "partial read: expected {} bytes, got {}", expected, got)
} }
Self::PartialWrite { expected, written } => { Self::PartialWrite { expected, written } => {
write!(f, "partial write: expected {} bytes, wrote {}", expected, written) write!(
f,
"partial write: expected {} bytes, wrote {}",
expected, written
)
} }
Self::Poisoned { reason } => { Self::Poisoned { reason } => {
write!(f, "stream poisoned: {}", reason) write!(f, "stream poisoned: {}", reason)
} }
Self::BufferOverflow { limit, attempted } => { Self::BufferOverflow { limit, attempted } => {
write!(f, "buffer overflow: limit {}, attempted {}", limit, attempted) write!(
f,
"buffer overflow: limit {}, attempted {}",
limit, attempted
)
} }
Self::InvalidFrame { details } => { Self::InvalidFrame { details } => {
write!(f, "invalid frame: {}", details) write!(f, "invalid frame: {}", details)
@ -90,9 +85,7 @@ impl From<StreamError> for std::io::Error {
StreamError::UnexpectedEof => { StreamError::UnexpectedEof => {
std::io::Error::new(std::io::ErrorKind::UnexpectedEof, err) std::io::Error::new(std::io::ErrorKind::UnexpectedEof, err)
} }
StreamError::Poisoned { .. } => { StreamError::Poisoned { .. } => std::io::Error::other(err),
std::io::Error::other(err)
}
StreamError::BufferOverflow { .. } => { StreamError::BufferOverflow { .. } => {
std::io::Error::new(std::io::ErrorKind::OutOfMemory, err) std::io::Error::new(std::io::ErrorKind::OutOfMemory, err)
} }
@ -135,7 +128,10 @@ impl Recoverable for StreamError {
} }
fn can_continue(&self) -> bool { fn can_continue(&self) -> bool {
!matches!(self, Self::Poisoned { .. } | Self::UnexpectedEof | Self::BufferOverflow { .. }) !matches!(
self,
Self::Poisoned { .. } | Self::UnexpectedEof | Self::BufferOverflow { .. }
)
} }
} }
@ -165,7 +161,6 @@ impl Recoverable for std::io::Error {
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub enum ProxyError { pub enum ProxyError {
// ============= Crypto Errors ============= // ============= Crypto Errors =============
#[error("Crypto error: {0}")] #[error("Crypto error: {0}")]
Crypto(String), Crypto(String),
@ -173,12 +168,10 @@ pub enum ProxyError {
InvalidKeyLength { expected: usize, got: usize }, InvalidKeyLength { expected: usize, got: usize },
// ============= Stream Errors ============= // ============= Stream Errors =============
#[error("Stream error: {0}")] #[error("Stream error: {0}")]
Stream(#[from] StreamError), Stream(#[from] StreamError),
// ============= Protocol Errors ============= // ============= Protocol Errors =============
#[error("Invalid handshake: {0}")] #[error("Invalid handshake: {0}")]
InvalidHandshake(String), InvalidHandshake(String),
@ -210,7 +203,6 @@ pub enum ProxyError {
TgHandshakeTimeout, TgHandshakeTimeout,
// ============= Network Errors ============= // ============= Network Errors =============
#[error("Connection timeout to {addr}")] #[error("Connection timeout to {addr}")]
ConnectionTimeout { addr: String }, ConnectionTimeout { addr: String },
@ -221,7 +213,6 @@ pub enum ProxyError {
Io(#[from] std::io::Error), Io(#[from] std::io::Error),
// ============= Proxy Protocol Errors ============= // ============= Proxy Protocol Errors =============
#[error("Invalid proxy protocol header")] #[error("Invalid proxy protocol header")]
InvalidProxyProtocol, InvalidProxyProtocol,
@ -229,7 +220,6 @@ pub enum ProxyError {
Proxy(String), Proxy(String),
// ============= Config Errors ============= // ============= Config Errors =============
#[error("Config error: {0}")] #[error("Config error: {0}")]
Config(String), Config(String),
@ -237,7 +227,6 @@ pub enum ProxyError {
InvalidSecret { user: String, reason: String }, InvalidSecret { user: String, reason: String },
// ============= User Errors ============= // ============= User Errors =============
#[error("User {user} expired")] #[error("User {user} expired")]
UserExpired { user: String }, UserExpired { user: String },
@ -254,7 +243,6 @@ pub enum ProxyError {
RateLimited, RateLimited,
// ============= General Errors ============= // ============= General Errors =============
#[error("Internal error: {0}")] #[error("Internal error: {0}")]
Internal(String), Internal(String),
} }
@ -311,7 +299,9 @@ impl<T, R, W> HandshakeResult<T, R, W> {
pub fn map<U, F: FnOnce(T) -> U>(self, f: F) -> HandshakeResult<U, R, W> { pub fn map<U, F: FnOnce(T) -> U>(self, f: F) -> HandshakeResult<U, R, W> {
match self { match self {
HandshakeResult::Success(v) => HandshakeResult::Success(f(v)), HandshakeResult::Success(v) => HandshakeResult::Success(f(v)),
HandshakeResult::BadClient { reader, writer } => HandshakeResult::BadClient { reader, writer }, HandshakeResult::BadClient { reader, writer } => {
HandshakeResult::BadClient { reader, writer }
}
HandshakeResult::Error(e) => HandshakeResult::Error(e), HandshakeResult::Error(e) => HandshakeResult::Error(e),
} }
} }
@ -341,18 +331,35 @@ mod tests {
#[test] #[test]
fn test_stream_error_display() { fn test_stream_error_display() {
let err = StreamError::PartialRead { expected: 100, got: 50 }; let err = StreamError::PartialRead {
expected: 100,
got: 50,
};
assert!(err.to_string().contains("100")); assert!(err.to_string().contains("100"));
assert!(err.to_string().contains("50")); assert!(err.to_string().contains("50"));
let err = StreamError::Poisoned { reason: "test".into() }; let err = StreamError::Poisoned {
reason: "test".into(),
};
assert!(err.to_string().contains("test")); assert!(err.to_string().contains("test"));
} }
#[test] #[test]
fn test_stream_error_recoverable() { fn test_stream_error_recoverable() {
assert!(StreamError::PartialRead { expected: 10, got: 5 }.is_recoverable()); assert!(
assert!(StreamError::PartialWrite { expected: 10, written: 5 }.is_recoverable()); StreamError::PartialRead {
expected: 10,
got: 5
}
.is_recoverable()
);
assert!(
StreamError::PartialWrite {
expected: 10,
written: 5
}
.is_recoverable()
);
assert!(!StreamError::Poisoned { reason: "x".into() }.is_recoverable()); assert!(!StreamError::Poisoned { reason: "x".into() }.is_recoverable());
assert!(!StreamError::UnexpectedEof.is_recoverable()); assert!(!StreamError::UnexpectedEof.is_recoverable());
} }
@ -361,7 +368,13 @@ mod tests {
fn test_stream_error_can_continue() { fn test_stream_error_can_continue() {
assert!(!StreamError::Poisoned { reason: "x".into() }.can_continue()); assert!(!StreamError::Poisoned { reason: "x".into() }.can_continue());
assert!(!StreamError::UnexpectedEof.can_continue()); assert!(!StreamError::UnexpectedEof.can_continue());
assert!(StreamError::PartialRead { expected: 10, got: 5 }.can_continue()); assert!(
StreamError::PartialRead {
expected: 10,
got: 5
}
.can_continue()
);
} }
#[test] #[test]
@ -377,7 +390,10 @@ mod tests {
assert!(success.is_success()); assert!(success.is_success());
assert!(!success.is_bad_client()); assert!(!success.is_bad_client());
let bad: HandshakeResult<i32, (), ()> = HandshakeResult::BadClient { reader: (), writer: () }; let bad: HandshakeResult<i32, (), ()> = HandshakeResult::BadClient {
reader: (),
writer: (),
};
assert!(!bad.is_success()); assert!(!bad.is_success());
assert!(bad.is_bad_client()); assert!(bad.is_bad_client());
} }
@ -404,7 +420,9 @@ mod tests {
#[test] #[test]
fn test_error_display() { fn test_error_display() {
let err = ProxyError::ConnectionTimeout { addr: "1.2.3.4:443".into() }; let err = ProxyError::ConnectionTimeout {
addr: "1.2.3.4:443".into(),
};
assert!(err.to_string().contains("1.2.3.4:443")); assert!(err.to_string().contains("1.2.3.4:443"));
let err = ProxyError::InvalidProxyProtocol; let err = ProxyError::InvalidProxyProtocol;

View File

@ -5,9 +5,9 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::net::IpAddr; use std::net::IpAddr;
use std::sync::Arc; use std::sync::Arc;
use std::sync::Mutex;
use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use std::sync::Mutex;
use tokio::sync::{Mutex as AsyncMutex, RwLock}; use tokio::sync::{Mutex as AsyncMutex, RwLock};
@ -22,7 +22,7 @@ pub struct UserIpTracker {
limit_mode: Arc<RwLock<UserMaxUniqueIpsMode>>, limit_mode: Arc<RwLock<UserMaxUniqueIpsMode>>,
limit_window: Arc<RwLock<Duration>>, limit_window: Arc<RwLock<Duration>>,
last_compact_epoch_secs: Arc<AtomicU64>, last_compact_epoch_secs: Arc<AtomicU64>,
pub(crate) cleanup_queue: Arc<Mutex<Vec<(String, IpAddr)>>>, cleanup_queue: Arc<Mutex<Vec<(String, IpAddr)>>>,
cleanup_drain_lock: Arc<AsyncMutex<()>>, cleanup_drain_lock: Arc<AsyncMutex<()>>,
} }
@ -41,7 +41,6 @@ impl UserIpTracker {
} }
} }
pub fn enqueue_cleanup(&self, user: String, ip: IpAddr) { pub fn enqueue_cleanup(&self, user: String, ip: IpAddr) {
match self.cleanup_queue.lock() { match self.cleanup_queue.lock() {
Ok(mut queue) => queue.push((user, ip)), Ok(mut queue) => queue.push((user, ip)),
@ -58,6 +57,19 @@ impl UserIpTracker {
} }
} }
#[cfg(test)]
pub(crate) fn cleanup_queue_len_for_tests(&self) -> usize {
self.cleanup_queue
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
.len()
}
#[cfg(test)]
pub(crate) fn cleanup_queue_mutex_for_tests(&self) -> Arc<Mutex<Vec<(String, IpAddr)>>> {
Arc::clone(&self.cleanup_queue)
}
pub(crate) async fn drain_cleanup_queue(&self) { pub(crate) async fn drain_cleanup_queue(&self) {
// Serialize queue draining and active-IP mutation so check-and-add cannot // Serialize queue draining and active-IP mutation so check-and-add cannot
// observe stale active entries that are already queued for removal. // observe stale active entries that are already queued for removal.
@ -129,7 +141,8 @@ impl UserIpTracker {
let mut active_ips = self.active_ips.write().await; let mut active_ips = self.active_ips.write().await;
let mut recent_ips = self.recent_ips.write().await; let mut recent_ips = self.recent_ips.write().await;
let mut users = Vec::<String>::with_capacity(active_ips.len().saturating_add(recent_ips.len())); let mut users =
Vec::<String>::with_capacity(active_ips.len().saturating_add(recent_ips.len()));
users.extend(active_ips.keys().cloned()); users.extend(active_ips.keys().cloned());
for user in recent_ips.keys() { for user in recent_ips.keys() {
if !active_ips.contains_key(user) { if !active_ips.contains_key(user) {
@ -138,8 +151,14 @@ impl UserIpTracker {
} }
for user in users { for user in users {
let active_empty = active_ips.get(&user).map(|ips| ips.is_empty()).unwrap_or(true); let active_empty = active_ips
let recent_empty = recent_ips.get(&user).map(|ips| ips.is_empty()).unwrap_or(true); .get(&user)
.map(|ips| ips.is_empty())
.unwrap_or(true);
let recent_empty = recent_ips
.get(&user)
.map(|ips| ips.is_empty())
.unwrap_or(true);
if active_empty && recent_empty { if active_empty && recent_empty {
active_ips.remove(&user); active_ips.remove(&user);
recent_ips.remove(&user); recent_ips.remove(&user);

View File

@ -1,3 +1,5 @@
#![allow(clippy::too_many_arguments)]
use std::sync::Arc; use std::sync::Arc;
use std::time::Instant; use std::time::Instant;
@ -11,10 +13,10 @@ use crate::startup::{
COMPONENT_DC_CONNECTIVITY_PING, COMPONENT_ME_CONNECTIVITY_PING, COMPONENT_RUNTIME_READY, COMPONENT_DC_CONNECTIVITY_PING, COMPONENT_ME_CONNECTIVITY_PING, COMPONENT_RUNTIME_READY,
StartupTracker, StartupTracker,
}; };
use crate::transport::UpstreamManager;
use crate::transport::middle_proxy::{ use crate::transport::middle_proxy::{
MePingFamily, MePingSample, MePool, format_me_route, format_sample_line, run_me_ping, MePingFamily, MePingSample, MePool, format_me_route, format_sample_line, run_me_ping,
}; };
use crate::transport::UpstreamManager;
pub(crate) async fn run_startup_connectivity( pub(crate) async fn run_startup_connectivity(
config: &Arc<ProxyConfig>, config: &Arc<ProxyConfig>,
@ -47,11 +49,15 @@ pub(crate) async fn run_startup_connectivity(
let v4_ok = me_results.iter().any(|r| { let v4_ok = me_results.iter().any(|r| {
matches!(r.family, MePingFamily::V4) matches!(r.family, MePingFamily::V4)
&& r.samples.iter().any(|s| s.error.is_none() && s.handshake_ms.is_some()) && r.samples
.iter()
.any(|s| s.error.is_none() && s.handshake_ms.is_some())
}); });
let v6_ok = me_results.iter().any(|r| { let v6_ok = me_results.iter().any(|r| {
matches!(r.family, MePingFamily::V6) matches!(r.family, MePingFamily::V6)
&& r.samples.iter().any(|s| s.error.is_none() && s.handshake_ms.is_some()) && r.samples
.iter()
.any(|s| s.error.is_none() && s.handshake_ms.is_some())
}); });
info!("================= Telegram ME Connectivity ================="); info!("================= Telegram ME Connectivity =================");
@ -131,8 +137,14 @@ pub(crate) async fn run_startup_connectivity(
.await; .await;
for upstream_result in &ping_results { for upstream_result in &ping_results {
let v6_works = upstream_result.v6_results.iter().any(|r| r.rtt_ms.is_some()); let v6_works = upstream_result
let v4_works = upstream_result.v4_results.iter().any(|r| r.rtt_ms.is_some()); .v6_results
.iter()
.any(|r| r.rtt_ms.is_some());
let v4_works = upstream_result
.v4_results
.iter()
.any(|r| r.rtt_ms.is_some());
if upstream_result.both_available { if upstream_result.both_available {
if prefer_ipv6 { if prefer_ipv6 {

View File

@ -1,5 +1,7 @@
use std::time::Duration; #![allow(clippy::items_after_test_module)]
use std::path::PathBuf; use std::path::PathBuf;
use std::time::Duration;
use tokio::sync::watch; use tokio::sync::watch;
use tracing::{debug, error, info, warn}; use tracing::{debug, error, info, warn};
@ -10,7 +12,10 @@ use crate::transport::middle_proxy::{
ProxyConfigData, fetch_proxy_config_with_raw, load_proxy_config_cache, save_proxy_config_cache, ProxyConfigData, fetch_proxy_config_with_raw, load_proxy_config_cache, save_proxy_config_cache,
}; };
pub(crate) fn resolve_runtime_config_path(config_path_cli: &str, startup_cwd: &std::path::Path) -> PathBuf { pub(crate) fn resolve_runtime_config_path(
config_path_cli: &str,
startup_cwd: &std::path::Path,
) -> PathBuf {
let raw = PathBuf::from(config_path_cli); let raw = PathBuf::from(config_path_cli);
let absolute = if raw.is_absolute() { let absolute = if raw.is_absolute() {
raw raw
@ -50,7 +55,9 @@ pub(crate) fn parse_cli() -> (String, Option<PathBuf>, bool, Option<String>) {
} }
} }
s if s.starts_with("--data-path=") => { s if s.starts_with("--data-path=") => {
data_path = Some(PathBuf::from(s.trim_start_matches("--data-path=").to_string())); data_path = Some(PathBuf::from(
s.trim_start_matches("--data-path=").to_string(),
));
} }
"--silent" | "-s" => { "--silent" | "-s" => {
silent = true; silent = true;
@ -68,7 +75,9 @@ pub(crate) fn parse_cli() -> (String, Option<PathBuf>, bool, Option<String>) {
eprintln!("Usage: telemt [config.toml] [OPTIONS]"); eprintln!("Usage: telemt [config.toml] [OPTIONS]");
eprintln!(); eprintln!();
eprintln!("Options:"); eprintln!("Options:");
eprintln!(" --data-path <DIR> Set data directory (absolute path; overrides config value)"); eprintln!(
" --data-path <DIR> Set data directory (absolute path; overrides config value)"
);
eprintln!(" --silent, -s Suppress info logs"); eprintln!(" --silent, -s Suppress info logs");
eprintln!(" --log-level <LEVEL> debug|verbose|normal|silent"); eprintln!(" --log-level <LEVEL> debug|verbose|normal|silent");
eprintln!(" --help, -h Show this help"); eprintln!(" --help, -h Show this help");
@ -146,7 +155,12 @@ mod tests {
pub(crate) fn print_proxy_links(host: &str, port: u16, config: &ProxyConfig) { pub(crate) fn print_proxy_links(host: &str, port: u16, config: &ProxyConfig) {
info!(target: "telemt::links", "--- Proxy Links ({}) ---", host); info!(target: "telemt::links", "--- Proxy Links ({}) ---", host);
for user_name in config.general.links.show.resolve_users(&config.access.users) { for user_name in config
.general
.links
.show
.resolve_users(&config.access.users)
{
if let Some(secret) = config.access.users.get(user_name) { if let Some(secret) = config.access.users.get(user_name) {
info!(target: "telemt::links", "User: {}", user_name); info!(target: "telemt::links", "User: {}", user_name);
if config.general.modes.classic { if config.general.modes.classic {
@ -287,7 +301,10 @@ pub(crate) async fn load_startup_proxy_config_snapshot(
return Some(cfg); return Some(cfg);
} }
warn!(snapshot = label, url, "Startup proxy-config is empty; trying disk cache"); warn!(
snapshot = label,
url, "Startup proxy-config is empty; trying disk cache"
);
if let Some(path) = cache_path { if let Some(path) = cache_path {
match load_proxy_config_cache(path).await { match load_proxy_config_cache(path).await {
Ok(cached) if !cached.map.is_empty() => { Ok(cached) if !cached.map.is_empty() => {
@ -302,8 +319,7 @@ pub(crate) async fn load_startup_proxy_config_snapshot(
Ok(_) => { Ok(_) => {
warn!( warn!(
snapshot = label, snapshot = label,
path, path, "Startup proxy-config cache is empty; ignoring cache file"
"Startup proxy-config cache is empty; ignoring cache file"
); );
} }
Err(cache_err) => { Err(cache_err) => {
@ -347,8 +363,7 @@ pub(crate) async fn load_startup_proxy_config_snapshot(
Ok(_) => { Ok(_) => {
warn!( warn!(
snapshot = label, snapshot = label,
path, path, "Startup proxy-config cache is empty; ignoring cache file"
"Startup proxy-config cache is empty; ignoring cache file"
); );
} }
Err(cache_err) => { Err(cache_err) => {

View File

@ -12,17 +12,15 @@ use tracing::{debug, error, info, warn};
use crate::config::ProxyConfig; use crate::config::ProxyConfig;
use crate::crypto::SecureRandom; use crate::crypto::SecureRandom;
use crate::ip_tracker::UserIpTracker; use crate::ip_tracker::UserIpTracker;
use crate::proxy::route_mode::{ROUTE_SWITCH_ERROR_MSG, RouteRuntimeController};
use crate::proxy::ClientHandler; use crate::proxy::ClientHandler;
use crate::proxy::route_mode::{ROUTE_SWITCH_ERROR_MSG, RouteRuntimeController};
use crate::startup::{COMPONENT_LISTENERS_BIND, StartupTracker}; use crate::startup::{COMPONENT_LISTENERS_BIND, StartupTracker};
use crate::stats::beobachten::BeobachtenStore; use crate::stats::beobachten::BeobachtenStore;
use crate::stats::{ReplayChecker, Stats}; use crate::stats::{ReplayChecker, Stats};
use crate::stream::BufferPool; use crate::stream::BufferPool;
use crate::tls_front::TlsFrontCache; use crate::tls_front::TlsFrontCache;
use crate::transport::middle_proxy::MePool; use crate::transport::middle_proxy::MePool;
use crate::transport::{ use crate::transport::{ListenOptions, UpstreamManager, create_listener, find_listener_processes};
ListenOptions, UpstreamManager, create_listener, find_listener_processes,
};
use super::helpers::{is_expected_handshake_eof, print_proxy_links}; use super::helpers::{is_expected_handshake_eof, print_proxy_links};
@ -81,8 +79,9 @@ pub(crate) async fn bind_listeners(
Ok(socket) => { Ok(socket) => {
let listener = TcpListener::from_std(socket.into())?; let listener = TcpListener::from_std(socket.into())?;
info!("Listening on {}", addr); info!("Listening on {}", addr);
let listener_proxy_protocol = let listener_proxy_protocol = listener_conf
listener_conf.proxy_protocol.unwrap_or(config.server.proxy_protocol); .proxy_protocol
.unwrap_or(config.server.proxy_protocol);
let public_host = if let Some(ref announce) = listener_conf.announce { let public_host = if let Some(ref announce) = listener_conf.announce {
announce.clone() announce.clone()
@ -100,8 +99,14 @@ pub(crate) async fn bind_listeners(
listener_conf.ip.to_string() listener_conf.ip.to_string()
}; };
if config.general.links.public_host.is_none() && !config.general.links.show.is_empty() { if config.general.links.public_host.is_none()
let link_port = config.general.links.public_port.unwrap_or(config.server.port); && !config.general.links.show.is_empty()
{
let link_port = config
.general
.links
.public_port
.unwrap_or(config.server.port);
print_proxy_links(&public_host, link_port, config); print_proxy_links(&public_host, link_port, config);
} }
@ -145,12 +150,14 @@ pub(crate) async fn bind_listeners(
let (host, port) = if let Some(ref h) = config.general.links.public_host { let (host, port) = if let Some(ref h) = config.general.links.public_host {
( (
h.clone(), h.clone(),
config.general.links.public_port.unwrap_or(config.server.port), config
.general
.links
.public_port
.unwrap_or(config.server.port),
) )
} else { } else {
let ip = detected_ip_v4 let ip = detected_ip_v4.or(detected_ip_v6).map(|ip| ip.to_string());
.or(detected_ip_v6)
.map(|ip| ip.to_string());
if ip.is_none() { if ip.is_none() {
warn!( warn!(
"show_link is configured but public IP could not be detected. Set public_host in config." "show_link is configured but public IP could not be detected. Set public_host in config."
@ -158,7 +165,11 @@ pub(crate) async fn bind_listeners(
} }
( (
ip.unwrap_or_else(|| "UNKNOWN".to_string()), ip.unwrap_or_else(|| "UNKNOWN".to_string()),
config.general.links.public_port.unwrap_or(config.server.port), config
.general
.links
.public_port
.unwrap_or(config.server.port),
) )
}; };
@ -178,13 +189,19 @@ pub(crate) async fn bind_listeners(
use std::os::unix::fs::PermissionsExt; use std::os::unix::fs::PermissionsExt;
let perms = std::fs::Permissions::from_mode(mode); let perms = std::fs::Permissions::from_mode(mode);
if let Err(e) = std::fs::set_permissions(unix_path, perms) { if let Err(e) = std::fs::set_permissions(unix_path, perms) {
error!("Failed to set unix socket permissions to {}: {}", perm_str, e); error!(
"Failed to set unix socket permissions to {}: {}",
perm_str, e
);
} else { } else {
info!("Listening on unix:{} (mode {})", unix_path, perm_str); info!("Listening on unix:{} (mode {})", unix_path, perm_str);
} }
} }
Err(e) => { Err(e) => {
warn!("Invalid listen_unix_sock_perm '{}': {}. Ignoring.", perm_str, e); warn!(
"Invalid listen_unix_sock_perm '{}': {}. Ignoring.",
perm_str, e
);
info!("Listening on unix:{}", unix_path); info!("Listening on unix:{}", unix_path);
} }
} }
@ -218,10 +235,8 @@ pub(crate) async fn bind_listeners(
drop(stream); drop(stream);
continue; continue;
} }
let accept_permit_timeout_ms = config_rx_unix let accept_permit_timeout_ms =
.borrow() config_rx_unix.borrow().server.accept_permit_timeout_ms;
.server
.accept_permit_timeout_ms;
let permit = if accept_permit_timeout_ms == 0 { let permit = if accept_permit_timeout_ms == 0 {
match max_connections_unix.clone().acquire_owned().await { match max_connections_unix.clone().acquire_owned().await {
Ok(permit) => permit, Ok(permit) => permit,
@ -361,10 +376,8 @@ pub(crate) fn spawn_tcp_accept_loops(
drop(stream); drop(stream);
continue; continue;
} }
let accept_permit_timeout_ms = config_rx let accept_permit_timeout_ms =
.borrow() config_rx.borrow().server.accept_permit_timeout_ms;
.server
.accept_permit_timeout_ms;
let permit = if accept_permit_timeout_ms == 0 { let permit = if accept_permit_timeout_ms == 0 {
match max_connections_tcp.clone().acquire_owned().await { match max_connections_tcp.clone().acquire_owned().await {
Ok(permit) => permit, Ok(permit) => permit,

View File

@ -1,3 +1,5 @@
#![allow(clippy::too_many_arguments)]
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
@ -12,8 +14,8 @@ use crate::startup::{
COMPONENT_ME_PROXY_CONFIG_V6, COMPONENT_ME_SECRET_FETCH, StartupMeStatus, StartupTracker, COMPONENT_ME_PROXY_CONFIG_V6, COMPONENT_ME_SECRET_FETCH, StartupMeStatus, StartupTracker,
}; };
use crate::stats::Stats; use crate::stats::Stats;
use crate::transport::middle_proxy::MePool;
use crate::transport::UpstreamManager; use crate::transport::UpstreamManager;
use crate::transport::middle_proxy::MePool;
use super::helpers::load_startup_proxy_config_snapshot; use super::helpers::load_startup_proxy_config_snapshot;
@ -229,8 +231,12 @@ pub(crate) async fn initialize_me_pool(
config.general.me_adaptive_floor_recover_grace_secs, config.general.me_adaptive_floor_recover_grace_secs,
config.general.me_adaptive_floor_writers_per_core_total, config.general.me_adaptive_floor_writers_per_core_total,
config.general.me_adaptive_floor_cpu_cores_override, config.general.me_adaptive_floor_cpu_cores_override,
config.general.me_adaptive_floor_max_extra_writers_single_per_core, config
config.general.me_adaptive_floor_max_extra_writers_multi_per_core, .general
.me_adaptive_floor_max_extra_writers_single_per_core,
config
.general
.me_adaptive_floor_max_extra_writers_multi_per_core,
config.general.me_adaptive_floor_max_active_writers_per_core, config.general.me_adaptive_floor_max_active_writers_per_core,
config.general.me_adaptive_floor_max_warm_writers_per_core, config.general.me_adaptive_floor_max_warm_writers_per_core,
config.general.me_adaptive_floor_max_active_writers_global, config.general.me_adaptive_floor_max_active_writers_global,
@ -473,7 +479,9 @@ pub(crate) async fn initialize_me_pool(
}) })
.await; .await;
match res { match res {
Ok(()) => warn!("me_health_monitor exited unexpectedly, restarting"), Ok(()) => warn!(
"me_health_monitor exited unexpectedly, restarting"
),
Err(e) => { Err(e) => {
error!(error = %e, "me_health_monitor panicked, restarting in 1s"); error!(error = %e, "me_health_monitor panicked, restarting in 1s");
tokio::time::sleep(Duration::from_secs(1)).await; tokio::time::sleep(Duration::from_secs(1)).await;
@ -490,7 +498,9 @@ pub(crate) async fn initialize_me_pool(
}) })
.await; .await;
match res { match res {
Ok(()) => warn!("me_drain_timeout_enforcer exited unexpectedly, restarting"), Ok(()) => warn!(
"me_drain_timeout_enforcer exited unexpectedly, restarting"
),
Err(e) => { Err(e) => {
error!(error = %e, "me_drain_timeout_enforcer panicked, restarting in 1s"); error!(error = %e, "me_drain_timeout_enforcer panicked, restarting in 1s");
tokio::time::sleep(Duration::from_secs(1)).await; tokio::time::sleep(Duration::from_secs(1)).await;
@ -507,7 +517,9 @@ pub(crate) async fn initialize_me_pool(
}) })
.await; .await;
match res { match res {
Ok(()) => warn!("me_zombie_writer_watchdog exited unexpectedly, restarting"), Ok(()) => warn!(
"me_zombie_writer_watchdog exited unexpectedly, restarting"
),
Err(e) => { Err(e) => {
error!(error = %e, "me_zombie_writer_watchdog panicked, restarting in 1s"); error!(error = %e, "me_zombie_writer_watchdog panicked, restarting in 1s");
tokio::time::sleep(Duration::from_secs(1)).await; tokio::time::sleep(Duration::from_secs(1)).await;

View File

@ -11,9 +11,9 @@
// - admission: conditional-cast gate and route mode switching. // - admission: conditional-cast gate and route mode switching.
// - listeners: TCP/Unix listener bind and accept-loop orchestration. // - listeners: TCP/Unix listener bind and accept-loop orchestration.
// - shutdown: graceful shutdown sequence and uptime logging. // - shutdown: graceful shutdown sequence and uptime logging.
mod helpers;
mod admission; mod admission;
mod connectivity; mod connectivity;
mod helpers;
mod listeners; mod listeners;
mod me_startup; mod me_startup;
mod runtime_tasks; mod runtime_tasks;
@ -33,18 +33,18 @@ use crate::crypto::SecureRandom;
use crate::ip_tracker::UserIpTracker; use crate::ip_tracker::UserIpTracker;
use crate::network::probe::{decide_network_capabilities, log_probe_result, run_probe}; use crate::network::probe::{decide_network_capabilities, log_probe_result, run_probe};
use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController}; use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController};
use crate::startup::{
COMPONENT_API_BOOTSTRAP, COMPONENT_CONFIG_LOAD, COMPONENT_ME_POOL_CONSTRUCT,
COMPONENT_ME_POOL_INIT_STAGE1, COMPONENT_ME_PROXY_CONFIG_V4, COMPONENT_ME_PROXY_CONFIG_V6,
COMPONENT_ME_SECRET_FETCH, COMPONENT_NETWORK_PROBE, COMPONENT_TRACING_INIT, StartupMeStatus,
StartupTracker,
};
use crate::stats::beobachten::BeobachtenStore; use crate::stats::beobachten::BeobachtenStore;
use crate::stats::telemetry::TelemetryPolicy; use crate::stats::telemetry::TelemetryPolicy;
use crate::stats::{ReplayChecker, Stats}; use crate::stats::{ReplayChecker, Stats};
use crate::startup::{
COMPONENT_API_BOOTSTRAP, COMPONENT_CONFIG_LOAD,
COMPONENT_ME_POOL_CONSTRUCT, COMPONENT_ME_POOL_INIT_STAGE1,
COMPONENT_ME_PROXY_CONFIG_V4, COMPONENT_ME_PROXY_CONFIG_V6, COMPONENT_ME_SECRET_FETCH,
COMPONENT_NETWORK_PROBE, COMPONENT_TRACING_INIT, StartupMeStatus, StartupTracker,
};
use crate::stream::BufferPool; use crate::stream::BufferPool;
use crate::transport::middle_proxy::MePool;
use crate::transport::UpstreamManager; use crate::transport::UpstreamManager;
use crate::transport::middle_proxy::MePool;
use helpers::{parse_cli, resolve_runtime_config_path}; use helpers::{parse_cli, resolve_runtime_config_path};
/// Runs the full telemt runtime startup pipeline and blocks until shutdown. /// Runs the full telemt runtime startup pipeline and blocks until shutdown.
@ -56,7 +56,10 @@ pub async fn run() -> std::result::Result<(), Box<dyn std::error::Error>> {
.as_secs(); .as_secs();
let startup_tracker = Arc::new(StartupTracker::new(process_started_at_epoch_secs)); let startup_tracker = Arc::new(StartupTracker::new(process_started_at_epoch_secs));
startup_tracker startup_tracker
.start_component(COMPONENT_CONFIG_LOAD, Some("load and validate config".to_string())) .start_component(
COMPONENT_CONFIG_LOAD,
Some("load and validate config".to_string()),
)
.await; .await;
let (config_path_cli, data_path, cli_silent, cli_log_level) = parse_cli(); let (config_path_cli, data_path, cli_silent, cli_log_level) = parse_cli();
let startup_cwd = match std::env::current_dir() { let startup_cwd = match std::env::current_dir() {
@ -77,7 +80,10 @@ pub async fn run() -> std::result::Result<(), Box<dyn std::error::Error>> {
} else { } else {
let default = ProxyConfig::default(); let default = ProxyConfig::default();
std::fs::write(&config_path, toml::to_string_pretty(&default).unwrap()).unwrap(); std::fs::write(&config_path, toml::to_string_pretty(&default).unwrap()).unwrap();
eprintln!("[telemt] Created default config at {}", config_path.display()); eprintln!(
"[telemt] Created default config at {}",
config_path.display()
);
default default
} }
} }
@ -94,24 +100,38 @@ pub async fn run() -> std::result::Result<(), Box<dyn std::error::Error>> {
if let Some(ref data_path) = config.general.data_path { if let Some(ref data_path) = config.general.data_path {
if !data_path.is_absolute() { if !data_path.is_absolute() {
eprintln!("[telemt] data_path must be absolute: {}", data_path.display()); eprintln!(
"[telemt] data_path must be absolute: {}",
data_path.display()
);
std::process::exit(1); std::process::exit(1);
} }
if data_path.exists() { if data_path.exists() {
if !data_path.is_dir() { if !data_path.is_dir() {
eprintln!("[telemt] data_path exists but is not a directory: {}", data_path.display()); eprintln!(
"[telemt] data_path exists but is not a directory: {}",
data_path.display()
);
std::process::exit(1); std::process::exit(1);
} }
} else { } else {
if let Err(e) = std::fs::create_dir_all(data_path) { if let Err(e) = std::fs::create_dir_all(data_path) {
eprintln!("[telemt] Can't create data_path {}: {}", data_path.display(), e); eprintln!(
"[telemt] Can't create data_path {}: {}",
data_path.display(),
e
);
std::process::exit(1); std::process::exit(1);
} }
} }
if let Err(e) = std::env::set_current_dir(data_path) { if let Err(e) = std::env::set_current_dir(data_path) {
eprintln!("[telemt] Can't use data_path {}: {}", data_path.display(), e); eprintln!(
"[telemt] Can't use data_path {}: {}",
data_path.display(),
e
);
std::process::exit(1); std::process::exit(1);
} }
} }
@ -135,7 +155,10 @@ pub async fn run() -> std::result::Result<(), Box<dyn std::error::Error>> {
let (filter_layer, filter_handle) = reload::Layer::new(EnvFilter::new("info")); let (filter_layer, filter_handle) = reload::Layer::new(EnvFilter::new("info"));
startup_tracker startup_tracker
.start_component(COMPONENT_TRACING_INIT, Some("initialize tracing subscriber".to_string())) .start_component(
COMPONENT_TRACING_INIT,
Some("initialize tracing subscriber".to_string()),
)
.await; .await;
// Configure color output based on config // Configure color output based on config
@ -150,7 +173,10 @@ pub async fn run() -> std::result::Result<(), Box<dyn std::error::Error>> {
.with(fmt_layer) .with(fmt_layer)
.init(); .init();
startup_tracker startup_tracker
.complete_component(COMPONENT_TRACING_INIT, Some("tracing initialized".to_string())) .complete_component(
COMPONENT_TRACING_INIT,
Some("tracing initialized".to_string()),
)
.await; .await;
info!("Telemt MTProxy v{}", env!("CARGO_PKG_VERSION")); info!("Telemt MTProxy v{}", env!("CARGO_PKG_VERSION"));
@ -216,7 +242,8 @@ pub async fn run() -> std::result::Result<(), Box<dyn std::error::Error>> {
config.access.user_max_unique_ips_window_secs, config.access.user_max_unique_ips_window_secs,
) )
.await; .await;
if config.access.user_max_unique_ips_global_each > 0 || !config.access.user_max_unique_ips.is_empty() if config.access.user_max_unique_ips_global_each > 0
|| !config.access.user_max_unique_ips.is_empty()
{ {
info!( info!(
global_each_limit = config.access.user_max_unique_ips_global_each, global_each_limit = config.access.user_max_unique_ips_global_each,
@ -243,7 +270,10 @@ pub async fn run() -> std::result::Result<(), Box<dyn std::error::Error>> {
let route_runtime = Arc::new(RouteRuntimeController::new(initial_route_mode)); let route_runtime = Arc::new(RouteRuntimeController::new(initial_route_mode));
let api_me_pool = Arc::new(RwLock::new(None::<Arc<MePool>>)); let api_me_pool = Arc::new(RwLock::new(None::<Arc<MePool>>));
startup_tracker startup_tracker
.start_component(COMPONENT_API_BOOTSTRAP, Some("spawn API listener task".to_string())) .start_component(
COMPONENT_API_BOOTSTRAP,
Some("spawn API listener task".to_string()),
)
.await; .await;
if config.server.api.enabled { if config.server.api.enabled {
@ -326,7 +356,10 @@ pub async fn run() -> std::result::Result<(), Box<dyn std::error::Error>> {
.await; .await;
startup_tracker startup_tracker
.start_component(COMPONENT_NETWORK_PROBE, Some("probe network capabilities".to_string())) .start_component(
COMPONENT_NETWORK_PROBE,
Some("probe network capabilities".to_string()),
)
.await; .await;
let probe = run_probe( let probe = run_probe(
&config.network, &config.network,
@ -339,11 +372,8 @@ pub async fn run() -> std::result::Result<(), Box<dyn std::error::Error>> {
probe.detected_ipv4.map(IpAddr::V4), probe.detected_ipv4.map(IpAddr::V4),
probe.detected_ipv6.map(IpAddr::V6), probe.detected_ipv6.map(IpAddr::V6),
)); ));
let decision = decide_network_capabilities( let decision =
&config.network, decide_network_capabilities(&config.network, &probe, config.general.middle_proxy_nat_ip);
&probe,
config.general.middle_proxy_nat_ip,
);
log_probe_result(&probe, &decision); log_probe_result(&probe, &decision);
startup_tracker startup_tracker
.complete_component( .complete_component(
@ -446,24 +476,16 @@ pub async fn run() -> std::result::Result<(), Box<dyn std::error::Error>> {
// If ME failed to initialize, force direct-only mode. // If ME failed to initialize, force direct-only mode.
if me_pool.is_some() { if me_pool.is_some() {
startup_tracker startup_tracker.set_transport_mode("middle_proxy").await;
.set_transport_mode("middle_proxy") startup_tracker.set_degraded(false).await;
.await;
startup_tracker
.set_degraded(false)
.await;
info!("Transport: Middle-End Proxy - all DC-over-RPC"); info!("Transport: Middle-End Proxy - all DC-over-RPC");
} else { } else {
let _ = use_middle_proxy; let _ = use_middle_proxy;
use_middle_proxy = false; use_middle_proxy = false;
// Make runtime config reflect direct-only mode for handlers. // Make runtime config reflect direct-only mode for handlers.
config.general.use_middle_proxy = false; config.general.use_middle_proxy = false;
startup_tracker startup_tracker.set_transport_mode("direct").await;
.set_transport_mode("direct") startup_tracker.set_degraded(true).await;
.await;
startup_tracker
.set_degraded(true)
.await;
if me2dc_fallback { if me2dc_fallback {
startup_tracker startup_tracker
.set_me_status(StartupMeStatus::Failed, "fallback_to_direct") .set_me_status(StartupMeStatus::Failed, "fallback_to_direct")

View File

@ -4,21 +4,24 @@ use std::sync::Arc;
use tokio::sync::{mpsc, watch}; use tokio::sync::{mpsc, watch};
use tracing::{debug, warn}; use tracing::{debug, warn};
use tracing_subscriber::reload;
use tracing_subscriber::EnvFilter; use tracing_subscriber::EnvFilter;
use tracing_subscriber::reload;
use crate::config::{LogLevel, ProxyConfig};
use crate::config::hot_reload::spawn_config_watcher; use crate::config::hot_reload::spawn_config_watcher;
use crate::config::{LogLevel, ProxyConfig};
use crate::crypto::SecureRandom; use crate::crypto::SecureRandom;
use crate::ip_tracker::UserIpTracker; use crate::ip_tracker::UserIpTracker;
use crate::metrics; use crate::metrics;
use crate::network::probe::NetworkProbe; use crate::network::probe::NetworkProbe;
use crate::startup::{COMPONENT_CONFIG_WATCHER_START, COMPONENT_METRICS_START, COMPONENT_RUNTIME_READY, StartupTracker}; use crate::startup::{
COMPONENT_CONFIG_WATCHER_START, COMPONENT_METRICS_START, COMPONENT_RUNTIME_READY,
StartupTracker,
};
use crate::stats::beobachten::BeobachtenStore; use crate::stats::beobachten::BeobachtenStore;
use crate::stats::telemetry::TelemetryPolicy; use crate::stats::telemetry::TelemetryPolicy;
use crate::stats::{ReplayChecker, Stats}; use crate::stats::{ReplayChecker, Stats};
use crate::transport::middle_proxy::{MePool, MeReinitTrigger};
use crate::transport::UpstreamManager; use crate::transport::UpstreamManager;
use crate::transport::middle_proxy::{MePool, MeReinitTrigger};
use super::helpers::write_beobachten_snapshot; use super::helpers::write_beobachten_snapshot;
@ -79,10 +82,8 @@ pub(crate) async fn spawn_runtime_tasks(
Some("spawn config hot-reload watcher".to_string()), Some("spawn config hot-reload watcher".to_string()),
) )
.await; .await;
let (config_rx, log_level_rx): ( let (config_rx, log_level_rx): (watch::Receiver<Arc<ProxyConfig>>, watch::Receiver<LogLevel>) =
watch::Receiver<Arc<ProxyConfig>>, spawn_config_watcher(
watch::Receiver<LogLevel>,
) = spawn_config_watcher(
config_path.to_path_buf(), config_path.to_path_buf(),
config.clone(), config.clone(),
detected_ip_v4, detected_ip_v4,
@ -114,7 +115,8 @@ pub(crate) async fn spawn_runtime_tasks(
break; break;
} }
let cfg = config_rx_policy.borrow_and_update().clone(); let cfg = config_rx_policy.borrow_and_update().clone();
stats_policy.apply_telemetry_policy(TelemetryPolicy::from_config(&cfg.general.telemetry)); stats_policy
.apply_telemetry_policy(TelemetryPolicy::from_config(&cfg.general.telemetry));
if let Some(pool) = &me_pool_for_policy { if let Some(pool) = &me_pool_for_policy {
pool.update_runtime_transport_policy( pool.update_runtime_transport_policy(
cfg.general.me_socks_kdf_policy, cfg.general.me_socks_kdf_policy,
@ -130,7 +132,11 @@ pub(crate) async fn spawn_runtime_tasks(
let ip_tracker_policy = ip_tracker.clone(); let ip_tracker_policy = ip_tracker.clone();
let mut config_rx_ip_limits = config_rx.clone(); let mut config_rx_ip_limits = config_rx.clone();
tokio::spawn(async move { tokio::spawn(async move {
let mut prev_limits = config_rx_ip_limits.borrow().access.user_max_unique_ips.clone(); let mut prev_limits = config_rx_ip_limits
.borrow()
.access
.user_max_unique_ips
.clone();
let mut prev_global_each = config_rx_ip_limits let mut prev_global_each = config_rx_ip_limits
.borrow() .borrow()
.access .access
@ -183,7 +189,9 @@ pub(crate) async fn spawn_runtime_tasks(
let sleep_secs = cfg.general.beobachten_flush_secs.max(1); let sleep_secs = cfg.general.beobachten_flush_secs.max(1);
if cfg.general.beobachten { if cfg.general.beobachten {
let ttl = std::time::Duration::from_secs(cfg.general.beobachten_minutes.saturating_mul(60)); let ttl = std::time::Duration::from_secs(
cfg.general.beobachten_minutes.saturating_mul(60),
);
let path = cfg.general.beobachten_file.clone(); let path = cfg.general.beobachten_file.clone();
let snapshot = beobachten_writer.snapshot_text(ttl); let snapshot = beobachten_writer.snapshot_text(ttl);
if let Err(e) = write_beobachten_snapshot(&path, &snapshot).await { if let Err(e) = write_beobachten_snapshot(&path, &snapshot).await {
@ -227,7 +235,10 @@ pub(crate) async fn spawn_runtime_tasks(
let config_rx_clone_rot = config_rx.clone(); let config_rx_clone_rot = config_rx.clone();
let reinit_tx_rotation = reinit_tx.clone(); let reinit_tx_rotation = reinit_tx.clone();
tokio::spawn(async move { tokio::spawn(async move {
crate::transport::middle_proxy::me_rotation_task(config_rx_clone_rot, reinit_tx_rotation) crate::transport::middle_proxy::me_rotation_task(
config_rx_clone_rot,
reinit_tx_rotation,
)
.await; .await;
}); });
} }

View File

@ -16,7 +16,10 @@ pub(crate) async fn wait_for_shutdown(process_started_at: Instant, me_pool: Opti
let uptime_secs = process_started_at.elapsed().as_secs(); let uptime_secs = process_started_at.elapsed().as_secs();
info!("Uptime: {}", format_uptime(uptime_secs)); info!("Uptime: {}", format_uptime(uptime_secs));
if let Some(pool) = &me_pool { if let Some(pool) = &me_pool {
match tokio::time::timeout(Duration::from_secs(2), pool.shutdown_send_close_conn_all()) match tokio::time::timeout(
Duration::from_secs(2),
pool.shutdown_send_close_conn_all(),
)
.await .await
{ {
Ok(total) => { Ok(total) => {

View File

@ -7,11 +7,14 @@ mod crypto;
mod error; mod error;
mod ip_tracker; mod ip_tracker;
#[cfg(test)] #[cfg(test)]
#[path = "tests/ip_tracker_regression_tests.rs"]
mod ip_tracker_regression_tests;
#[cfg(test)]
#[path = "tests/ip_tracker_hotpath_adversarial_tests.rs"] #[path = "tests/ip_tracker_hotpath_adversarial_tests.rs"]
mod ip_tracker_hotpath_adversarial_tests; mod ip_tracker_hotpath_adversarial_tests;
#[cfg(test)]
#[path = "tests/ip_tracker_encapsulation_adversarial_tests.rs"]
mod ip_tracker_encapsulation_adversarial_tests;
#[cfg(test)]
#[path = "tests/ip_tracker_regression_tests.rs"]
mod ip_tracker_regression_tests;
mod maestro; mod maestro;
mod metrics; mod metrics;
mod network; mod network;

View File

@ -1,5 +1,5 @@
use std::convert::Infallible;
use std::collections::{BTreeSet, HashMap}; use std::collections::{BTreeSet, HashMap};
use std::convert::Infallible;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
@ -11,12 +11,12 @@ use hyper::service::service_fn;
use hyper::{Request, Response, StatusCode}; use hyper::{Request, Response, StatusCode};
use ipnetwork::IpNetwork; use ipnetwork::IpNetwork;
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tracing::{info, warn, debug}; use tracing::{debug, info, warn};
use crate::config::ProxyConfig; use crate::config::ProxyConfig;
use crate::ip_tracker::UserIpTracker; use crate::ip_tracker::UserIpTracker;
use crate::stats::beobachten::BeobachtenStore;
use crate::stats::Stats; use crate::stats::Stats;
use crate::stats::beobachten::BeobachtenStore;
use crate::transport::{ListenOptions, create_listener}; use crate::transport::{ListenOptions, create_listener};
pub async fn serve( pub async fn serve(
@ -62,7 +62,10 @@ pub async fn serve(
let addr_v4 = SocketAddr::from(([0, 0, 0, 0], port)); let addr_v4 = SocketAddr::from(([0, 0, 0, 0], port));
match bind_metrics_listener(addr_v4, false) { match bind_metrics_listener(addr_v4, false) {
Ok(listener) => { Ok(listener) => {
info!("Metrics endpoint: http://{}/metrics and /beobachten", addr_v4); info!(
"Metrics endpoint: http://{}/metrics and /beobachten",
addr_v4
);
listener_v4 = Some(listener); listener_v4 = Some(listener);
} }
Err(e) => { Err(e) => {
@ -73,7 +76,10 @@ pub async fn serve(
let addr_v6 = SocketAddr::from(([0, 0, 0, 0, 0, 0, 0, 0], port)); let addr_v6 = SocketAddr::from(([0, 0, 0, 0, 0, 0, 0, 0], port));
match bind_metrics_listener(addr_v6, true) { match bind_metrics_listener(addr_v6, true) {
Ok(listener) => { Ok(listener) => {
info!("Metrics endpoint: http://[::]:{}/metrics and /beobachten", port); info!(
"Metrics endpoint: http://[::]:{}/metrics and /beobachten",
port
);
listener_v6 = Some(listener); listener_v6 = Some(listener);
} }
Err(e) => { Err(e) => {
@ -109,12 +115,7 @@ pub async fn serve(
.await; .await;
}); });
serve_listener( serve_listener(
listener4, listener4, stats, beobachten, ip_tracker, config_rx, whitelist,
stats,
beobachten,
ip_tracker,
config_rx,
whitelist,
) )
.await; .await;
} }
@ -231,7 +232,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
let _ = writeln!(out, "# TYPE telemt_uptime_seconds gauge"); let _ = writeln!(out, "# TYPE telemt_uptime_seconds gauge");
let _ = writeln!(out, "telemt_uptime_seconds {:.1}", stats.uptime_secs()); let _ = writeln!(out, "telemt_uptime_seconds {:.1}", stats.uptime_secs());
let _ = writeln!(out, "# HELP telemt_telemetry_core_enabled Runtime core telemetry switch"); let _ = writeln!(
out,
"# HELP telemt_telemetry_core_enabled Runtime core telemetry switch"
);
let _ = writeln!(out, "# TYPE telemt_telemetry_core_enabled gauge"); let _ = writeln!(out, "# TYPE telemt_telemetry_core_enabled gauge");
let _ = writeln!( let _ = writeln!(
out, out,
@ -239,7 +243,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
if core_enabled { 1 } else { 0 } if core_enabled { 1 } else { 0 }
); );
let _ = writeln!(out, "# HELP telemt_telemetry_user_enabled Runtime per-user telemetry switch"); let _ = writeln!(
out,
"# HELP telemt_telemetry_user_enabled Runtime per-user telemetry switch"
);
let _ = writeln!(out, "# TYPE telemt_telemetry_user_enabled gauge"); let _ = writeln!(out, "# TYPE telemt_telemetry_user_enabled gauge");
let _ = writeln!( let _ = writeln!(
out, out,
@ -247,7 +254,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
if user_enabled { 1 } else { 0 } if user_enabled { 1 } else { 0 }
); );
let _ = writeln!(out, "# HELP telemt_telemetry_me_level Runtime ME telemetry level flag"); let _ = writeln!(
out,
"# HELP telemt_telemetry_me_level Runtime ME telemetry level flag"
);
let _ = writeln!(out, "# TYPE telemt_telemetry_me_level gauge"); let _ = writeln!(out, "# TYPE telemt_telemetry_me_level gauge");
let _ = writeln!( let _ = writeln!(
out, out,
@ -277,23 +287,40 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
} }
); );
let _ = writeln!(out, "# HELP telemt_connections_total Total accepted connections"); let _ = writeln!(
out,
"# HELP telemt_connections_total Total accepted connections"
);
let _ = writeln!(out, "# TYPE telemt_connections_total counter"); let _ = writeln!(out, "# TYPE telemt_connections_total counter");
let _ = writeln!( let _ = writeln!(
out, out,
"telemt_connections_total {}", "telemt_connections_total {}",
if core_enabled { stats.get_connects_all() } else { 0 } if core_enabled {
stats.get_connects_all()
} else {
0
}
); );
let _ = writeln!(out, "# HELP telemt_connections_bad_total Bad/rejected connections"); let _ = writeln!(
out,
"# HELP telemt_connections_bad_total Bad/rejected connections"
);
let _ = writeln!(out, "# TYPE telemt_connections_bad_total counter"); let _ = writeln!(out, "# TYPE telemt_connections_bad_total counter");
let _ = writeln!( let _ = writeln!(
out, out,
"telemt_connections_bad_total {}", "telemt_connections_bad_total {}",
if core_enabled { stats.get_connects_bad() } else { 0 } if core_enabled {
stats.get_connects_bad()
} else {
0
}
); );
let _ = writeln!(out, "# HELP telemt_handshake_timeouts_total Handshake timeouts"); let _ = writeln!(
out,
"# HELP telemt_handshake_timeouts_total Handshake timeouts"
);
let _ = writeln!(out, "# TYPE telemt_handshake_timeouts_total counter"); let _ = writeln!(out, "# TYPE telemt_handshake_timeouts_total counter");
let _ = writeln!( let _ = writeln!(
out, out,
@ -372,7 +399,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
out, out,
"# HELP telemt_upstream_connect_attempts_per_request Histogram-like buckets for attempts per upstream connect request cycle" "# HELP telemt_upstream_connect_attempts_per_request Histogram-like buckets for attempts per upstream connect request cycle"
); );
let _ = writeln!(out, "# TYPE telemt_upstream_connect_attempts_per_request counter"); let _ = writeln!(
out,
"# TYPE telemt_upstream_connect_attempts_per_request counter"
);
let _ = writeln!( let _ = writeln!(
out, out,
"telemt_upstream_connect_attempts_per_request{{bucket=\"1\"}} {}", "telemt_upstream_connect_attempts_per_request{{bucket=\"1\"}} {}",
@ -414,7 +444,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
out, out,
"# HELP telemt_upstream_connect_duration_success_total Histogram-like buckets of successful upstream connect cycle duration" "# HELP telemt_upstream_connect_duration_success_total Histogram-like buckets of successful upstream connect cycle duration"
); );
let _ = writeln!(out, "# TYPE telemt_upstream_connect_duration_success_total counter"); let _ = writeln!(
out,
"# TYPE telemt_upstream_connect_duration_success_total counter"
);
let _ = writeln!( let _ = writeln!(
out, out,
"telemt_upstream_connect_duration_success_total{{bucket=\"le_100ms\"}} {}", "telemt_upstream_connect_duration_success_total{{bucket=\"le_100ms\"}} {}",
@ -456,7 +489,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
out, out,
"# HELP telemt_upstream_connect_duration_fail_total Histogram-like buckets of failed upstream connect cycle duration" "# HELP telemt_upstream_connect_duration_fail_total Histogram-like buckets of failed upstream connect cycle duration"
); );
let _ = writeln!(out, "# TYPE telemt_upstream_connect_duration_fail_total counter"); let _ = writeln!(
out,
"# TYPE telemt_upstream_connect_duration_fail_total counter"
);
let _ = writeln!( let _ = writeln!(
out, out,
"telemt_upstream_connect_duration_fail_total{{bucket=\"le_100ms\"}} {}", "telemt_upstream_connect_duration_fail_total{{bucket=\"le_100ms\"}} {}",
@ -494,7 +530,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
} }
); );
let _ = writeln!(out, "# HELP telemt_me_keepalive_sent_total ME keepalive frames sent"); let _ = writeln!(
out,
"# HELP telemt_me_keepalive_sent_total ME keepalive frames sent"
);
let _ = writeln!(out, "# TYPE telemt_me_keepalive_sent_total counter"); let _ = writeln!(out, "# TYPE telemt_me_keepalive_sent_total counter");
let _ = writeln!( let _ = writeln!(
out, out,
@ -506,7 +545,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
} }
); );
let _ = writeln!(out, "# HELP telemt_me_keepalive_failed_total ME keepalive send failures"); let _ = writeln!(
out,
"# HELP telemt_me_keepalive_failed_total ME keepalive send failures"
);
let _ = writeln!(out, "# TYPE telemt_me_keepalive_failed_total counter"); let _ = writeln!(out, "# TYPE telemt_me_keepalive_failed_total counter");
let _ = writeln!( let _ = writeln!(
out, out,
@ -518,7 +560,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
} }
); );
let _ = writeln!(out, "# HELP telemt_me_keepalive_pong_total ME keepalive pong replies"); let _ = writeln!(
out,
"# HELP telemt_me_keepalive_pong_total ME keepalive pong replies"
);
let _ = writeln!(out, "# TYPE telemt_me_keepalive_pong_total counter"); let _ = writeln!(out, "# TYPE telemt_me_keepalive_pong_total counter");
let _ = writeln!( let _ = writeln!(
out, out,
@ -530,7 +575,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
} }
); );
let _ = writeln!(out, "# HELP telemt_me_keepalive_timeout_total ME keepalive ping timeouts"); let _ = writeln!(
out,
"# HELP telemt_me_keepalive_timeout_total ME keepalive ping timeouts"
);
let _ = writeln!(out, "# TYPE telemt_me_keepalive_timeout_total counter"); let _ = writeln!(out, "# TYPE telemt_me_keepalive_timeout_total counter");
let _ = writeln!( let _ = writeln!(
out, out,
@ -546,7 +594,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
out, out,
"# HELP telemt_me_rpc_proxy_req_signal_sent_total Service RPC_PROXY_REQ activity signals sent" "# HELP telemt_me_rpc_proxy_req_signal_sent_total Service RPC_PROXY_REQ activity signals sent"
); );
let _ = writeln!(out, "# TYPE telemt_me_rpc_proxy_req_signal_sent_total counter"); let _ = writeln!(
out,
"# TYPE telemt_me_rpc_proxy_req_signal_sent_total counter"
);
let _ = writeln!( let _ = writeln!(
out, out,
"telemt_me_rpc_proxy_req_signal_sent_total {}", "telemt_me_rpc_proxy_req_signal_sent_total {}",
@ -629,7 +680,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
} }
); );
let _ = writeln!(out, "# HELP telemt_me_reconnect_attempts_total ME reconnect attempts"); let _ = writeln!(
out,
"# HELP telemt_me_reconnect_attempts_total ME reconnect attempts"
);
let _ = writeln!(out, "# TYPE telemt_me_reconnect_attempts_total counter"); let _ = writeln!(out, "# TYPE telemt_me_reconnect_attempts_total counter");
let _ = writeln!( let _ = writeln!(
out, out,
@ -641,7 +695,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
} }
); );
let _ = writeln!(out, "# HELP telemt_me_reconnect_success_total ME reconnect successes"); let _ = writeln!(
out,
"# HELP telemt_me_reconnect_success_total ME reconnect successes"
);
let _ = writeln!(out, "# TYPE telemt_me_reconnect_success_total counter"); let _ = writeln!(out, "# TYPE telemt_me_reconnect_success_total counter");
let _ = writeln!( let _ = writeln!(
out, out,
@ -653,7 +710,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
} }
); );
let _ = writeln!(out, "# HELP telemt_me_handshake_reject_total ME handshake rejects from upstream"); let _ = writeln!(
out,
"# HELP telemt_me_handshake_reject_total ME handshake rejects from upstream"
);
let _ = writeln!(out, "# TYPE telemt_me_handshake_reject_total counter"); let _ = writeln!(out, "# TYPE telemt_me_handshake_reject_total counter");
let _ = writeln!( let _ = writeln!(
out, out,
@ -665,20 +725,25 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
} }
); );
let _ = writeln!(out, "# HELP telemt_me_handshake_error_code_total ME handshake reject errors by code"); let _ = writeln!(
out,
"# HELP telemt_me_handshake_error_code_total ME handshake reject errors by code"
);
let _ = writeln!(out, "# TYPE telemt_me_handshake_error_code_total counter"); let _ = writeln!(out, "# TYPE telemt_me_handshake_error_code_total counter");
if me_allows_normal { if me_allows_normal {
for (error_code, count) in stats.get_me_handshake_error_code_counts() { for (error_code, count) in stats.get_me_handshake_error_code_counts() {
let _ = writeln!( let _ = writeln!(
out, out,
"telemt_me_handshake_error_code_total{{error_code=\"{}\"}} {}", "telemt_me_handshake_error_code_total{{error_code=\"{}\"}} {}",
error_code, error_code, count
count
); );
} }
} }
let _ = writeln!(out, "# HELP telemt_me_reader_eof_total ME reader EOF terminations"); let _ = writeln!(
out,
"# HELP telemt_me_reader_eof_total ME reader EOF terminations"
);
let _ = writeln!(out, "# TYPE telemt_me_reader_eof_total counter"); let _ = writeln!(out, "# TYPE telemt_me_reader_eof_total counter");
let _ = writeln!( let _ = writeln!(
out, out,
@ -780,7 +845,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
} }
); );
let _ = writeln!(out, "# HELP telemt_me_seq_mismatch_total ME sequence mismatches"); let _ = writeln!(
out,
"# HELP telemt_me_seq_mismatch_total ME sequence mismatches"
);
let _ = writeln!(out, "# TYPE telemt_me_seq_mismatch_total counter"); let _ = writeln!(out, "# TYPE telemt_me_seq_mismatch_total counter");
let _ = writeln!( let _ = writeln!(
out, out,
@ -792,7 +860,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
} }
); );
let _ = writeln!(out, "# HELP telemt_me_route_drop_no_conn_total ME route drops: no conn"); let _ = writeln!(
out,
"# HELP telemt_me_route_drop_no_conn_total ME route drops: no conn"
);
let _ = writeln!(out, "# TYPE telemt_me_route_drop_no_conn_total counter"); let _ = writeln!(out, "# TYPE telemt_me_route_drop_no_conn_total counter");
let _ = writeln!( let _ = writeln!(
out, out,
@ -804,8 +875,14 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
} }
); );
let _ = writeln!(out, "# HELP telemt_me_route_drop_channel_closed_total ME route drops: channel closed"); let _ = writeln!(
let _ = writeln!(out, "# TYPE telemt_me_route_drop_channel_closed_total counter"); out,
"# HELP telemt_me_route_drop_channel_closed_total ME route drops: channel closed"
);
let _ = writeln!(
out,
"# TYPE telemt_me_route_drop_channel_closed_total counter"
);
let _ = writeln!( let _ = writeln!(
out, out,
"telemt_me_route_drop_channel_closed_total {}", "telemt_me_route_drop_channel_closed_total {}",
@ -816,7 +893,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
} }
); );
let _ = writeln!(out, "# HELP telemt_me_route_drop_queue_full_total ME route drops: queue full"); let _ = writeln!(
out,
"# HELP telemt_me_route_drop_queue_full_total ME route drops: queue full"
);
let _ = writeln!(out, "# TYPE telemt_me_route_drop_queue_full_total counter"); let _ = writeln!(out, "# TYPE telemt_me_route_drop_queue_full_total counter");
let _ = writeln!( let _ = writeln!(
out, out,
@ -973,7 +1053,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
out, out,
"# HELP telemt_me_writer_pick_mode_switch_total Writer-pick mode switches via runtime updates" "# HELP telemt_me_writer_pick_mode_switch_total Writer-pick mode switches via runtime updates"
); );
let _ = writeln!(out, "# TYPE telemt_me_writer_pick_mode_switch_total counter"); let _ = writeln!(
out,
"# TYPE telemt_me_writer_pick_mode_switch_total counter"
);
let _ = writeln!( let _ = writeln!(
out, out,
"telemt_me_writer_pick_mode_switch_total {}", "telemt_me_writer_pick_mode_switch_total {}",
@ -1023,7 +1106,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
} }
); );
let _ = writeln!(out, "# HELP telemt_me_kdf_drift_total ME KDF input drift detections"); let _ = writeln!(
out,
"# HELP telemt_me_kdf_drift_total ME KDF input drift detections"
);
let _ = writeln!(out, "# TYPE telemt_me_kdf_drift_total counter"); let _ = writeln!(out, "# TYPE telemt_me_kdf_drift_total counter");
let _ = writeln!( let _ = writeln!(
out, out,
@ -1069,7 +1155,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
out, out,
"# HELP telemt_me_hardswap_pending_ttl_expired_total Pending hardswap generations reset by TTL expiration" "# HELP telemt_me_hardswap_pending_ttl_expired_total Pending hardswap generations reset by TTL expiration"
); );
let _ = writeln!(out, "# TYPE telemt_me_hardswap_pending_ttl_expired_total counter"); let _ = writeln!(
out,
"# TYPE telemt_me_hardswap_pending_ttl_expired_total counter"
);
let _ = writeln!( let _ = writeln!(
out, out,
"telemt_me_hardswap_pending_ttl_expired_total {}", "telemt_me_hardswap_pending_ttl_expired_total {}",
@ -1301,10 +1390,7 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
out, out,
"# HELP telemt_me_adaptive_floor_global_cap_raw Runtime raw global adaptive floor cap" "# HELP telemt_me_adaptive_floor_global_cap_raw Runtime raw global adaptive floor cap"
); );
let _ = writeln!( let _ = writeln!(out, "# TYPE telemt_me_adaptive_floor_global_cap_raw gauge");
out,
"# TYPE telemt_me_adaptive_floor_global_cap_raw gauge"
);
let _ = writeln!( let _ = writeln!(
out, out,
"telemt_me_adaptive_floor_global_cap_raw {}", "telemt_me_adaptive_floor_global_cap_raw {}",
@ -1487,7 +1573,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
} }
); );
let _ = writeln!(out, "# HELP telemt_secure_padding_invalid_total Invalid secure frame lengths"); let _ = writeln!(
out,
"# HELP telemt_secure_padding_invalid_total Invalid secure frame lengths"
);
let _ = writeln!(out, "# TYPE telemt_secure_padding_invalid_total counter"); let _ = writeln!(out, "# TYPE telemt_secure_padding_invalid_total counter");
let _ = writeln!( let _ = writeln!(
out, out,
@ -1499,7 +1588,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
} }
); );
let _ = writeln!(out, "# HELP telemt_desync_total Total crypto-desync detections"); let _ = writeln!(
out,
"# HELP telemt_desync_total Total crypto-desync detections"
);
let _ = writeln!(out, "# TYPE telemt_desync_total counter"); let _ = writeln!(out, "# TYPE telemt_desync_total counter");
let _ = writeln!( let _ = writeln!(
out, out,
@ -1511,7 +1603,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
} }
); );
let _ = writeln!(out, "# HELP telemt_desync_full_logged_total Full forensic desync logs emitted"); let _ = writeln!(
out,
"# HELP telemt_desync_full_logged_total Full forensic desync logs emitted"
);
let _ = writeln!(out, "# TYPE telemt_desync_full_logged_total counter"); let _ = writeln!(out, "# TYPE telemt_desync_full_logged_total counter");
let _ = writeln!( let _ = writeln!(
out, out,
@ -1523,7 +1618,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
} }
); );
let _ = writeln!(out, "# HELP telemt_desync_suppressed_total Suppressed desync forensic events"); let _ = writeln!(
out,
"# HELP telemt_desync_suppressed_total Suppressed desync forensic events"
);
let _ = writeln!(out, "# TYPE telemt_desync_suppressed_total counter"); let _ = writeln!(out, "# TYPE telemt_desync_suppressed_total counter");
let _ = writeln!( let _ = writeln!(
out, out,
@ -1535,7 +1633,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
} }
); );
let _ = writeln!(out, "# HELP telemt_desync_frames_bucket_total Desync count by frames_ok bucket"); let _ = writeln!(
out,
"# HELP telemt_desync_frames_bucket_total Desync count by frames_ok bucket"
);
let _ = writeln!(out, "# TYPE telemt_desync_frames_bucket_total counter"); let _ = writeln!(out, "# TYPE telemt_desync_frames_bucket_total counter");
let _ = writeln!( let _ = writeln!(
out, out,
@ -1574,7 +1675,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
} }
); );
let _ = writeln!(out, "# HELP telemt_pool_swap_total Successful ME pool swaps"); let _ = writeln!(
out,
"# HELP telemt_pool_swap_total Successful ME pool swaps"
);
let _ = writeln!(out, "# TYPE telemt_pool_swap_total counter"); let _ = writeln!(out, "# TYPE telemt_pool_swap_total counter");
let _ = writeln!( let _ = writeln!(
out, out,
@ -1586,7 +1690,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
} }
); );
let _ = writeln!(out, "# HELP telemt_pool_drain_active Active draining ME writers"); let _ = writeln!(
out,
"# HELP telemt_pool_drain_active Active draining ME writers"
);
let _ = writeln!(out, "# TYPE telemt_pool_drain_active gauge"); let _ = writeln!(out, "# TYPE telemt_pool_drain_active gauge");
let _ = writeln!( let _ = writeln!(
out, out,
@ -1598,7 +1705,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
} }
); );
let _ = writeln!(out, "# HELP telemt_pool_force_close_total Forced close events for draining writers"); let _ = writeln!(
out,
"# HELP telemt_pool_force_close_total Forced close events for draining writers"
);
let _ = writeln!(out, "# TYPE telemt_pool_force_close_total counter"); let _ = writeln!(out, "# TYPE telemt_pool_force_close_total counter");
let _ = writeln!( let _ = writeln!(
out, out,
@ -1610,7 +1720,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
} }
); );
let _ = writeln!(out, "# HELP telemt_pool_stale_pick_total Stale writer fallback picks for new binds"); let _ = writeln!(
out,
"# HELP telemt_pool_stale_pick_total Stale writer fallback picks for new binds"
);
let _ = writeln!(out, "# TYPE telemt_pool_stale_pick_total counter"); let _ = writeln!(out, "# TYPE telemt_pool_stale_pick_total counter");
let _ = writeln!( let _ = writeln!(
out, out,
@ -1622,7 +1735,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
} }
); );
let _ = writeln!(out, "# HELP telemt_me_writer_removed_total Total ME writer removals"); let _ = writeln!(
out,
"# HELP telemt_me_writer_removed_total Total ME writer removals"
);
let _ = writeln!(out, "# TYPE telemt_me_writer_removed_total counter"); let _ = writeln!(out, "# TYPE telemt_me_writer_removed_total counter");
let _ = writeln!( let _ = writeln!(
out, out,
@ -1638,7 +1754,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
out, out,
"# HELP telemt_me_writer_removed_unexpected_total Unexpected ME writer removals that triggered refill" "# HELP telemt_me_writer_removed_unexpected_total Unexpected ME writer removals that triggered refill"
); );
let _ = writeln!(out, "# TYPE telemt_me_writer_removed_unexpected_total counter"); let _ = writeln!(
out,
"# TYPE telemt_me_writer_removed_unexpected_total counter"
);
let _ = writeln!( let _ = writeln!(
out, out,
"telemt_me_writer_removed_unexpected_total {}", "telemt_me_writer_removed_unexpected_total {}",
@ -1649,7 +1768,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
} }
); );
let _ = writeln!(out, "# HELP telemt_me_refill_triggered_total Immediate ME refill runs started"); let _ = writeln!(
out,
"# HELP telemt_me_refill_triggered_total Immediate ME refill runs started"
);
let _ = writeln!(out, "# TYPE telemt_me_refill_triggered_total counter"); let _ = writeln!(out, "# TYPE telemt_me_refill_triggered_total counter");
let _ = writeln!( let _ = writeln!(
out, out,
@ -1665,7 +1787,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
out, out,
"# HELP telemt_me_refill_skipped_inflight_total Immediate ME refill skips due to inflight dedup" "# HELP telemt_me_refill_skipped_inflight_total Immediate ME refill skips due to inflight dedup"
); );
let _ = writeln!(out, "# TYPE telemt_me_refill_skipped_inflight_total counter"); let _ = writeln!(
out,
"# TYPE telemt_me_refill_skipped_inflight_total counter"
);
let _ = writeln!( let _ = writeln!(
out, out,
"telemt_me_refill_skipped_inflight_total {}", "telemt_me_refill_skipped_inflight_total {}",
@ -1676,7 +1801,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
} }
); );
let _ = writeln!(out, "# HELP telemt_me_refill_failed_total Immediate ME refill failures"); let _ = writeln!(
out,
"# HELP telemt_me_refill_failed_total Immediate ME refill failures"
);
let _ = writeln!(out, "# TYPE telemt_me_refill_failed_total counter"); let _ = writeln!(out, "# TYPE telemt_me_refill_failed_total counter");
let _ = writeln!( let _ = writeln!(
out, out,
@ -1692,7 +1820,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
out, out,
"# HELP telemt_me_writer_restored_same_endpoint_total Refilled ME writer restored on the same endpoint" "# HELP telemt_me_writer_restored_same_endpoint_total Refilled ME writer restored on the same endpoint"
); );
let _ = writeln!(out, "# TYPE telemt_me_writer_restored_same_endpoint_total counter"); let _ = writeln!(
out,
"# TYPE telemt_me_writer_restored_same_endpoint_total counter"
);
let _ = writeln!( let _ = writeln!(
out, out,
"telemt_me_writer_restored_same_endpoint_total {}", "telemt_me_writer_restored_same_endpoint_total {}",
@ -1707,7 +1838,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
out, out,
"# HELP telemt_me_writer_restored_fallback_total Refilled ME writer restored via fallback endpoint" "# HELP telemt_me_writer_restored_fallback_total Refilled ME writer restored via fallback endpoint"
); );
let _ = writeln!(out, "# TYPE telemt_me_writer_restored_fallback_total counter"); let _ = writeln!(
out,
"# TYPE telemt_me_writer_restored_fallback_total counter"
);
let _ = writeln!( let _ = writeln!(
out, out,
"telemt_me_writer_restored_fallback_total {}", "telemt_me_writer_restored_fallback_total {}",
@ -1785,17 +1919,35 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
unresolved_writer_losses unresolved_writer_losses
); );
let _ = writeln!(out, "# HELP telemt_user_connections_total Per-user total connections"); let _ = writeln!(
out,
"# HELP telemt_user_connections_total Per-user total connections"
);
let _ = writeln!(out, "# TYPE telemt_user_connections_total counter"); let _ = writeln!(out, "# TYPE telemt_user_connections_total counter");
let _ = writeln!(out, "# HELP telemt_user_connections_current Per-user active connections"); let _ = writeln!(
out,
"# HELP telemt_user_connections_current Per-user active connections"
);
let _ = writeln!(out, "# TYPE telemt_user_connections_current gauge"); let _ = writeln!(out, "# TYPE telemt_user_connections_current gauge");
let _ = writeln!(out, "# HELP telemt_user_octets_from_client Per-user bytes received"); let _ = writeln!(
out,
"# HELP telemt_user_octets_from_client Per-user bytes received"
);
let _ = writeln!(out, "# TYPE telemt_user_octets_from_client counter"); let _ = writeln!(out, "# TYPE telemt_user_octets_from_client counter");
let _ = writeln!(out, "# HELP telemt_user_octets_to_client Per-user bytes sent"); let _ = writeln!(
out,
"# HELP telemt_user_octets_to_client Per-user bytes sent"
);
let _ = writeln!(out, "# TYPE telemt_user_octets_to_client counter"); let _ = writeln!(out, "# TYPE telemt_user_octets_to_client counter");
let _ = writeln!(out, "# HELP telemt_user_msgs_from_client Per-user messages received"); let _ = writeln!(
out,
"# HELP telemt_user_msgs_from_client Per-user messages received"
);
let _ = writeln!(out, "# TYPE telemt_user_msgs_from_client counter"); let _ = writeln!(out, "# TYPE telemt_user_msgs_from_client counter");
let _ = writeln!(out, "# HELP telemt_user_msgs_to_client Per-user messages sent"); let _ = writeln!(
out,
"# HELP telemt_user_msgs_to_client Per-user messages sent"
);
let _ = writeln!(out, "# TYPE telemt_user_msgs_to_client counter"); let _ = writeln!(out, "# TYPE telemt_user_msgs_to_client counter");
let _ = writeln!( let _ = writeln!(
out, out,
@ -1835,12 +1987,45 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
for entry in stats.iter_user_stats() { for entry in stats.iter_user_stats() {
let user = entry.key(); let user = entry.key();
let s = entry.value(); let s = entry.value();
let _ = writeln!(out, "telemt_user_connections_total{{user=\"{}\"}} {}", user, s.connects.load(std::sync::atomic::Ordering::Relaxed)); let _ = writeln!(
let _ = writeln!(out, "telemt_user_connections_current{{user=\"{}\"}} {}", user, s.curr_connects.load(std::sync::atomic::Ordering::Relaxed)); out,
let _ = writeln!(out, "telemt_user_octets_from_client{{user=\"{}\"}} {}", user, s.octets_from_client.load(std::sync::atomic::Ordering::Relaxed)); "telemt_user_connections_total{{user=\"{}\"}} {}",
let _ = writeln!(out, "telemt_user_octets_to_client{{user=\"{}\"}} {}", user, s.octets_to_client.load(std::sync::atomic::Ordering::Relaxed)); user,
let _ = writeln!(out, "telemt_user_msgs_from_client{{user=\"{}\"}} {}", user, s.msgs_from_client.load(std::sync::atomic::Ordering::Relaxed)); s.connects.load(std::sync::atomic::Ordering::Relaxed)
let _ = writeln!(out, "telemt_user_msgs_to_client{{user=\"{}\"}} {}", user, s.msgs_to_client.load(std::sync::atomic::Ordering::Relaxed)); );
let _ = writeln!(
out,
"telemt_user_connections_current{{user=\"{}\"}} {}",
user,
s.curr_connects.load(std::sync::atomic::Ordering::Relaxed)
);
let _ = writeln!(
out,
"telemt_user_octets_from_client{{user=\"{}\"}} {}",
user,
s.octets_from_client
.load(std::sync::atomic::Ordering::Relaxed)
);
let _ = writeln!(
out,
"telemt_user_octets_to_client{{user=\"{}\"}} {}",
user,
s.octets_to_client
.load(std::sync::atomic::Ordering::Relaxed)
);
let _ = writeln!(
out,
"telemt_user_msgs_from_client{{user=\"{}\"}} {}",
user,
s.msgs_from_client
.load(std::sync::atomic::Ordering::Relaxed)
);
let _ = writeln!(
out,
"telemt_user_msgs_to_client{{user=\"{}\"}} {}",
user,
s.msgs_to_client.load(std::sync::atomic::Ordering::Relaxed)
);
} }
let ip_stats = ip_tracker.get_stats().await; let ip_stats = ip_tracker.get_stats().await;
@ -1858,16 +2043,25 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
.get_recent_counts_for_users(&unique_users_vec) .get_recent_counts_for_users(&unique_users_vec)
.await; .await;
let _ = writeln!(out, "# HELP telemt_user_unique_ips_current Per-user current number of unique active IPs"); let _ = writeln!(
out,
"# HELP telemt_user_unique_ips_current Per-user current number of unique active IPs"
);
let _ = writeln!(out, "# TYPE telemt_user_unique_ips_current gauge"); let _ = writeln!(out, "# TYPE telemt_user_unique_ips_current gauge");
let _ = writeln!( let _ = writeln!(
out, out,
"# HELP telemt_user_unique_ips_recent_window Per-user unique IPs seen in configured observation window" "# HELP telemt_user_unique_ips_recent_window Per-user unique IPs seen in configured observation window"
); );
let _ = writeln!(out, "# TYPE telemt_user_unique_ips_recent_window gauge"); let _ = writeln!(out, "# TYPE telemt_user_unique_ips_recent_window gauge");
let _ = writeln!(out, "# HELP telemt_user_unique_ips_limit Effective per-user unique IP limit (0 means unlimited)"); let _ = writeln!(
out,
"# HELP telemt_user_unique_ips_limit Effective per-user unique IP limit (0 means unlimited)"
);
let _ = writeln!(out, "# TYPE telemt_user_unique_ips_limit gauge"); let _ = writeln!(out, "# TYPE telemt_user_unique_ips_limit gauge");
let _ = writeln!(out, "# HELP telemt_user_unique_ips_utilization Per-user unique IP usage ratio (0 for unlimited)"); let _ = writeln!(
out,
"# HELP telemt_user_unique_ips_utilization Per-user unique IP usage ratio (0 for unlimited)"
);
let _ = writeln!(out, "# TYPE telemt_user_unique_ips_utilization gauge"); let _ = writeln!(out, "# TYPE telemt_user_unique_ips_utilization gauge");
for user in unique_users { for user in unique_users {
@ -1878,29 +2072,34 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
.get(&user) .get(&user)
.copied() .copied()
.filter(|limit| *limit > 0) .filter(|limit| *limit > 0)
.or( .or((config.access.user_max_unique_ips_global_each > 0)
(config.access.user_max_unique_ips_global_each > 0) .then_some(config.access.user_max_unique_ips_global_each))
.then_some(config.access.user_max_unique_ips_global_each),
)
.unwrap_or(0); .unwrap_or(0);
let utilization = if limit > 0 { let utilization = if limit > 0 {
current as f64 / limit as f64 current as f64 / limit as f64
} else { } else {
0.0 0.0
}; };
let _ = writeln!(out, "telemt_user_unique_ips_current{{user=\"{}\"}} {}", user, current); let _ = writeln!(
out,
"telemt_user_unique_ips_current{{user=\"{}\"}} {}",
user, current
);
let _ = writeln!( let _ = writeln!(
out, out,
"telemt_user_unique_ips_recent_window{{user=\"{}\"}} {}", "telemt_user_unique_ips_recent_window{{user=\"{}\"}} {}",
user, user,
recent_counts.get(&user).copied().unwrap_or(0) recent_counts.get(&user).copied().unwrap_or(0)
); );
let _ = writeln!(out, "telemt_user_unique_ips_limit{{user=\"{}\"}} {}", user, limit); let _ = writeln!(
out,
"telemt_user_unique_ips_limit{{user=\"{}\"}} {}",
user, limit
);
let _ = writeln!( let _ = writeln!(
out, out,
"telemt_user_unique_ips_utilization{{user=\"{}\"}} {:.6}", "telemt_user_unique_ips_utilization{{user=\"{}\"}} {:.6}",
user, user, utilization
utilization
); );
} }
} }
@ -1911,8 +2110,8 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use std::net::IpAddr;
use http_body_util::BodyExt; use http_body_util::BodyExt;
use std::net::IpAddr;
#[tokio::test] #[tokio::test]
async fn test_render_metrics_format() { async fn test_render_metrics_format() {
@ -1967,13 +2166,10 @@ mod tests {
assert!(output.contains("telemt_upstream_connect_success_total 1")); assert!(output.contains("telemt_upstream_connect_success_total 1"));
assert!(output.contains("telemt_upstream_connect_fail_total 1")); assert!(output.contains("telemt_upstream_connect_fail_total 1"));
assert!(output.contains("telemt_upstream_connect_failfast_hard_error_total 1")); assert!(output.contains("telemt_upstream_connect_failfast_hard_error_total 1"));
assert!(output.contains("telemt_upstream_connect_attempts_per_request{bucket=\"2\"} 1"));
assert!( assert!(
output.contains("telemt_upstream_connect_attempts_per_request{bucket=\"2\"} 1") output
); .contains("telemt_upstream_connect_duration_success_total{bucket=\"101_500ms\"} 1")
assert!(
output.contains(
"telemt_upstream_connect_duration_success_total{bucket=\"101_500ms\"} 1"
)
); );
assert!( assert!(
output.contains("telemt_upstream_connect_duration_fail_total{bucket=\"gt_1000ms\"} 1") output.contains("telemt_upstream_connect_duration_fail_total{bucket=\"gt_1000ms\"} 1")
@ -2050,9 +2246,10 @@ mod tests {
assert!(output.contains("# TYPE telemt_relay_pressure_evict_total counter")); assert!(output.contains("# TYPE telemt_relay_pressure_evict_total counter"));
assert!(output.contains("# TYPE telemt_relay_protocol_desync_close_total counter")); assert!(output.contains("# TYPE telemt_relay_protocol_desync_close_total counter"));
assert!(output.contains("# TYPE telemt_me_writer_removed_total counter")); assert!(output.contains("# TYPE telemt_me_writer_removed_total counter"));
assert!(output.contains( assert!(
"# TYPE telemt_me_writer_removed_unexpected_minus_restored_total gauge" output
)); .contains("# TYPE telemt_me_writer_removed_unexpected_minus_restored_total gauge")
);
assert!(output.contains("# TYPE telemt_user_unique_ips_current gauge")); assert!(output.contains("# TYPE telemt_user_unique_ips_current gauge"));
assert!(output.contains("# TYPE telemt_user_unique_ips_recent_window gauge")); assert!(output.contains("# TYPE telemt_user_unique_ips_recent_window gauge"));
assert!(output.contains("# TYPE telemt_user_unique_ips_limit gauge")); assert!(output.contains("# TYPE telemt_user_unique_ips_limit gauge"));
@ -2069,14 +2266,17 @@ mod tests {
stats.increment_connects_all(); stats.increment_connects_all();
stats.increment_connects_all(); stats.increment_connects_all();
let req = Request::builder() let req = Request::builder().uri("/metrics").body(()).unwrap();
.uri("/metrics") let resp = handle(req, &stats, &beobachten, &tracker, &config)
.body(()) .await
.unwrap(); .unwrap();
let resp = handle(req, &stats, &beobachten, &tracker, &config).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
let body = resp.into_body().collect().await.unwrap().to_bytes(); let body = resp.into_body().collect().await.unwrap().to_bytes();
assert!(std::str::from_utf8(body.as_ref()).unwrap().contains("telemt_connections_total 3")); assert!(
std::str::from_utf8(body.as_ref())
.unwrap()
.contains("telemt_connections_total 3")
);
config.general.beobachten = true; config.general.beobachten = true;
config.general.beobachten_minutes = 10; config.general.beobachten_minutes = 10;
@ -2085,10 +2285,7 @@ mod tests {
"203.0.113.10".parse::<IpAddr>().unwrap(), "203.0.113.10".parse::<IpAddr>().unwrap(),
Duration::from_secs(600), Duration::from_secs(600),
); );
let req_beob = Request::builder() let req_beob = Request::builder().uri("/beobachten").body(()).unwrap();
.uri("/beobachten")
.body(())
.unwrap();
let resp_beob = handle(req_beob, &stats, &beobachten, &tracker, &config) let resp_beob = handle(req_beob, &stats, &beobachten, &tracker, &config)
.await .await
.unwrap(); .unwrap();
@ -2098,10 +2295,7 @@ mod tests {
assert!(beob_text.contains("[TLS-scanner]")); assert!(beob_text.contains("[TLS-scanner]"));
assert!(beob_text.contains("203.0.113.10-1")); assert!(beob_text.contains("203.0.113.10-1"));
let req404 = Request::builder() let req404 = Request::builder().uri("/other").body(()).unwrap();
.uri("/other")
.body(())
.unwrap();
let resp404 = handle(req404, &stats, &beobachten, &tracker, &config) let resp404 = handle(req404, &stats, &beobachten, &tracker, &config)
.await .await
.unwrap(); .unwrap();

View File

@ -26,9 +26,7 @@ fn parse_ip_spec(ip_spec: &str) -> Result<IpAddr> {
} }
let ip = ip_spec.parse::<IpAddr>().map_err(|_| { let ip = ip_spec.parse::<IpAddr>().map_err(|_| {
ProxyError::Config(format!( ProxyError::Config(format!("network.dns_overrides IP is invalid: '{ip_spec}'"))
"network.dns_overrides IP is invalid: '{ip_spec}'"
))
})?; })?;
if matches!(ip, IpAddr::V6(_)) { if matches!(ip, IpAddr::V6(_)) {
return Err(ProxyError::Config(format!( return Err(ProxyError::Config(format!(
@ -103,9 +101,9 @@ pub fn validate_entries(entries: &[String]) -> Result<()> {
/// Replace runtime DNS overrides with a new validated snapshot. /// Replace runtime DNS overrides with a new validated snapshot.
pub fn install_entries(entries: &[String]) -> Result<()> { pub fn install_entries(entries: &[String]) -> Result<()> {
let parsed = parse_entries(entries)?; let parsed = parse_entries(entries)?;
let mut guard = overrides_store() let mut guard = overrides_store().write().map_err(|_| {
.write() ProxyError::Config("network.dns_overrides runtime lock is poisoned".to_string())
.map_err(|_| ProxyError::Config("network.dns_overrides runtime lock is poisoned".to_string()))?; })?;
*guard = parsed; *guard = parsed;
Ok(()) Ok(())
} }

View File

@ -1,4 +1,5 @@
#![allow(dead_code)] #![allow(dead_code)]
#![allow(clippy::items_after_test_module)]
use std::collections::HashMap; use std::collections::HashMap;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket}; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket};
@ -10,7 +11,9 @@ use tracing::{debug, info, warn};
use crate::config::{NetworkConfig, UpstreamConfig, UpstreamType}; use crate::config::{NetworkConfig, UpstreamConfig, UpstreamType};
use crate::error::Result; use crate::error::Result;
use crate::network::stun::{stun_probe_family_with_bind, DualStunResult, IpFamily, StunProbeResult}; use crate::network::stun::{
DualStunResult, IpFamily, StunProbeResult, stun_probe_family_with_bind,
};
use crate::transport::UpstreamManager; use crate::transport::UpstreamManager;
#[derive(Debug, Clone, Default)] #[derive(Debug, Clone, Default)]
@ -78,12 +81,7 @@ pub async fn run_probe(
warn!("STUN probe is enabled but network.stun_servers is empty"); warn!("STUN probe is enabled but network.stun_servers is empty");
DualStunResult::default() DualStunResult::default()
} else { } else {
probe_stun_servers_parallel( probe_stun_servers_parallel(&servers, stun_nat_probe_concurrency.max(1), None, None)
&servers,
stun_nat_probe_concurrency.max(1),
None,
None,
)
.await .await
} }
} else if nat_probe { } else if nat_probe {
@ -99,7 +97,8 @@ pub async fn run_probe(
let UpstreamType::Direct { let UpstreamType::Direct {
interface, interface,
bind_addresses, bind_addresses,
} = &upstream.upstream_type else { } = &upstream.upstream_type
else {
continue; continue;
}; };
if let Some(addrs) = bind_addresses.as_ref().filter(|v| !v.is_empty()) { if let Some(addrs) = bind_addresses.as_ref().filter(|v| !v.is_empty()) {
@ -199,12 +198,11 @@ pub async fn run_probe(
if nat_probe if nat_probe
&& probe.reflected_ipv4.is_none() && probe.reflected_ipv4.is_none()
&& probe.detected_ipv4.map(is_bogon_v4).unwrap_or(false) && probe.detected_ipv4.map(is_bogon_v4).unwrap_or(false)
&& let Some(public_ip) = detect_public_ipv4_http(&config.http_ip_detect_urls).await
{ {
if let Some(public_ip) = detect_public_ipv4_http(&config.http_ip_detect_urls).await {
probe.reflected_ipv4 = Some(SocketAddr::new(IpAddr::V4(public_ip), 0)); probe.reflected_ipv4 = Some(SocketAddr::new(IpAddr::V4(public_ip), 0));
info!(public_ip = %public_ip, "STUN unavailable, using HTTP public IPv4 fallback"); info!(public_ip = %public_ip, "STUN unavailable, using HTTP public IPv4 fallback");
} }
}
probe.ipv4_nat_detected = match (probe.detected_ipv4, probe.reflected_ipv4) { probe.ipv4_nat_detected = match (probe.detected_ipv4, probe.reflected_ipv4) {
(Some(det), Some(reflected)) => det != reflected.ip(), (Some(det), Some(reflected)) => det != reflected.ip(),
@ -217,12 +215,20 @@ pub async fn run_probe(
probe.ipv4_usable = config.ipv4 probe.ipv4_usable = config.ipv4
&& probe.detected_ipv4.is_some() && probe.detected_ipv4.is_some()
&& (!probe.ipv4_is_bogon || probe.reflected_ipv4.map(|r| !is_bogon(r.ip())).unwrap_or(false)); && (!probe.ipv4_is_bogon
|| probe
.reflected_ipv4
.map(|r| !is_bogon(r.ip()))
.unwrap_or(false));
let ipv6_enabled = config.ipv6.unwrap_or(probe.detected_ipv6.is_some()); let ipv6_enabled = config.ipv6.unwrap_or(probe.detected_ipv6.is_some());
probe.ipv6_usable = ipv6_enabled probe.ipv6_usable = ipv6_enabled
&& probe.detected_ipv6.is_some() && probe.detected_ipv6.is_some()
&& (!probe.ipv6_is_bogon || probe.reflected_ipv6.map(|r| !is_bogon(r.ip())).unwrap_or(false)); && (!probe.ipv6_is_bogon
|| probe
.reflected_ipv6
.map(|r| !is_bogon(r.ip()))
.unwrap_or(false));
Ok(probe) Ok(probe)
} }
@ -280,8 +286,6 @@ async fn probe_stun_servers_parallel(
while next_idx < servers.len() && join_set.len() < concurrency { while next_idx < servers.len() && join_set.len() < concurrency {
let stun_addr = servers[next_idx].clone(); let stun_addr = servers[next_idx].clone();
next_idx += 1; next_idx += 1;
let bind_v4 = bind_v4;
let bind_v6 = bind_v6;
join_set.spawn(async move { join_set.spawn(async move {
let res = timeout(STUN_BATCH_TIMEOUT, async { let res = timeout(STUN_BATCH_TIMEOUT, async {
let v4 = stun_probe_family_with_bind(&stun_addr, IpFamily::V4, bind_v4).await?; let v4 = stun_probe_family_with_bind(&stun_addr, IpFamily::V4, bind_v4).await?;
@ -300,11 +304,15 @@ async fn probe_stun_servers_parallel(
match task { match task {
Ok((stun_addr, Ok(Ok(result)))) => { Ok((stun_addr, Ok(Ok(result)))) => {
if let Some(v4) = result.v4 { if let Some(v4) = result.v4 {
let entry = best_v4_by_ip.entry(v4.reflected_addr.ip()).or_insert((0, v4)); let entry = best_v4_by_ip
.entry(v4.reflected_addr.ip())
.or_insert((0, v4));
entry.0 += 1; entry.0 += 1;
} }
if let Some(v6) = result.v6 { if let Some(v6) = result.v6 {
let entry = best_v6_by_ip.entry(v6.reflected_addr.ip()).or_insert((0, v6)); let entry = best_v6_by_ip
.entry(v6.reflected_addr.ip())
.or_insert((0, v6));
entry.0 += 1; entry.0 += 1;
} }
if result.v4.is_some() || result.v6.is_some() { if result.v4.is_some() || result.v6.is_some() {
@ -324,17 +332,11 @@ async fn probe_stun_servers_parallel(
} }
let mut out = DualStunResult::default(); let mut out = DualStunResult::default();
if let Some((_, best)) = best_v4_by_ip if let Some((_, best)) = best_v4_by_ip.into_values().max_by_key(|(count, _)| *count) {
.into_values()
.max_by_key(|(count, _)| *count)
{
info!("STUN-Quorum reached, IP: {}", best.reflected_addr.ip()); info!("STUN-Quorum reached, IP: {}", best.reflected_addr.ip());
out.v4 = Some(best); out.v4 = Some(best);
} }
if let Some((_, best)) = best_v6_by_ip if let Some((_, best)) = best_v6_by_ip.into_values().max_by_key(|(count, _)| *count) {
.into_values()
.max_by_key(|(count, _)| *count)
{
info!("STUN-Quorum reached, IP: {}", best.reflected_addr.ip()); info!("STUN-Quorum reached, IP: {}", best.reflected_addr.ip());
out.v6 = Some(best); out.v6 = Some(best);
} }
@ -347,7 +349,8 @@ pub fn decide_network_capabilities(
middle_proxy_nat_ip: Option<IpAddr>, middle_proxy_nat_ip: Option<IpAddr>,
) -> NetworkDecision { ) -> NetworkDecision {
let ipv4_dc = config.ipv4 && probe.detected_ipv4.is_some(); let ipv4_dc = config.ipv4 && probe.detected_ipv4.is_some();
let ipv6_dc = config.ipv6.unwrap_or(probe.detected_ipv6.is_some()) && probe.detected_ipv6.is_some(); let ipv6_dc =
config.ipv6.unwrap_or(probe.detected_ipv6.is_some()) && probe.detected_ipv6.is_some();
let nat_ip_v4 = matches!(middle_proxy_nat_ip, Some(IpAddr::V4(_))); let nat_ip_v4 = matches!(middle_proxy_nat_ip, Some(IpAddr::V4(_)));
let nat_ip_v6 = matches!(middle_proxy_nat_ip, Some(IpAddr::V6(_))); let nat_ip_v6 = matches!(middle_proxy_nat_ip, Some(IpAddr::V6(_)));
@ -534,10 +537,26 @@ pub fn is_bogon_v6(ip: Ipv6Addr) -> bool {
pub fn log_probe_result(probe: &NetworkProbe, decision: &NetworkDecision) { pub fn log_probe_result(probe: &NetworkProbe, decision: &NetworkDecision) {
info!( info!(
ipv4 = probe.detected_ipv4.as_ref().map(|v| v.to_string()).unwrap_or_else(|| "-".into()), ipv4 = probe
ipv6 = probe.detected_ipv6.as_ref().map(|v| v.to_string()).unwrap_or_else(|| "-".into()), .detected_ipv4
reflected_v4 = probe.reflected_ipv4.as_ref().map(|v| v.ip().to_string()).unwrap_or_else(|| "-".into()), .as_ref()
reflected_v6 = probe.reflected_ipv6.as_ref().map(|v| v.ip().to_string()).unwrap_or_else(|| "-".into()), .map(|v| v.to_string())
.unwrap_or_else(|| "-".into()),
ipv6 = probe
.detected_ipv6
.as_ref()
.map(|v| v.to_string())
.unwrap_or_else(|| "-".into()),
reflected_v4 = probe
.reflected_ipv4
.as_ref()
.map(|v| v.ip().to_string())
.unwrap_or_else(|| "-".into()),
reflected_v6 = probe
.reflected_ipv6
.as_ref()
.map(|v| v.ip().to_string())
.unwrap_or_else(|| "-".into()),
ipv4_bogon = probe.ipv4_is_bogon, ipv4_bogon = probe.ipv4_is_bogon,
ipv6_bogon = probe.ipv6_is_bogon, ipv6_bogon = probe.ipv6_is_bogon,
ipv4_me = decision.ipv4_me, ipv4_me = decision.ipv4_me,

View File

@ -4,8 +4,8 @@
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::sync::OnceLock; use std::sync::OnceLock;
use tokio::net::{lookup_host, UdpSocket}; use tokio::net::{UdpSocket, lookup_host};
use tokio::time::{timeout, Duration, sleep}; use tokio::time::{Duration, sleep, timeout};
use crate::crypto::SecureRandom; use crate::crypto::SecureRandom;
use crate::error::{ProxyError, Result}; use crate::error::{ProxyError, Result};
@ -41,13 +41,13 @@ pub async fn stun_probe_dual(stun_addr: &str) -> Result<DualStunResult> {
stun_probe_family(stun_addr, IpFamily::V6), stun_probe_family(stun_addr, IpFamily::V6),
); );
Ok(DualStunResult { Ok(DualStunResult { v4: v4?, v6: v6? })
v4: v4?,
v6: v6?,
})
} }
pub async fn stun_probe_family(stun_addr: &str, family: IpFamily) -> Result<Option<StunProbeResult>> { pub async fn stun_probe_family(
stun_addr: &str,
family: IpFamily,
) -> Result<Option<StunProbeResult>> {
stun_probe_family_with_bind(stun_addr, family, None).await stun_probe_family_with_bind(stun_addr, family, None).await
} }
@ -76,13 +76,18 @@ pub async fn stun_probe_family_with_bind(
if let Some(addr) = target_addr { if let Some(addr) = target_addr {
match socket.connect(addr).await { match socket.connect(addr).await {
Ok(()) => {} Ok(()) => {}
Err(e) if family == IpFamily::V6 && matches!( Err(e)
if family == IpFamily::V6
&& matches!(
e.kind(), e.kind(),
std::io::ErrorKind::NetworkUnreachable std::io::ErrorKind::NetworkUnreachable
| std::io::ErrorKind::HostUnreachable | std::io::ErrorKind::HostUnreachable
| std::io::ErrorKind::Unsupported | std::io::ErrorKind::Unsupported
| std::io::ErrorKind::NetworkDown | std::io::ErrorKind::NetworkDown
) => return Ok(None), ) =>
{
return Ok(None);
}
Err(e) => return Err(ProxyError::Proxy(format!("STUN connect failed: {e}"))), Err(e) => return Err(ProxyError::Proxy(format!("STUN connect failed: {e}"))),
} }
} else { } else {
@ -205,7 +210,6 @@ pub async fn stun_probe_family_with_bind(
idx += (alen + 3) & !3; idx += (alen + 3) & !3;
} }
} }
Ok(None) Ok(None)
@ -233,7 +237,11 @@ async fn resolve_stun_addr(stun_addr: &str, family: IpFamily) -> Result<Option<S
.await .await
.map_err(|e| ProxyError::Proxy(format!("STUN resolve failed: {e}")))?; .map_err(|e| ProxyError::Proxy(format!("STUN resolve failed: {e}")))?;
let target = addrs let target = addrs.find(|a| {
.find(|a| matches!((a.is_ipv4(), family), (true, IpFamily::V4) | (false, IpFamily::V6))); matches!(
(a.is_ipv4(), family),
(true, IpFamily::V4) | (false, IpFamily::V6)
)
});
Ok(target) Ok(target)
} }

View File

@ -36,32 +36,86 @@ pub static TG_DATACENTERS_V6: LazyLock<Vec<IpAddr>> = LazyLock::new(|| {
pub static TG_MIDDLE_PROXIES_V4: LazyLock<std::collections::HashMap<i32, Vec<(IpAddr, u16)>>> = pub static TG_MIDDLE_PROXIES_V4: LazyLock<std::collections::HashMap<i32, Vec<(IpAddr, u16)>>> =
LazyLock::new(|| { LazyLock::new(|| {
let mut m = std::collections::HashMap::new(); let mut m = std::collections::HashMap::new();
m.insert(1, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), 8888)]); m.insert(
m.insert(-1, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), 8888)]); 1,
m.insert(2, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 161, 144)), 8888)]); vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), 8888)],
m.insert(-2, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 161, 144)), 8888)]); );
m.insert(3, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 100)), 8888)]); m.insert(
m.insert(-3, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 100)), 8888)]); -1,
vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), 8888)],
);
m.insert(
2,
vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 161, 144)), 8888)],
);
m.insert(
-2,
vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 161, 144)), 8888)],
);
m.insert(
3,
vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 100)), 8888)],
);
m.insert(
-3,
vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 100)), 8888)],
);
m.insert(4, vec![(IpAddr::V4(Ipv4Addr::new(91, 108, 4, 136)), 8888)]); m.insert(4, vec![(IpAddr::V4(Ipv4Addr::new(91, 108, 4, 136)), 8888)]);
m.insert(-4, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 165, 109)), 8888)]); m.insert(
-4,
vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 165, 109)), 8888)],
);
m.insert(5, vec![(IpAddr::V4(Ipv4Addr::new(91, 108, 56, 183)), 8888)]); m.insert(5, vec![(IpAddr::V4(Ipv4Addr::new(91, 108, 56, 183)), 8888)]);
m.insert(-5, vec![(IpAddr::V4(Ipv4Addr::new(91, 108, 56, 183)), 8888)]); m.insert(
-5,
vec![(IpAddr::V4(Ipv4Addr::new(91, 108, 56, 183)), 8888)],
);
m m
}); });
pub static TG_MIDDLE_PROXIES_V6: LazyLock<std::collections::HashMap<i32, Vec<(IpAddr, u16)>>> = pub static TG_MIDDLE_PROXIES_V6: LazyLock<std::collections::HashMap<i32, Vec<(IpAddr, u16)>>> =
LazyLock::new(|| { LazyLock::new(|| {
let mut m = std::collections::HashMap::new(); let mut m = std::collections::HashMap::new();
m.insert(1, vec![(IpAddr::V6("2001:b28:f23d:f001::d".parse().unwrap()), 8888)]); m.insert(
m.insert(-1, vec![(IpAddr::V6("2001:b28:f23d:f001::d".parse().unwrap()), 8888)]); 1,
m.insert(2, vec![(IpAddr::V6("2001:67c:04e8:f002::d".parse().unwrap()), 80)]); vec![(IpAddr::V6("2001:b28:f23d:f001::d".parse().unwrap()), 8888)],
m.insert(-2, vec![(IpAddr::V6("2001:67c:04e8:f002::d".parse().unwrap()), 80)]); );
m.insert(3, vec![(IpAddr::V6("2001:b28:f23d:f003::d".parse().unwrap()), 8888)]); m.insert(
m.insert(-3, vec![(IpAddr::V6("2001:b28:f23d:f003::d".parse().unwrap()), 8888)]); -1,
m.insert(4, vec![(IpAddr::V6("2001:67c:04e8:f004::d".parse().unwrap()), 8888)]); vec![(IpAddr::V6("2001:b28:f23d:f001::d".parse().unwrap()), 8888)],
m.insert(-4, vec![(IpAddr::V6("2001:67c:04e8:f004::d".parse().unwrap()), 8888)]); );
m.insert(5, vec![(IpAddr::V6("2001:b28:f23f:f005::d".parse().unwrap()), 8888)]); m.insert(
m.insert(-5, vec![(IpAddr::V6("2001:b28:f23f:f005::d".parse().unwrap()), 8888)]); 2,
vec![(IpAddr::V6("2001:67c:04e8:f002::d".parse().unwrap()), 80)],
);
m.insert(
-2,
vec![(IpAddr::V6("2001:67c:04e8:f002::d".parse().unwrap()), 80)],
);
m.insert(
3,
vec![(IpAddr::V6("2001:b28:f23d:f003::d".parse().unwrap()), 8888)],
);
m.insert(
-3,
vec![(IpAddr::V6("2001:b28:f23d:f003::d".parse().unwrap()), 8888)],
);
m.insert(
4,
vec![(IpAddr::V6("2001:67c:04e8:f004::d".parse().unwrap()), 8888)],
);
m.insert(
-4,
vec![(IpAddr::V6("2001:67c:04e8:f004::d".parse().unwrap()), 8888)],
);
m.insert(
5,
vec![(IpAddr::V6("2001:b28:f23f:f005::d".parse().unwrap()), 8888)],
);
m.insert(
-5,
vec![(IpAddr::V6("2001:b28:f23f:f005::d".parse().unwrap()), 8888)],
);
m m
}); });
@ -222,9 +276,7 @@ pub const SMALL_BUFFER_SIZE: usize = 8192;
// ============= Statistics ============= // ============= Statistics =============
/// Duration buckets for histogram metrics /// Duration buckets for histogram metrics
pub static DURATION_BUCKETS: &[f64] = &[ pub static DURATION_BUCKETS: &[f64] = &[0.1, 0.5, 1.0, 2.0, 5.0, 15.0, 60.0, 300.0, 600.0, 1800.0];
0.1, 0.5, 1.0, 2.0, 5.0, 15.0, 60.0, 300.0, 600.0, 1800.0,
];
// ============= Reserved Nonce Patterns ============= // ============= Reserved Nonce Patterns =============
@ -242,9 +294,7 @@ pub static RESERVED_NONCE_BEGINNINGS: &[[u8; 4]] = &[
]; ];
/// Reserved continuation bytes (bytes 4-7) /// Reserved continuation bytes (bytes 4-7)
pub static RESERVED_NONCE_CONTINUES: &[[u8; 4]] = &[ pub static RESERVED_NONCE_CONTINUES: &[[u8; 4]] = &[[0x00, 0x00, 0x00, 0x00]];
[0x00, 0x00, 0x00, 0x00],
];
// ============= RPC Constants (for Middle Proxy) ============= // ============= RPC Constants (for Middle Proxy) =============
@ -285,11 +335,10 @@ pub mod rpc_flags {
pub const FLAG_QUICKACK: u32 = 0x80000000; pub const FLAG_QUICKACK: u32 = 0x80000000;
} }
// ============= Middle-End Proxy Servers =============
pub const ME_PROXY_PORT: u16 = 8888;
// ============= Middle-End Proxy Servers ============= pub static TG_MIDDLE_PROXIES_FLAT_V4: LazyLock<Vec<(IpAddr, u16)>> = LazyLock::new(|| {
pub const ME_PROXY_PORT: u16 = 8888;
pub static TG_MIDDLE_PROXIES_FLAT_V4: LazyLock<Vec<(IpAddr, u16)>> = LazyLock::new(|| {
vec![ vec![
(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), 8888), (IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), 8888),
(IpAddr::V4(Ipv4Addr::new(149, 154, 161, 144)), 8888), (IpAddr::V4(Ipv4Addr::new(149, 154, 161, 144)), 8888),
@ -297,29 +346,29 @@ pub mod rpc_flags {
(IpAddr::V4(Ipv4Addr::new(91, 108, 4, 136)), 8888), (IpAddr::V4(Ipv4Addr::new(91, 108, 4, 136)), 8888),
(IpAddr::V4(Ipv4Addr::new(91, 108, 56, 183)), 8888), (IpAddr::V4(Ipv4Addr::new(91, 108, 56, 183)), 8888),
] ]
}); });
// ============= RPC Constants (u32 native endian) ============= // ============= RPC Constants (u32 native endian) =============
// From mtproto-common.h + net-tcp-rpc-common.h + mtproto-proxy.c // From mtproto-common.h + net-tcp-rpc-common.h + mtproto-proxy.c
pub const RPC_NONCE_U32: u32 = 0x7acb87aa; pub const RPC_NONCE_U32: u32 = 0x7acb87aa;
pub const RPC_HANDSHAKE_U32: u32 = 0x7682eef5; pub const RPC_HANDSHAKE_U32: u32 = 0x7682eef5;
pub const RPC_HANDSHAKE_ERROR_U32: u32 = 0x6a27beda; pub const RPC_HANDSHAKE_ERROR_U32: u32 = 0x6a27beda;
pub const TL_PROXY_TAG_U32: u32 = 0xdb1e26ae; // mtproto-proxy.c:121 pub const TL_PROXY_TAG_U32: u32 = 0xdb1e26ae; // mtproto-proxy.c:121
// mtproto-common.h // mtproto-common.h
pub const RPC_PROXY_REQ_U32: u32 = 0x36cef1ee; pub const RPC_PROXY_REQ_U32: u32 = 0x36cef1ee;
pub const RPC_PROXY_ANS_U32: u32 = 0x4403da0d; pub const RPC_PROXY_ANS_U32: u32 = 0x4403da0d;
pub const RPC_CLOSE_CONN_U32: u32 = 0x1fcf425d; pub const RPC_CLOSE_CONN_U32: u32 = 0x1fcf425d;
pub const RPC_CLOSE_EXT_U32: u32 = 0x5eb634a2; pub const RPC_CLOSE_EXT_U32: u32 = 0x5eb634a2;
pub const RPC_SIMPLE_ACK_U32: u32 = 0x3bac409b; pub const RPC_SIMPLE_ACK_U32: u32 = 0x3bac409b;
pub const RPC_PING_U32: u32 = 0x5730a2df; pub const RPC_PING_U32: u32 = 0x5730a2df;
pub const RPC_PONG_U32: u32 = 0x8430eaa7; pub const RPC_PONG_U32: u32 = 0x8430eaa7;
pub const RPC_CRYPTO_NONE_U32: u32 = 0; pub const RPC_CRYPTO_NONE_U32: u32 = 0;
pub const RPC_CRYPTO_AES_U32: u32 = 1; pub const RPC_CRYPTO_AES_U32: u32 = 1;
pub mod proxy_flags { pub mod proxy_flags {
pub const FLAG_HAS_AD_TAG: u32 = 1; pub const FLAG_HAS_AD_TAG: u32 = 1;
pub const FLAG_NOT_ENCRYPTED: u32 = 0x2; pub const FLAG_NOT_ENCRYPTED: u32 = 0x2;
pub const FLAG_HAS_AD_TAG2: u32 = 0x8; pub const FLAG_HAS_AD_TAG2: u32 = 0x8;
@ -329,20 +378,20 @@ pub mod rpc_flags {
pub const FLAG_INTERMEDIATE: u32 = 0x20000000; pub const FLAG_INTERMEDIATE: u32 = 0x20000000;
pub const FLAG_ABRIDGED: u32 = 0x40000000; pub const FLAG_ABRIDGED: u32 = 0x40000000;
pub const FLAG_QUICKACK: u32 = 0x80000000; pub const FLAG_QUICKACK: u32 = 0x80000000;
} }
pub mod rpc_crypto_flags { pub mod rpc_crypto_flags {
pub const USE_CRC32C: u32 = 0x800; pub const USE_CRC32C: u32 = 0x800;
} }
pub const ME_CONNECT_TIMEOUT_SECS: u64 = 5; pub const ME_CONNECT_TIMEOUT_SECS: u64 = 5;
pub const ME_HANDSHAKE_TIMEOUT_SECS: u64 = 10; pub const ME_HANDSHAKE_TIMEOUT_SECS: u64 = 10;
#[cfg(test)] #[cfg(test)]
#[path = "tests/tls_size_constants_security_tests.rs"] #[path = "tests/tls_size_constants_security_tests.rs"]
mod tls_size_constants_security_tests; mod tls_size_constants_security_tests;
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;

View File

@ -83,7 +83,7 @@ impl FrameMode {
/// Validate message length for MTProto /// Validate message length for MTProto
pub fn validate_message_length(len: usize) -> bool { pub fn validate_message_length(len: usize) -> bool {
use super::constants::{MIN_MSG_LEN, MAX_MSG_LEN, PADDING_FILLER}; use super::constants::{MAX_MSG_LEN, MIN_MSG_LEN, PADDING_FILLER};
(MIN_MSG_LEN..=MAX_MSG_LEN).contains(&len) && len.is_multiple_of(PADDING_FILLER.len()) (MIN_MSG_LEN..=MAX_MSG_LEN).contains(&len) && len.is_multiple_of(PADDING_FILLER.len())
} }

View File

@ -2,9 +2,9 @@
#![allow(dead_code)] #![allow(dead_code)]
use zeroize::Zeroize;
use crate::crypto::{sha256, AesCtr};
use super::constants::*; use super::constants::*;
use crate::crypto::{AesCtr, sha256};
use zeroize::Zeroize;
/// Obfuscation parameters from handshake /// Obfuscation parameters from handshake
/// ///
@ -69,9 +69,8 @@ impl ObfuscationParams {
None => continue, None => continue,
}; };
let dc_idx = i16::from_le_bytes( let dc_idx =
decrypted[DC_IDX_POS..DC_IDX_POS + 2].try_into().unwrap() i16::from_le_bytes(decrypted[DC_IDX_POS..DC_IDX_POS + 2].try_into().unwrap());
);
let mut enc_key_input = Vec::with_capacity(PREKEY_LEN + secret.len()); let mut enc_key_input = Vec::with_capacity(PREKEY_LEN + secret.len());
enc_key_input.extend_from_slice(enc_prekey); enc_key_input.extend_from_slice(enc_prekey);

View File

@ -1,6 +1,6 @@
use super::*; use super::*;
use std::time::Instant;
use crate::crypto::sha256_hmac; use crate::crypto::sha256_hmac;
use std::time::Instant;
/// Helper to create a byte vector of specific length. /// Helper to create a byte vector of specific length.
fn make_garbage(len: usize) -> Vec<u8> { fn make_garbage(len: usize) -> Vec<u8> {
@ -33,8 +33,7 @@ fn make_valid_tls_handshake_with_session_id(
let digest = make_digest(secret, &handshake, timestamp); let digest = make_digest(secret, &handshake, timestamp);
handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN] handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].copy_from_slice(&digest);
.copy_from_slice(&digest);
handshake handshake
} }
@ -161,7 +160,10 @@ fn extract_sni_with_invalid_hostname_rejected() {
h.extend_from_slice(&(ext.len() as u16).to_be_bytes()); h.extend_from_slice(&(ext.len() as u16).to_be_bytes());
h.extend_from_slice(&ext); h.extend_from_slice(&ext);
assert!(extract_sni_from_client_hello(&h).is_none(), "Invalid SNI hostname must be rejected"); assert!(
extract_sni_from_client_hello(&h).is_none(),
"Invalid SNI hostname must be rejected"
);
} }
// ------------------------------------------------------------------ // ------------------------------------------------------------------
@ -323,7 +325,9 @@ fn extract_alpn_with_malformed_list_rejected() {
ext.extend_from_slice(&(alpn_payload.len() as u16).to_be_bytes()); ext.extend_from_slice(&(alpn_payload.len() as u16).to_be_bytes());
ext.extend_from_slice(&alpn_payload); ext.extend_from_slice(&alpn_payload);
let mut h = vec![0x16, 0x03, 0x03, 0x00, 0x40, 0x01, 0x00, 0x00, 0x3C, 0x03, 0x03]; let mut h = vec![
0x16, 0x03, 0x03, 0x00, 0x40, 0x01, 0x00, 0x00, 0x3C, 0x03, 0x03,
];
h.extend_from_slice(&[0u8; 32]); h.extend_from_slice(&[0u8; 32]);
h.push(0); h.push(0);
h.extend_from_slice(&[0x00, 0x02, 0x13, 0x01, 0x01, 0x00]); h.extend_from_slice(&[0x00, 0x02, 0x13, 0x01, 0x01, 0x00]);
@ -331,7 +335,10 @@ fn extract_alpn_with_malformed_list_rejected() {
h.extend_from_slice(&ext); h.extend_from_slice(&ext);
let res = extract_alpn_from_client_hello(&h); let res = extract_alpn_from_client_hello(&h);
assert!(res.is_empty(), "Malformed ALPN list must return empty or fail"); assert!(
res.is_empty(),
"Malformed ALPN list must return empty or fail"
);
} }
#[test] #[test]

View File

@ -84,7 +84,10 @@ fn make_valid_client_hello_record(host: &str, alpn_protocols: &[&[u8]]) -> Vec<u
#[test] #[test]
fn client_hello_fuzz_corpus_never_panics_or_accepts_corruption() { fn client_hello_fuzz_corpus_never_panics_or_accepts_corruption() {
let valid = make_valid_client_hello_record("example.com", &[b"h2", b"http/1.1"]); let valid = make_valid_client_hello_record("example.com", &[b"h2", b"http/1.1"]);
assert_eq!(extract_sni_from_client_hello(&valid).as_deref(), Some("example.com")); assert_eq!(
extract_sni_from_client_hello(&valid).as_deref(),
Some("example.com")
);
assert_eq!( assert_eq!(
extract_alpn_from_client_hello(&valid), extract_alpn_from_client_hello(&valid),
vec![b"h2".to_vec(), b"http/1.1".to_vec()] vec![b"h2".to_vec(), b"http/1.1".to_vec()]
@ -121,8 +124,14 @@ fn client_hello_fuzz_corpus_never_panics_or_accepts_corruption() {
continue; continue;
} }
assert!(extract_sni_from_client_hello(input).is_none(), "corpus item {idx} must fail closed for SNI"); assert!(
assert!(extract_alpn_from_client_hello(input).is_empty(), "corpus item {idx} must fail closed for ALPN"); extract_sni_from_client_hello(input).is_none(),
"corpus item {idx} must fail closed for SNI"
);
assert!(
extract_alpn_from_client_hello(input).is_empty(),
"corpus item {idx} must fail closed for ALPN"
);
} }
} }
@ -163,7 +172,9 @@ fn tls_handshake_fuzz_corpus_never_panics_and_rejects_digest_mutations() {
for _ in 0..32 { for _ in 0..32 {
let mut mutated = base.clone(); let mut mutated = base.clone();
for _ in 0..2 { for _ in 0..2 {
seed = seed.wrapping_mul(2862933555777941757).wrapping_add(3037000493); seed = seed
.wrapping_mul(2862933555777941757)
.wrapping_add(3037000493);
let idx = TLS_DIGEST_POS + (seed as usize % TLS_DIGEST_LEN); let idx = TLS_DIGEST_POS + (seed as usize % TLS_DIGEST_LEN);
mutated[idx] ^= ((seed >> 17) as u8).wrapping_add(1); mutated[idx] ^= ((seed >> 17) as u8).wrapping_add(1);
} }
@ -171,9 +182,13 @@ fn tls_handshake_fuzz_corpus_never_panics_and_rejects_digest_mutations() {
} }
for (idx, handshake) in corpus.iter().enumerate() { for (idx, handshake) in corpus.iter().enumerate() {
let result = catch_unwind(|| validate_tls_handshake_at_time(handshake, &secrets, false, now)); let result =
catch_unwind(|| validate_tls_handshake_at_time(handshake, &secrets, false, now));
assert!(result.is_ok(), "corpus item {idx} must not panic"); assert!(result.is_ok(), "corpus item {idx} must not panic");
assert!(result.unwrap().is_none(), "corpus item {idx} must fail closed"); assert!(
result.unwrap().is_none(),
"corpus item {idx} must fail closed"
);
} }
} }

View File

@ -0,0 +1,37 @@
use super::*;
#[test]
fn extension_builder_fails_closed_on_u16_length_overflow() {
let builder = TlsExtensionBuilder {
extensions: vec![0u8; (u16::MAX as usize) + 1],
};
let built = builder.build();
assert!(
built.is_empty(),
"oversized extension blob must fail closed instead of truncating length field"
);
}
#[test]
fn server_hello_builder_fails_closed_on_session_id_len_overflow() {
let builder = ServerHelloBuilder {
random: [0u8; 32],
session_id: vec![0xAB; (u8::MAX as usize) + 1],
cipher_suite: cipher_suite::TLS_AES_128_GCM_SHA256,
compression: 0,
extensions: TlsExtensionBuilder::new(),
};
let message = builder.build_message();
let record = builder.build_record();
assert!(
message.is_empty(),
"session_id length overflow must fail closed in message builder"
);
assert!(
record.is_empty(),
"session_id length overflow must fail closed in record builder"
);
}

View File

@ -1,7 +1,9 @@
use super::*; use super::*;
use crate::crypto::sha256_hmac; use crate::crypto::sha256_hmac;
use crate::tls_front::emulator::build_emulated_server_hello; use crate::tls_front::emulator::build_emulated_server_hello;
use crate::tls_front::types::{CachedTlsData, ParsedServerHello, TlsBehaviorProfile, TlsProfileSource}; use crate::tls_front::types::{
CachedTlsData, ParsedServerHello, TlsBehaviorProfile, TlsProfileSource,
};
use std::time::SystemTime; use std::time::SystemTime;
/// Build a TLS-handshake-like buffer that contains a valid HMAC digest /// Build a TLS-handshake-like buffer that contains a valid HMAC digest
@ -39,8 +41,7 @@ fn make_valid_tls_handshake_with_session_id(
digest[28 + i] ^= ts[i]; digest[28 + i] ^= ts[i];
} }
handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN] handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].copy_from_slice(&digest);
.copy_from_slice(&digest);
handshake handshake
} }
@ -180,7 +181,10 @@ fn second_user_in_list_found_when_first_does_not_match() {
("user_b".to_string(), secret_b.to_vec()), ("user_b".to_string(), secret_b.to_vec()),
]; ];
let result = validate_tls_handshake(&handshake, &secrets, true); let result = validate_tls_handshake(&handshake, &secrets, true);
assert!(result.is_some(), "user_b must be found even though user_a comes first"); assert!(
result.is_some(),
"user_b must be found even though user_a comes first"
);
assert_eq!(result.unwrap().user, "user_b"); assert_eq!(result.unwrap().user, "user_b");
} }
@ -428,8 +432,7 @@ fn censor_probe_random_digests_all_rejected() {
let mut h = vec![0x42u8; min_len]; let mut h = vec![0x42u8; min_len];
h[TLS_DIGEST_POS + TLS_DIGEST_LEN] = session_id_len as u8; h[TLS_DIGEST_POS + TLS_DIGEST_LEN] = session_id_len as u8;
let rand_digest = rng.bytes(TLS_DIGEST_LEN); let rand_digest = rng.bytes(TLS_DIGEST_LEN);
h[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN] h[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].copy_from_slice(&rand_digest);
.copy_from_slice(&rand_digest);
assert!( assert!(
validate_tls_handshake(&h, &secrets, true).is_none(), validate_tls_handshake(&h, &secrets, true).is_none(),
"Random digest at attempt {attempt} must not match" "Random digest at attempt {attempt} must not match"
@ -553,8 +556,7 @@ fn system_time_before_unix_epoch_is_rejected_without_panic() {
fn system_time_far_future_overflowing_i64_returns_none() { fn system_time_far_future_overflowing_i64_returns_none() {
// i64::MAX + 1 seconds past epoch overflows i64 when cast naively with `as`. // i64::MAX + 1 seconds past epoch overflows i64 when cast naively with `as`.
let overflow_secs = u64::try_from(i64::MAX).unwrap() + 1; let overflow_secs = u64::try_from(i64::MAX).unwrap() + 1;
if let Some(far_future) = if let Some(far_future) = UNIX_EPOCH.checked_add(std::time::Duration::from_secs(overflow_secs))
UNIX_EPOCH.checked_add(std::time::Duration::from_secs(overflow_secs))
{ {
assert!( assert!(
system_time_to_unix_secs(far_future).is_none(), system_time_to_unix_secs(far_future).is_none(),
@ -620,7 +622,10 @@ fn appended_trailing_byte_causes_rejection() {
let mut h = make_valid_tls_handshake(secret, 0); let mut h = make_valid_tls_handshake(secret, 0);
let secrets = vec![("u".to_string(), secret.to_vec())]; let secrets = vec![("u".to_string(), secret.to_vec())];
assert!(validate_tls_handshake(&h, &secrets, true).is_some(), "baseline"); assert!(
validate_tls_handshake(&h, &secrets, true).is_some(),
"baseline"
);
h.push(0x00); h.push(0x00);
assert!( assert!(
@ -647,8 +652,7 @@ fn zero_length_session_id_accepted() {
let computed = sha256_hmac(secret, &handshake); let computed = sha256_hmac(secret, &handshake);
// timestamp = 0 → ts XOR bytes are all zero → digest = computed unchanged. // timestamp = 0 → ts XOR bytes are all zero → digest = computed unchanged.
handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN] handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].copy_from_slice(&computed);
.copy_from_slice(&computed);
let secrets = vec![("u".to_string(), secret.to_vec())]; let secrets = vec![("u".to_string(), secret.to_vec())];
let result = validate_tls_handshake(&handshake, &secrets, true); let result = validate_tls_handshake(&handshake, &secrets, true);
@ -773,10 +777,18 @@ fn ignore_time_skew_explicitly_decouples_from_boot_time_cap() {
let secrets = vec![("u".to_string(), secret.to_vec())]; let secrets = vec![("u".to_string(), secret.to_vec())];
let cap_zero = validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, true, 0, 0); let cap_zero = validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, true, 0, 0);
let cap_nonzero = let cap_nonzero = validate_tls_handshake_at_time_with_boot_cap(
validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, true, 0, BOOT_TIME_COMPAT_MAX_SECS); &h,
&secrets,
true,
0,
BOOT_TIME_COMPAT_MAX_SECS,
);
assert!(cap_zero.is_some(), "ignore_time_skew=true must accept valid HMAC"); assert!(
cap_zero.is_some(),
"ignore_time_skew=true must accept valid HMAC"
);
assert!( assert!(
cap_nonzero.is_some(), cap_nonzero.is_some(),
"ignore_time_skew path must not depend on boot-time cap" "ignore_time_skew path must not depend on boot-time cap"
@ -888,8 +900,8 @@ fn adversarial_skew_boundary_matrix_accepts_only_inclusive_window_when_boot_disa
let ts_i64 = now - offset; let ts_i64 = now - offset;
let ts = u32::try_from(ts_i64).expect("timestamp must fit u32 for test matrix"); let ts = u32::try_from(ts_i64).expect("timestamp must fit u32 for test matrix");
let h = make_valid_tls_handshake(secret, ts); let h = make_valid_tls_handshake(secret, ts);
let accepted = validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, false, now, 0) let accepted =
.is_some(); validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, false, now, 0).is_some();
let expected = (TIME_SKEW_MIN..=TIME_SKEW_MAX).contains(&offset); let expected = (TIME_SKEW_MIN..=TIME_SKEW_MAX).contains(&offset);
assert_eq!( assert_eq!(
accepted, expected, accepted, expected,
@ -917,8 +929,8 @@ fn light_fuzz_skew_window_rejects_outside_range_when_boot_disabled() {
let ts = u32::try_from(ts_i64).expect("timestamp must fit u32 for fuzz test"); let ts = u32::try_from(ts_i64).expect("timestamp must fit u32 for fuzz test");
let h = make_valid_tls_handshake(secret, ts); let h = make_valid_tls_handshake(secret, ts);
let accepted = validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, false, now, 0) let accepted =
.is_some(); validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, false, now, 0).is_some();
assert!( assert!(
!accepted, !accepted,
"offset {offset} must be rejected outside strict skew window" "offset {offset} must be rejected outside strict skew window"
@ -940,8 +952,8 @@ fn stress_boot_disabled_validation_matches_time_diff_oracle() {
let ts = s as u32; let ts = s as u32;
let h = make_valid_tls_handshake(secret, ts); let h = make_valid_tls_handshake(secret, ts);
let accepted = validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, false, now, 0) let accepted =
.is_some(); validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, false, now, 0).is_some();
let time_diff = now - i64::from(ts); let time_diff = now - i64::from(ts);
let expected = (TIME_SKEW_MIN..=TIME_SKEW_MAX).contains(&time_diff); let expected = (TIME_SKEW_MIN..=TIME_SKEW_MAX).contains(&time_diff);
assert_eq!( assert_eq!(
@ -960,7 +972,10 @@ fn integration_large_user_list_with_boot_disabled_finds_only_matching_user() {
let mut secrets = Vec::new(); let mut secrets = Vec::new();
for i in 0..512u32 { for i in 0..512u32 {
secrets.push((format!("noise-{i}"), format!("noise-secret-{i}").into_bytes())); secrets.push((
format!("noise-{i}"),
format!("noise-secret-{i}").into_bytes(),
));
} }
secrets.push(("target-user".to_string(), target_secret.to_vec())); secrets.push(("target-user".to_string(), target_secret.to_vec()));
@ -1018,7 +1033,10 @@ fn u32_max_timestamp_accepted_with_ignore_time_skew() {
let secrets = vec![("u".to_string(), secret.to_vec())]; let secrets = vec![("u".to_string(), secret.to_vec())];
let result = validate_tls_handshake(&h, &secrets, true); let result = validate_tls_handshake(&h, &secrets, true);
assert!(result.is_some(), "u32::MAX timestamp must be accepted with ignore_time_skew=true"); assert!(
result.is_some(),
"u32::MAX timestamp must be accepted with ignore_time_skew=true"
);
assert_eq!( assert_eq!(
result.unwrap().timestamp, result.unwrap().timestamp,
u32::MAX, u32::MAX,
@ -1159,7 +1177,8 @@ fn first_matching_user_wins_over_later_duplicate_secret() {
let result = validate_tls_handshake(&h, &secrets, true); let result = validate_tls_handshake(&h, &secrets, true);
assert!(result.is_some()); assert!(result.is_some());
assert_eq!( assert_eq!(
result.unwrap().user, "winner", result.unwrap().user,
"winner",
"first matching user must be returned even when a later entry also matches" "first matching user must be returned even when a later entry also matches"
); );
} }
@ -1425,7 +1444,8 @@ fn test_build_server_hello_structure() {
assert!(response.len() > ccs_start + 6); assert!(response.len() > ccs_start + 6);
assert_eq!(response[ccs_start], TLS_RECORD_CHANGE_CIPHER); assert_eq!(response[ccs_start], TLS_RECORD_CHANGE_CIPHER);
let ccs_len = 5 + u16::from_be_bytes([response[ccs_start + 3], response[ccs_start + 4]]) as usize; let ccs_len =
5 + u16::from_be_bytes([response[ccs_start + 3], response[ccs_start + 4]]) as usize;
let app_start = ccs_start + ccs_len; let app_start = ccs_start + ccs_len;
assert!(response.len() > app_start + 5); assert!(response.len() > app_start + 5);
assert_eq!(response[app_start], TLS_RECORD_APPLICATION); assert_eq!(response[app_start], TLS_RECORD_APPLICATION);
@ -1729,7 +1749,10 @@ fn empty_secret_hmac_is_supported() {
let handshake = make_valid_tls_handshake(secret, 0); let handshake = make_valid_tls_handshake(secret, 0);
let secrets = vec![("empty".to_string(), secret.to_vec())]; let secrets = vec![("empty".to_string(), secret.to_vec())];
let result = validate_tls_handshake(&handshake, &secrets, true); let result = validate_tls_handshake(&handshake, &secrets, true);
assert!(result.is_some(), "Empty HMAC key must not panic and must validate when correct"); assert!(
result.is_some(),
"Empty HMAC key must not panic and must validate when correct"
);
} }
#[test] #[test]
@ -1802,7 +1825,10 @@ fn server_hello_application_data_payload_varies_across_runs() {
let app_len = u16::from_be_bytes([response[app_pos + 3], response[app_pos + 4]]) as usize; let app_len = u16::from_be_bytes([response[app_pos + 3], response[app_pos + 4]]) as usize;
let payload = response[app_pos + 5..app_pos + 5 + app_len].to_vec(); let payload = response[app_pos + 5..app_pos + 5 + app_len].to_vec();
assert!(payload.iter().any(|&b| b != 0), "Payload must not be all-zero deterministic filler"); assert!(
payload.iter().any(|&b| b != 0),
"Payload must not be all-zero deterministic filler"
);
unique_payloads.insert(payload); unique_payloads.insert(payload);
} }
@ -1846,7 +1872,13 @@ fn large_replay_window_does_not_expand_time_skew_acceptance() {
#[test] #[test]
fn parse_tls_record_header_accepts_tls_version_constant() { fn parse_tls_record_header_accepts_tls_version_constant() {
let header = [TLS_RECORD_HANDSHAKE, TLS_VERSION[0], TLS_VERSION[1], 0x00, 0x2A]; let header = [
TLS_RECORD_HANDSHAKE,
TLS_VERSION[0],
TLS_VERSION[1],
0x00,
0x2A,
];
let parsed = parse_tls_record_header(&header).expect("TLS_VERSION header should be accepted"); let parsed = parse_tls_record_header(&header).expect("TLS_VERSION header should be accepted");
assert_eq!(parsed.0, TLS_RECORD_HANDSHAKE); assert_eq!(parsed.0, TLS_RECORD_HANDSHAKE);
assert_eq!(parsed.1, 42); assert_eq!(parsed.1, 42);
@ -1868,7 +1900,10 @@ fn server_hello_clamps_fake_cert_len_lower_bound() {
let app_len = u16::from_be_bytes([response[app_pos + 3], response[app_pos + 4]]) as usize; let app_len = u16::from_be_bytes([response[app_pos + 3], response[app_pos + 4]]) as usize;
assert_eq!(response[app_pos], TLS_RECORD_APPLICATION); assert_eq!(response[app_pos], TLS_RECORD_APPLICATION);
assert_eq!(app_len, 64, "fake cert payload must be clamped to minimum 64 bytes"); assert_eq!(
app_len, 64,
"fake cert payload must be clamped to minimum 64 bytes"
);
} }
#[test] #[test]
@ -1887,7 +1922,10 @@ fn server_hello_clamps_fake_cert_len_upper_bound() {
let app_len = u16::from_be_bytes([response[app_pos + 3], response[app_pos + 4]]) as usize; let app_len = u16::from_be_bytes([response[app_pos + 3], response[app_pos + 4]]) as usize;
assert_eq!(response[app_pos], TLS_RECORD_APPLICATION); assert_eq!(response[app_pos], TLS_RECORD_APPLICATION);
assert_eq!(app_len, MAX_TLS_CIPHERTEXT_SIZE, "fake cert payload must be clamped to TLS record max bound"); assert_eq!(
app_len, MAX_TLS_CIPHERTEXT_SIZE,
"fake cert payload must be clamped to TLS record max bound"
);
} }
#[test] #[test]
@ -1898,7 +1936,15 @@ fn server_hello_new_session_ticket_count_matches_configuration() {
let rng = crate::crypto::SecureRandom::new(); let rng = crate::crypto::SecureRandom::new();
let tickets: u8 = 3; let tickets: u8 = 3;
let response = build_server_hello(secret, &client_digest, &session_id, 1024, &rng, None, tickets); let response = build_server_hello(
secret,
&client_digest,
&session_id,
1024,
&rng,
None,
tickets,
);
let mut pos = 0usize; let mut pos = 0usize;
let mut app_records = 0usize; let mut app_records = 0usize;
@ -1906,7 +1952,10 @@ fn server_hello_new_session_ticket_count_matches_configuration() {
let rtype = response[pos]; let rtype = response[pos];
let rlen = u16::from_be_bytes([response[pos + 3], response[pos + 4]]) as usize; let rlen = u16::from_be_bytes([response[pos + 3], response[pos + 4]]) as usize;
let next = pos + 5 + rlen; let next = pos + 5 + rlen;
assert!(next <= response.len(), "TLS record must stay inside response bounds"); assert!(
next <= response.len(),
"TLS record must stay inside response bounds"
);
if rtype == TLS_RECORD_APPLICATION { if rtype == TLS_RECORD_APPLICATION {
app_records += 1; app_records += 1;
} }
@ -1927,7 +1976,15 @@ fn server_hello_new_session_ticket_count_is_safely_capped() {
let session_id = vec![0x54; 32]; let session_id = vec![0x54; 32];
let rng = crate::crypto::SecureRandom::new(); let rng = crate::crypto::SecureRandom::new();
let response = build_server_hello(secret, &client_digest, &session_id, 1024, &rng, None, u8::MAX); let response = build_server_hello(
secret,
&client_digest,
&session_id,
1024,
&rng,
None,
u8::MAX,
);
let mut pos = 0usize; let mut pos = 0usize;
let mut app_records = 0usize; let mut app_records = 0usize;
@ -1935,7 +1992,10 @@ fn server_hello_new_session_ticket_count_is_safely_capped() {
let rtype = response[pos]; let rtype = response[pos];
let rlen = u16::from_be_bytes([response[pos + 3], response[pos + 4]]) as usize; let rlen = u16::from_be_bytes([response[pos + 3], response[pos + 4]]) as usize;
let next = pos + 5 + rlen; let next = pos + 5 + rlen;
assert!(next <= response.len(), "TLS record must stay inside response bounds"); assert!(
next <= response.len(),
"TLS record must stay inside response bounds"
);
if rtype == TLS_RECORD_APPLICATION { if rtype == TLS_RECORD_APPLICATION {
app_records += 1; app_records += 1;
} }
@ -1943,8 +2003,7 @@ fn server_hello_new_session_ticket_count_is_safely_capped() {
} }
assert_eq!( assert_eq!(
app_records, app_records, 5,
5,
"response must cap ticket-like tail records to four plus one main application record" "response must cap ticket-like tail records to four plus one main application record"
); );
} }
@ -1972,10 +2031,14 @@ fn boot_time_handshake_replay_remains_blocked_after_cache_window_expires() {
std::thread::sleep(std::time::Duration::from_millis(70)); std::thread::sleep(std::time::Duration::from_millis(70));
let validation_after_expiry = validate_tls_handshake_with_replay_window(&handshake, &secrets, false, 2) let validation_after_expiry =
validate_tls_handshake_with_replay_window(&handshake, &secrets, false, 2)
.expect("boot-time handshake must still cryptographically validate after cache expiry"); .expect("boot-time handshake must still cryptographically validate after cache expiry");
let digest_half_after_expiry = &validation_after_expiry.digest[..TLS_DIGEST_HALF_LEN]; let digest_half_after_expiry = &validation_after_expiry.digest[..TLS_DIGEST_HALF_LEN];
assert_eq!(digest_half, digest_half_after_expiry, "replay key must be stable for same handshake"); assert_eq!(
digest_half, digest_half_after_expiry,
"replay key must be stable for same handshake"
);
assert!( assert!(
checker.check_and_add_tls_digest(digest_half_after_expiry), checker.check_and_add_tls_digest(digest_half_after_expiry),
@ -2006,7 +2069,8 @@ fn adversarial_boot_time_handshake_should_not_be_replayable_after_cache_expiry()
std::thread::sleep(std::time::Duration::from_millis(70)); std::thread::sleep(std::time::Duration::from_millis(70));
let validation_after_expiry = validate_tls_handshake_with_replay_window(&handshake, &secrets, false, 2) let validation_after_expiry =
validate_tls_handshake_with_replay_window(&handshake, &secrets, false, 2)
.expect("boot-time handshake still validates cryptographically after cache expiry"); .expect("boot-time handshake still validates cryptographically after cache expiry");
let digest_half_after_expiry = &validation_after_expiry.digest[..TLS_DIGEST_HALF_LEN]; let digest_half_after_expiry = &validation_after_expiry.digest[..TLS_DIGEST_HALF_LEN];
@ -2067,11 +2131,14 @@ fn light_fuzz_boot_time_timestamp_matrix_with_short_replay_window_obeys_boot_cap
let ts = (s as u32) % 8; let ts = (s as u32) % 8;
let handshake = make_valid_tls_handshake(secret, ts); let handshake = make_valid_tls_handshake(secret, ts);
let accepted = validate_tls_handshake_with_replay_window(&handshake, &secrets, false, 2) let accepted =
.is_some(); validate_tls_handshake_with_replay_window(&handshake, &secrets, false, 2).is_some();
if ts < 2 { if ts < 2 {
assert!(accepted, "timestamp {ts} must remain boot-time compatible under 2s cap"); assert!(
accepted,
"timestamp {ts} must remain boot-time compatible under 2s cap"
);
} else { } else {
assert!( assert!(
!accepted, !accepted,
@ -2107,7 +2174,9 @@ fn server_hello_application_data_contains_alpn_marker_when_selected() {
let expected = [0x00u8, 0x10, 0x00, 0x05, 0x00, 0x03, 0x02, b'h', b'2']; let expected = [0x00u8, 0x10, 0x00, 0x05, 0x00, 0x03, 0x02, b'h', b'2'];
assert!( assert!(
app_payload.windows(expected.len()).any(|window| window == expected), app_payload
.windows(expected.len())
.any(|window| window == expected),
"first application payload must carry ALPN marker for selected protocol" "first application payload must carry ALPN marker for selected protocol"
); );
} }
@ -2137,7 +2206,10 @@ fn server_hello_ignores_oversized_alpn_and_still_caps_ticket_tail() {
let rtype = response[pos]; let rtype = response[pos];
let rlen = u16::from_be_bytes([response[pos + 3], response[pos + 4]]) as usize; let rlen = u16::from_be_bytes([response[pos + 3], response[pos + 4]]) as usize;
let next = pos + 5 + rlen; let next = pos + 5 + rlen;
assert!(next <= response.len(), "TLS record must stay inside response bounds"); assert!(
next <= response.len(),
"TLS record must stay inside response bounds"
);
if rtype == TLS_RECORD_APPLICATION { if rtype == TLS_RECORD_APPLICATION {
app_records += 1; app_records += 1;
if first_app_payload.is_none() { if first_app_payload.is_none() {
@ -2146,7 +2218,9 @@ fn server_hello_ignores_oversized_alpn_and_still_caps_ticket_tail() {
} }
pos = next; pos = next;
} }
let marker = [0x00u8, 0x10, 0x00, 0x06, 0x00, 0x04, 0x03, b'x', b'x', b'x', b'x']; let marker = [
0x00u8, 0x10, 0x00, 0x06, 0x00, 0x04, 0x03, b'x', b'x', b'x', b'x',
];
assert_eq!( assert_eq!(
app_records, 5, app_records, 5,
@ -2310,13 +2384,13 @@ fn light_fuzz_tls_header_classifier_and_parser_policy_consistency() {
&& header[1] == 0x03 && header[1] == 0x03
&& (header[2] == 0x01 || header[2] == 0x03); && (header[2] == 0x01 || header[2] == 0x03);
assert_eq!( assert_eq!(
classified, classified, expected_classified,
expected_classified,
"classifier policy mismatch for header {header:02x?}" "classifier policy mismatch for header {header:02x?}"
); );
let parsed = parse_tls_record_header(&header); let parsed = parse_tls_record_header(&header);
let expected_parsed = header[1] == 0x03 && (header[2] == 0x01 || header[2] == TLS_VERSION[1]); let expected_parsed =
header[1] == 0x03 && (header[2] == 0x01 || header[2] == TLS_VERSION[1]);
assert_eq!( assert_eq!(
parsed.is_some(), parsed.is_some(),
expected_parsed, expected_parsed,

View File

@ -1,8 +1,4 @@
use super::{ use super::{MAX_TLS_CIPHERTEXT_SIZE, MAX_TLS_PLAINTEXT_SIZE, MIN_TLS_CLIENT_HELLO_SIZE};
MAX_TLS_CIPHERTEXT_SIZE,
MAX_TLS_PLAINTEXT_SIZE,
MIN_TLS_CLIENT_HELLO_SIZE,
};
#[test] #[test]
fn tls_size_constants_match_rfc_8446() { fn tls_size_constants_match_rfc_8446() {

View File

@ -5,11 +5,66 @@
//! actually carries MTProto authentication data. //! actually carries MTProto authentication data.
#![allow(dead_code)] #![allow(dead_code)]
#![cfg_attr(not(test), forbid(clippy::undocumented_unsafe_blocks))]
#![cfg_attr(
not(test),
deny(
clippy::unwrap_used,
clippy::expect_used,
clippy::panic,
clippy::todo,
clippy::unimplemented,
clippy::correctness,
clippy::option_if_let_else,
clippy::or_fun_call,
clippy::branches_sharing_code,
clippy::single_option_map,
clippy::useless_let_if_seq,
clippy::redundant_locals,
clippy::cloned_ref_to_slice_refs,
unsafe_code,
clippy::await_holding_lock,
clippy::await_holding_refcell_ref,
clippy::debug_assert_with_mut_call,
clippy::macro_use_imports,
clippy::cast_ptr_alignment,
clippy::cast_lossless,
clippy::ptr_as_ptr,
clippy::large_stack_arrays,
clippy::same_functions_in_if_condition,
trivial_casts,
trivial_numeric_casts,
unused_extern_crates,
unused_import_braces,
rust_2018_idioms
)
)]
#![cfg_attr(
not(test),
allow(
clippy::use_self,
clippy::redundant_closure,
clippy::too_many_arguments,
clippy::doc_markdown,
clippy::missing_const_for_fn,
clippy::unnecessary_operation,
clippy::redundant_pub_crate,
clippy::derive_partial_eq_without_eq,
clippy::type_complexity,
clippy::new_ret_no_self,
clippy::cast_possible_truncation,
clippy::cast_possible_wrap,
clippy::significant_drop_tightening,
clippy::significant_drop_in_scrutinee,
clippy::float_cmp,
clippy::nursery
)
)]
use crate::crypto::{sha256_hmac, SecureRandom}; use super::constants::*;
use crate::crypto::{SecureRandom, sha256_hmac};
#[cfg(test)] #[cfg(test)]
use crate::error::ProxyError; use crate::error::ProxyError;
use super::constants::*;
use std::time::{SystemTime, UNIX_EPOCH}; use std::time::{SystemTime, UNIX_EPOCH};
use subtle::ConstantTimeEq; use subtle::ConstantTimeEq;
use x25519_dalek::{X25519_BASEPOINT_BYTES, x25519}; use x25519_dalek::{X25519_BASEPOINT_BYTES, x25519};
@ -69,7 +124,6 @@ pub struct TlsValidation {
/// Client digest for response generation /// Client digest for response generation
pub digest: [u8; TLS_DIGEST_LEN], pub digest: [u8; TLS_DIGEST_LEN],
/// Timestamp extracted from digest /// Timestamp extracted from digest
pub timestamp: u32, pub timestamp: u32,
} }
@ -91,7 +145,8 @@ impl TlsExtensionBuilder {
/// Add Key Share extension with X25519 key /// Add Key Share extension with X25519 key
fn add_key_share(&mut self, public_key: &[u8; 32]) -> &mut Self { fn add_key_share(&mut self, public_key: &[u8; 32]) -> &mut Self {
// Extension type: key_share (0x0033) // Extension type: key_share (0x0033)
self.extensions.extend_from_slice(&extension_type::KEY_SHARE.to_be_bytes()); self.extensions
.extend_from_slice(&extension_type::KEY_SHARE.to_be_bytes());
// Key share entry: curve (2) + key_len (2) + key (32) = 36 bytes // Key share entry: curve (2) + key_len (2) + key (32) = 36 bytes
// Extension data length // Extension data length
@ -99,7 +154,8 @@ impl TlsExtensionBuilder {
self.extensions.extend_from_slice(&entry_len.to_be_bytes()); self.extensions.extend_from_slice(&entry_len.to_be_bytes());
// Named curve: x25519 // Named curve: x25519
self.extensions.extend_from_slice(&named_curve::X25519.to_be_bytes()); self.extensions
.extend_from_slice(&named_curve::X25519.to_be_bytes());
// Key length // Key length
self.extensions.extend_from_slice(&(32u16).to_be_bytes()); self.extensions.extend_from_slice(&(32u16).to_be_bytes());
@ -113,7 +169,8 @@ impl TlsExtensionBuilder {
/// Add Supported Versions extension /// Add Supported Versions extension
fn add_supported_versions(&mut self, version: u16) -> &mut Self { fn add_supported_versions(&mut self, version: u16) -> &mut Self {
// Extension type: supported_versions (0x002b) // Extension type: supported_versions (0x002b)
self.extensions.extend_from_slice(&extension_type::SUPPORTED_VERSIONS.to_be_bytes()); self.extensions
.extend_from_slice(&extension_type::SUPPORTED_VERSIONS.to_be_bytes());
// Extension data: length (2) + version (2) // Extension data: length (2) + version (2)
self.extensions.extend_from_slice(&(2u16).to_be_bytes()); self.extensions.extend_from_slice(&(2u16).to_be_bytes());
@ -125,12 +182,13 @@ impl TlsExtensionBuilder {
} }
/// Build final extensions with length prefix /// Build final extensions with length prefix
fn build(self) -> Vec<u8> { fn build(self) -> Vec<u8> {
let Ok(len) = u16::try_from(self.extensions.len()) else {
return Vec::new();
};
let mut result = Vec::with_capacity(2 + self.extensions.len()); let mut result = Vec::with_capacity(2 + self.extensions.len());
// Extensions length (2 bytes) // Extensions length (2 bytes)
let len = self.extensions.len() as u16;
result.extend_from_slice(&len.to_be_bytes()); result.extend_from_slice(&len.to_be_bytes());
// Extensions data // Extensions data
@ -140,7 +198,6 @@ impl TlsExtensionBuilder {
} }
/// Get current extensions without length prefix (for calculation) /// Get current extensions without length prefix (for calculation)
fn as_bytes(&self) -> &[u8] { fn as_bytes(&self) -> &[u8] {
&self.extensions &self.extensions
} }
@ -186,8 +243,13 @@ impl ServerHelloBuilder {
/// Build ServerHello message (without record header) /// Build ServerHello message (without record header)
fn build_message(&self) -> Vec<u8> { fn build_message(&self) -> Vec<u8> {
let Ok(session_id_len) = u8::try_from(self.session_id.len()) else {
return Vec::new();
};
let extensions = self.extensions.extensions.clone(); let extensions = self.extensions.extensions.clone();
let extensions_len = extensions.len() as u16; let Ok(extensions_len) = u16::try_from(extensions.len()) else {
return Vec::new();
};
// Calculate total length // Calculate total length
let body_len = 2 + // version let body_len = 2 + // version
@ -196,6 +258,9 @@ impl ServerHelloBuilder {
2 + // cipher suite 2 + // cipher suite
1 + // compression 1 + // compression
2 + extensions.len(); // extensions length + data 2 + extensions.len(); // extensions length + data
if body_len > 0x00ff_ffff {
return Vec::new();
}
let mut message = Vec::with_capacity(4 + body_len); let mut message = Vec::with_capacity(4 + body_len);
@ -203,7 +268,10 @@ impl ServerHelloBuilder {
message.push(0x02); // ServerHello message type message.push(0x02); // ServerHello message type
// 3-byte length // 3-byte length
let len_bytes = (body_len as u32).to_be_bytes(); let Ok(body_len_u32) = u32::try_from(body_len) else {
return Vec::new();
};
let len_bytes = body_len_u32.to_be_bytes();
message.extend_from_slice(&len_bytes[1..4]); message.extend_from_slice(&len_bytes[1..4]);
// Server version (TLS 1.2 in header, actual version in extension) // Server version (TLS 1.2 in header, actual version in extension)
@ -213,7 +281,7 @@ impl ServerHelloBuilder {
message.extend_from_slice(&self.random); message.extend_from_slice(&self.random);
// Session ID // Session ID
message.push(self.session_id.len() as u8); message.push(session_id_len);
message.extend_from_slice(&self.session_id); message.extend_from_slice(&self.session_id);
// Cipher suite // Cipher suite
@ -234,13 +302,19 @@ impl ServerHelloBuilder {
/// Build complete ServerHello TLS record /// Build complete ServerHello TLS record
fn build_record(&self) -> Vec<u8> { fn build_record(&self) -> Vec<u8> {
let message = self.build_message(); let message = self.build_message();
if message.is_empty() {
return Vec::new();
}
let Ok(message_len) = u16::try_from(message.len()) else {
return Vec::new();
};
let mut record = Vec::with_capacity(5 + message.len()); let mut record = Vec::with_capacity(5 + message.len());
// TLS record header // TLS record header
record.push(TLS_RECORD_HANDSHAKE); record.push(TLS_RECORD_HANDSHAKE);
record.extend_from_slice(&TLS_VERSION); record.extend_from_slice(&TLS_VERSION);
record.extend_from_slice(&(message.len() as u16).to_be_bytes()); record.extend_from_slice(&message_len.to_be_bytes());
// Message // Message
record.extend_from_slice(&message); record.extend_from_slice(&message);
@ -256,7 +330,6 @@ impl ServerHelloBuilder {
/// Returns validation result if a matching user is found. /// Returns validation result if a matching user is found.
/// The result **must** be used — ignoring it silently bypasses authentication. /// The result **must** be used — ignoring it silently bypasses authentication.
#[must_use] #[must_use]
pub fn validate_tls_handshake( pub fn validate_tls_handshake(
handshake: &[u8], handshake: &[u8],
secrets: &[(String, Vec<u8>)], secrets: &[(String, Vec<u8>)],
@ -320,7 +393,6 @@ fn system_time_to_unix_secs(now: SystemTime) -> Option<i64> {
i64::try_from(d.as_secs()).ok() i64::try_from(d.as_secs()).ok()
} }
fn validate_tls_handshake_at_time( fn validate_tls_handshake_at_time(
handshake: &[u8], handshake: &[u8],
secrets: &[(String, Vec<u8>)], secrets: &[(String, Vec<u8>)],
@ -463,15 +535,20 @@ pub fn build_server_hello(
// Build Change Cipher Spec record // Build Change Cipher Spec record
let change_cipher_spec = [ let change_cipher_spec = [
TLS_RECORD_CHANGE_CIPHER, TLS_RECORD_CHANGE_CIPHER,
TLS_VERSION[0], TLS_VERSION[1], TLS_VERSION[0],
0x00, 0x01, // length = 1 TLS_VERSION[1],
0x00,
0x01, // length = 1
0x01, // CCS byte 0x01, // CCS byte
]; ];
// Build first encrypted flight mimic as opaque ApplicationData bytes. // Build first encrypted flight mimic as opaque ApplicationData bytes.
// Embed a compact EncryptedExtensions-like ALPN block when selected. // Embed a compact EncryptedExtensions-like ALPN block when selected.
let mut fake_cert = Vec::with_capacity(fake_cert_len); let mut fake_cert = Vec::with_capacity(fake_cert_len);
if let Some(proto) = alpn.as_ref().filter(|p| !p.is_empty() && p.len() <= u8::MAX as usize) { if let Some(proto) = alpn
.as_ref()
.filter(|p| !p.is_empty() && p.len() <= u8::MAX as usize)
{
let proto_list_len = 1usize + proto.len(); let proto_list_len = 1usize + proto.len();
let ext_data_len = 2usize + proto_list_len; let ext_data_len = 2usize + proto_list_len;
let marker_len = 4usize + ext_data_len; let marker_len = 4usize + ext_data_len;
@ -515,7 +592,10 @@ pub fn build_server_hello(
// Combine all records // Combine all records
let mut response = Vec::with_capacity( let mut response = Vec::with_capacity(
server_hello.len() + change_cipher_spec.len() + app_data_record.len() + tickets.iter().map(|r| r.len()).sum::<usize>() server_hello.len()
+ change_cipher_spec.len()
+ app_data_record.len()
+ tickets.iter().map(|r| r.len()).sum::<usize>(),
); );
response.extend_from_slice(&server_hello); response.extend_from_slice(&server_hello);
response.extend_from_slice(&change_cipher_spec); response.extend_from_slice(&change_cipher_spec);
@ -532,8 +612,7 @@ pub fn build_server_hello(
// Insert computed digest into ServerHello // Insert computed digest into ServerHello
// Position: record header (5) + message type (1) + length (3) + version (2) = 11 // Position: record header (5) + message type (1) + length (3) + version (2) = 11
response[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN] response[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].copy_from_slice(&response_digest);
.copy_from_slice(&response_digest);
response response
} }
@ -611,19 +690,20 @@ pub fn extract_sni_from_client_hello(handshake: &[u8]) -> Option<String> {
let sn_end = std::cmp::min(sn_pos + list_len, pos + elen); let sn_end = std::cmp::min(sn_pos + list_len, pos + elen);
while sn_pos + 3 <= sn_end { while sn_pos + 3 <= sn_end {
let name_type = handshake[sn_pos]; let name_type = handshake[sn_pos];
let name_len = u16::from_be_bytes([handshake[sn_pos + 1], handshake[sn_pos + 2]]) as usize; let name_len =
u16::from_be_bytes([handshake[sn_pos + 1], handshake[sn_pos + 2]]) as usize;
sn_pos += 3; sn_pos += 3;
if sn_pos + name_len > sn_end { if sn_pos + name_len > sn_end {
break; break;
} }
if name_type == 0 && name_len > 0 if name_type == 0
&& name_len > 0
&& let Ok(host) = std::str::from_utf8(&handshake[sn_pos..sn_pos + name_len]) && let Ok(host) = std::str::from_utf8(&handshake[sn_pos..sn_pos + name_len])
&& is_valid_sni_hostname(host)
{ {
if is_valid_sni_hostname(host) {
extracted_sni = Some(host.to_string()); extracted_sni = Some(host.to_string());
break; break;
} }
}
sn_pos += name_len; sn_pos += name_len;
} }
} }
@ -679,35 +759,49 @@ pub fn extract_alpn_from_client_hello(handshake: &[u8]) -> Vec<Vec<u8>> {
} }
pos += 4; // type + len pos += 4; // type + len
pos += 2 + 32; // version + random pos += 2 + 32; // version + random
if pos >= handshake.len() { return Vec::new(); } if pos >= handshake.len() {
return Vec::new();
}
let session_id_len = *handshake.get(pos).unwrap_or(&0) as usize; let session_id_len = *handshake.get(pos).unwrap_or(&0) as usize;
pos += 1 + session_id_len; pos += 1 + session_id_len;
if pos + 2 > handshake.len() { return Vec::new(); } if pos + 2 > handshake.len() {
let cipher_len = u16::from_be_bytes([handshake[pos], handshake[pos+1]]) as usize; return Vec::new();
}
let cipher_len = u16::from_be_bytes([handshake[pos], handshake[pos + 1]]) as usize;
pos += 2 + cipher_len; pos += 2 + cipher_len;
if pos >= handshake.len() { return Vec::new(); } if pos >= handshake.len() {
return Vec::new();
}
let comp_len = *handshake.get(pos).unwrap_or(&0) as usize; let comp_len = *handshake.get(pos).unwrap_or(&0) as usize;
pos += 1 + comp_len; pos += 1 + comp_len;
if pos + 2 > handshake.len() { return Vec::new(); } if pos + 2 > handshake.len() {
let ext_len = u16::from_be_bytes([handshake[pos], handshake[pos+1]]) as usize; return Vec::new();
}
let ext_len = u16::from_be_bytes([handshake[pos], handshake[pos + 1]]) as usize;
pos += 2; pos += 2;
let ext_end = pos + ext_len; let ext_end = pos + ext_len;
if ext_end > handshake.len() { return Vec::new(); } if ext_end > handshake.len() {
return Vec::new();
}
let mut out = Vec::new(); let mut out = Vec::new();
while pos + 4 <= ext_end { while pos + 4 <= ext_end {
let etype = u16::from_be_bytes([handshake[pos], handshake[pos+1]]); let etype = u16::from_be_bytes([handshake[pos], handshake[pos + 1]]);
let elen = u16::from_be_bytes([handshake[pos+2], handshake[pos+3]]) as usize; let elen = u16::from_be_bytes([handshake[pos + 2], handshake[pos + 3]]) as usize;
pos += 4; pos += 4;
if pos + elen > ext_end { break; } if pos + elen > ext_end {
break;
}
if etype == extension_type::ALPN && elen >= 3 { if etype == extension_type::ALPN && elen >= 3 {
let list_len = u16::from_be_bytes([handshake[pos], handshake[pos+1]]) as usize; let list_len = u16::from_be_bytes([handshake[pos], handshake[pos + 1]]) as usize;
let mut lp = pos + 2; let mut lp = pos + 2;
let list_end = (pos + 2).saturating_add(list_len).min(pos + elen); let list_end = (pos + 2).saturating_add(list_len).min(pos + elen);
while lp < list_end { while lp < list_end {
let plen = handshake[lp] as usize; let plen = handshake[lp] as usize;
lp += 1; lp += 1;
if lp + plen > list_end { break; } if lp + plen > list_end {
out.push(handshake[lp..lp+plen].to_vec()); break;
}
out.push(handshake[lp..lp + plen].to_vec());
lp += plen; lp += plen;
} }
break; break;
@ -717,7 +811,6 @@ pub fn extract_alpn_from_client_hello(handshake: &[u8]) -> Vec<Vec<u8>> {
out out
} }
/// Check if bytes look like a TLS ClientHello /// Check if bytes look like a TLS ClientHello
pub fn is_tls_handshake(first_bytes: &[u8]) -> bool { pub fn is_tls_handshake(first_bytes: &[u8]) -> bool {
if first_bytes.len() < 3 { if first_bytes.len() < 3 {
@ -731,7 +824,6 @@ pub fn is_tls_handshake(first_bytes: &[u8]) -> bool {
} }
/// Parse TLS record header, returns (record_type, length) /// Parse TLS record header, returns (record_type, length)
pub fn parse_tls_record_header(header: &[u8; 5]) -> Option<(u8, u16)> { pub fn parse_tls_record_header(header: &[u8; 5]) -> Option<(u8, u16)> {
let record_type = header[0]; let record_type = header[0];
let version = [header[1], header[2]]; let version = [header[1], header[2]];
@ -776,25 +868,28 @@ fn validate_server_hello_structure(data: &[u8]) -> Result<(), ProxyError> {
// Check record length // Check record length
let record_len = u16::from_be_bytes([data[3], data[4]]) as usize; let record_len = u16::from_be_bytes([data[3], data[4]]) as usize;
if data.len() < 5 + record_len { if data.len() < 5 + record_len {
return Err(ProxyError::InvalidHandshake( return Err(ProxyError::InvalidHandshake(format!(
format!("ServerHello record truncated: expected {}, got {}", "ServerHello record truncated: expected {}, got {}",
5 + record_len, data.len()) 5 + record_len,
)); data.len()
)));
} }
// Check message type // Check message type
if data[5] != 0x02 { if data[5] != 0x02 {
return Err(ProxyError::InvalidHandshake( return Err(ProxyError::InvalidHandshake(format!(
format!("Expected ServerHello (0x02), got 0x{:02x}", data[5]) "Expected ServerHello (0x02), got 0x{:02x}",
)); data[5]
)));
} }
// Parse message length // Parse message length
let msg_len = u32::from_be_bytes([0, data[6], data[7], data[8]]) as usize; let msg_len = u32::from_be_bytes([0, data[6], data[7], data[8]]) as usize;
if msg_len + 4 != record_len { if msg_len + 4 != record_len {
return Err(ProxyError::InvalidHandshake( return Err(ProxyError::InvalidHandshake(format!(
format!("Message length mismatch: {} + 4 != {}", msg_len, record_len) "Message length mismatch: {} + 4 != {}",
)); msg_len, record_len
)));
} }
Ok(()) Ok(())
@ -806,7 +901,7 @@ fn validate_server_hello_structure(data: &[u8]) -> Result<(), ProxyError> {
/// Using `static_assertions` ensures these can never silently break across /// Using `static_assertions` ensures these can never silently break across
/// refactors without a compile error. /// refactors without a compile error.
mod compile_time_security_checks { mod compile_time_security_checks {
use super::{TLS_DIGEST_LEN, TLS_DIGEST_HALF_LEN}; use super::{TLS_DIGEST_HALF_LEN, TLS_DIGEST_LEN};
use static_assertions::const_assert; use static_assertions::const_assert;
// The digest must be exactly one SHA-256 output. // The digest must be exactly one SHA-256 output.
@ -834,3 +929,7 @@ mod adversarial_tests;
#[cfg(test)] #[cfg(test)]
#[path = "tests/tls_fuzz_security_tests.rs"] #[path = "tests/tls_fuzz_security_tests.rs"]
mod fuzz_security_tests; mod fuzz_security_tests;
#[cfg(test)]
#[path = "tests/tls_length_cast_hardening_security_tests.rs"]
mod length_cast_hardening_security_tests;

View File

@ -1,3 +1,8 @@
#![allow(dead_code)]
// Adaptive buffer policy is staged and retained for deterministic rollout.
// Keep definitions compiled for compatibility and security test scaffolding.
use dashmap::DashMap; use dashmap::DashMap;
use std::cmp::max; use std::cmp::max;
use std::sync::OnceLock; use std::sync::OnceLock;
@ -170,7 +175,8 @@ impl SessionAdaptiveController {
return self.promote(TierTransitionReason::SoftConfirmed, 0); return self.promote(TierTransitionReason::SoftConfirmed, 0);
} }
let demote_candidate = self.throughput_ema_bps < THROUGHPUT_DOWN_BPS && !tier2_now && !hard_now; let demote_candidate =
self.throughput_ema_bps < THROUGHPUT_DOWN_BPS && !tier2_now && !hard_now;
if demote_candidate { if demote_candidate {
self.quiet_ticks = self.quiet_ticks.saturating_add(1); self.quiet_ticks = self.quiet_ticks.saturating_add(1);
if self.quiet_ticks >= QUIET_DEMOTE_TICKS { if self.quiet_ticks >= QUIET_DEMOTE_TICKS {
@ -253,10 +259,7 @@ pub fn record_user_tier(user: &str, tier: AdaptiveTier) {
}; };
return; return;
} }
profiles().insert( profiles().insert(user.to_string(), UserAdaptiveProfile { tier, seen_at: now });
user.to_string(),
UserAdaptiveProfile { tier, seen_at: now },
);
} }
pub fn direct_copy_buffers_for_tier( pub fn direct_copy_buffers_for_tier(
@ -339,10 +342,7 @@ mod tests {
sample( sample(
300_000, // ~9.6 Mbps 300_000, // ~9.6 Mbps
320_000, // incoming > outgoing to confirm tier2 320_000, // incoming > outgoing to confirm tier2
250_000, 250_000, 10, 0, 0,
10,
0,
0,
), ),
tick_secs, tick_secs,
); );
@ -358,10 +358,7 @@ mod tests {
fn test_hard_promotion_on_pending_pressure() { fn test_hard_promotion_on_pending_pressure() {
let mut ctrl = SessionAdaptiveController::new(AdaptiveTier::Base); let mut ctrl = SessionAdaptiveController::new(AdaptiveTier::Base);
let transition = ctrl let transition = ctrl
.observe( .observe(sample(10_000, 20_000, 10_000, 4, 1, 3), 0.25)
sample(10_000, 20_000, 10_000, 4, 1, 3),
0.25,
)
.expect("expected hard promotion"); .expect("expected hard promotion");
assert_eq!(transition.reason, TierTransitionReason::HardPressure); assert_eq!(transition.reason, TierTransitionReason::HardPressure);
assert_eq!(transition.to, AdaptiveTier::Tier1); assert_eq!(transition.to, AdaptiveTier::Tier1);

View File

@ -1,5 +1,7 @@
//! Client Handler //! Client Handler
use ipnetwork::IpNetwork;
use rand::RngExt;
use std::future::Future; use std::future::Future;
use std::net::{IpAddr, SocketAddr}; use std::net::{IpAddr, SocketAddr};
use std::pin::Pin; use std::pin::Pin;
@ -7,8 +9,6 @@ use std::sync::Arc;
use std::sync::OnceLock; use std::sync::OnceLock;
use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Duration; use std::time::Duration;
use ipnetwork::IpNetwork;
use rand::RngExt;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio::time::timeout; use tokio::time::timeout;
@ -75,10 +75,10 @@ use crate::protocol::tls;
use crate::stats::beobachten::BeobachtenStore; use crate::stats::beobachten::BeobachtenStore;
use crate::stats::{ReplayChecker, Stats}; use crate::stats::{ReplayChecker, Stats};
use crate::stream::{BufferPool, CryptoReader, CryptoWriter}; use crate::stream::{BufferPool, CryptoReader, CryptoWriter};
use crate::transport::middle_proxy::MePool;
use crate::transport::{UpstreamManager, configure_client_socket, parse_proxy_protocol};
use crate::transport::socket::normalize_ip;
use crate::tls_front::TlsFrontCache; use crate::tls_front::TlsFrontCache;
use crate::transport::middle_proxy::MePool;
use crate::transport::socket::normalize_ip;
use crate::transport::{UpstreamManager, configure_client_socket, parse_proxy_protocol};
use crate::proxy::direct_relay::handle_via_direct; use crate::proxy::direct_relay::handle_via_direct;
use crate::proxy::handshake::{HandshakeSuccess, handle_mtproto_handshake, handle_tls_handshake}; use crate::proxy::handshake::{HandshakeSuccess, handle_mtproto_handshake, handle_tls_handshake};
@ -116,11 +116,23 @@ fn beobachten_ttl(config: &ProxyConfig) -> Duration {
} }
fn wrap_tls_application_record(payload: &[u8]) -> Vec<u8> { fn wrap_tls_application_record(payload: &[u8]) -> Vec<u8> {
let mut record = Vec::with_capacity(5 + payload.len()); let chunks = payload.len().div_ceil(u16::MAX as usize).max(1);
let mut record = Vec::with_capacity(payload.len() + 5 * chunks);
if payload.is_empty() {
record.push(TLS_RECORD_APPLICATION); record.push(TLS_RECORD_APPLICATION);
record.extend_from_slice(&TLS_VERSION); record.extend_from_slice(&TLS_VERSION);
record.extend_from_slice(&(payload.len() as u16).to_be_bytes()); record.extend_from_slice(&0u16.to_be_bytes());
record.extend_from_slice(payload); return record;
}
for chunk in payload.chunks(u16::MAX as usize) {
record.push(TLS_RECORD_APPLICATION);
record.extend_from_slice(&TLS_VERSION);
record.extend_from_slice(&(chunk.len() as u16).to_be_bytes());
record.extend_from_slice(chunk);
}
record record
} }
@ -128,7 +140,10 @@ fn tls_clienthello_len_in_bounds(tls_len: usize) -> bool {
(MIN_TLS_CLIENT_HELLO_SIZE..=MAX_TLS_PLAINTEXT_SIZE).contains(&tls_len) (MIN_TLS_CLIENT_HELLO_SIZE..=MAX_TLS_PLAINTEXT_SIZE).contains(&tls_len)
} }
async fn read_with_progress<R: AsyncRead + Unpin>(reader: &mut R, mut buf: &mut [u8]) -> std::io::Result<usize> { async fn read_with_progress<R: AsyncRead + Unpin>(
reader: &mut R,
mut buf: &mut [u8],
) -> std::io::Result<usize> {
let mut total = 0usize; let mut total = 0usize;
while !buf.is_empty() { while !buf.is_empty() {
match reader.read(buf).await { match reader.read(buf).await {
@ -271,10 +286,14 @@ where
let mut local_addr = synthetic_local_addr(config.server.port); let mut local_addr = synthetic_local_addr(config.server.port);
if proxy_protocol_enabled { if proxy_protocol_enabled {
let proxy_header_timeout = Duration::from_millis( let proxy_header_timeout =
config.server.proxy_protocol_header_timeout_ms.max(1), Duration::from_millis(config.server.proxy_protocol_header_timeout_ms.max(1));
); match timeout(
match timeout(proxy_header_timeout, parse_proxy_protocol(&mut stream, peer)).await { proxy_header_timeout,
parse_proxy_protocol(&mut stream, peer),
)
.await
{
Ok(Ok(info)) => { Ok(Ok(info)) => {
if !is_trusted_proxy_source(peer.ip(), &config.server.proxy_protocol_trusted_cidrs) if !is_trusted_proxy_source(peer.ip(), &config.server.proxy_protocol_trusted_cidrs)
{ {
@ -674,9 +693,8 @@ impl RunningClientHandler {
let mut local_addr = self.stream.local_addr().map_err(ProxyError::Io)?; let mut local_addr = self.stream.local_addr().map_err(ProxyError::Io)?;
if self.proxy_protocol_enabled { if self.proxy_protocol_enabled {
let proxy_header_timeout = Duration::from_millis( let proxy_header_timeout =
self.config.server.proxy_protocol_header_timeout_ms.max(1), Duration::from_millis(self.config.server.proxy_protocol_header_timeout_ms.max(1));
);
match timeout( match timeout(
proxy_header_timeout, proxy_header_timeout,
parse_proxy_protocol(&mut self.stream, self.peer), parse_proxy_protocol(&mut self.stream, self.peer),
@ -761,7 +779,11 @@ impl RunningClientHandler {
} }
} }
async fn handle_tls_client(mut self, first_bytes: [u8; 5], local_addr: SocketAddr) -> Result<HandshakeOutcome> { async fn handle_tls_client(
mut self,
first_bytes: [u8; 5],
local_addr: SocketAddr,
) -> Result<HandshakeOutcome> {
let peer = self.peer; let peer = self.peer;
let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize; let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize;
@ -895,7 +917,8 @@ impl RunningClientHandler {
} else { } else {
wrap_tls_application_record(&pending_plaintext) wrap_tls_application_record(&pending_plaintext)
}; };
let reader = tokio::io::AsyncReadExt::chain(std::io::Cursor::new(pending_record), reader); let reader =
tokio::io::AsyncReadExt::chain(std::io::Cursor::new(pending_record), reader);
stats.increment_connects_bad(); stats.increment_connects_bad();
debug!( debug!(
peer = %peer, peer = %peer,
@ -933,7 +956,11 @@ impl RunningClientHandler {
))) )))
} }
async fn handle_direct_client(mut self, first_bytes: [u8; 5], local_addr: SocketAddr) -> Result<HandshakeOutcome> { async fn handle_direct_client(
mut self,
first_bytes: [u8; 5],
local_addr: SocketAddr,
) -> Result<HandshakeOutcome> {
let peer = self.peer; let peer = self.peer;
if !self.config.general.modes.classic && !self.config.general.modes.secure { if !self.config.general.modes.classic && !self.config.general.modes.secure {
@ -1035,8 +1062,7 @@ impl RunningClientHandler {
{ {
let user = success.user.clone(); let user = success.user.clone();
let user_limit_reservation = let user_limit_reservation = match Self::acquire_user_connection_reservation_static(
match Self::acquire_user_connection_reservation_static(
&user, &user,
&config, &config,
stats.clone(), stats.clone(),
@ -1134,7 +1160,11 @@ impl RunningClientHandler {
}); });
} }
let limit = config.access.user_max_tcp_conns.get(user).map(|v| *v as u64); let limit = config
.access
.user_max_tcp_conns
.get(user)
.map(|v| *v as u64);
if !stats.try_acquire_user_curr_connects(user, limit) { if !stats.try_acquire_user_curr_connects(user, limit) {
return Err(ProxyError::ConnectionLimitExceeded { return Err(ProxyError::ConnectionLimitExceeded {
user: user.to_string(), user: user.to_string(),
@ -1294,3 +1324,7 @@ mod masking_probe_evasion_blackhat_tests;
#[cfg(test)] #[cfg(test)]
#[path = "tests/client_beobachten_ttl_bounds_security_tests.rs"] #[path = "tests/client_beobachten_ttl_bounds_security_tests.rs"]
mod beobachten_ttl_bounds_security_tests; mod beobachten_ttl_bounds_security_tests;
#[cfg(test)]
#[path = "tests/client_tls_record_wrap_hardening_security_tests.rs"]
mod tls_record_wrap_hardening_security_tests;

View File

@ -1,10 +1,10 @@
use std::collections::HashSet;
use std::ffi::OsString; use std::ffi::OsString;
use std::fs::OpenOptions; use std::fs::OpenOptions;
use std::io::Write; use std::io::Write;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::path::{Component, Path, PathBuf}; use std::path::{Component, Path, PathBuf};
use std::sync::Arc; use std::sync::Arc;
use std::collections::HashSet;
use std::sync::{Mutex, OnceLock}; use std::sync::{Mutex, OnceLock};
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf, split}; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf, split};
@ -24,13 +24,13 @@ use crate::proxy::route_mode::{
use crate::stats::Stats; use crate::stats::Stats;
use crate::stream::{BufferPool, CryptoReader, CryptoWriter}; use crate::stream::{BufferPool, CryptoReader, CryptoWriter};
use crate::transport::UpstreamManager; use crate::transport::UpstreamManager;
#[cfg(unix)]
use nix::fcntl::{Flock, FlockArg, OFlag, openat};
#[cfg(unix)]
use nix::sys::stat::Mode;
#[cfg(unix)] #[cfg(unix)]
use std::os::unix::fs::OpenOptionsExt; use std::os::unix::fs::OpenOptionsExt;
#[cfg(unix)]
use std::os::unix::ffi::OsStrExt;
#[cfg(unix)]
use std::os::unix::io::{AsRawFd, FromRawFd};
const UNKNOWN_DC_LOG_DISTINCT_LIMIT: usize = 1024; const UNKNOWN_DC_LOG_DISTINCT_LIMIT: usize = 1024;
static LOGGED_UNKNOWN_DCS: OnceLock<Mutex<HashSet<i16>>> = OnceLock::new(); static LOGGED_UNKNOWN_DCS: OnceLock<Mutex<HashSet<i16>>> = OnceLock::new();
@ -160,7 +160,9 @@ fn open_unknown_dc_log_append(path: &Path) -> std::io::Result<std::fs::File> {
} }
} }
fn open_unknown_dc_log_append_anchored(path: &SanitizedUnknownDcLogPath) -> std::io::Result<std::fs::File> { fn open_unknown_dc_log_append_anchored(
path: &SanitizedUnknownDcLogPath,
) -> std::io::Result<std::fs::File> {
#[cfg(unix)] #[cfg(unix)]
{ {
let parent = OpenOptions::new() let parent = OpenOptions::new()
@ -168,23 +170,16 @@ fn open_unknown_dc_log_append_anchored(path: &SanitizedUnknownDcLogPath) -> std:
.custom_flags(libc::O_DIRECTORY | libc::O_NOFOLLOW | libc::O_CLOEXEC) .custom_flags(libc::O_DIRECTORY | libc::O_NOFOLLOW | libc::O_CLOEXEC)
.open(&path.allowed_parent)?; .open(&path.allowed_parent)?;
let file_name = std::ffi::CString::new(path.file_name.as_os_str().as_bytes()) let oflags = OFlag::O_CREAT
.map_err(|_| std::io::Error::new(std::io::ErrorKind::InvalidInput, "unknown DC log file name contains NUL byte"))?; | OFlag::O_APPEND
| OFlag::O_WRONLY
let fd = unsafe { | OFlag::O_NOFOLLOW
libc::openat( | OFlag::O_CLOEXEC;
parent.as_raw_fd(), let mode = Mode::from_bits_truncate(0o600);
file_name.as_ptr(), let path_component = Path::new(path.file_name.as_os_str());
libc::O_CREAT | libc::O_APPEND | libc::O_WRONLY | libc::O_NOFOLLOW | libc::O_CLOEXEC, let fd = openat(&parent, path_component, oflags, mode)
0o600, .map_err(|err| std::io::Error::from_raw_os_error(err as i32))?;
) let file = std::fs::File::from(fd);
};
if fd < 0 {
return Err(std::io::Error::last_os_error());
}
let file = unsafe { std::fs::File::from_raw_fd(fd) };
Ok(file) Ok(file)
} }
#[cfg(not(unix))] #[cfg(not(unix))]
@ -200,16 +195,13 @@ fn open_unknown_dc_log_append_anchored(path: &SanitizedUnknownDcLogPath) -> std:
fn append_unknown_dc_line(file: &mut std::fs::File, dc_idx: i16) -> std::io::Result<()> { fn append_unknown_dc_line(file: &mut std::fs::File, dc_idx: i16) -> std::io::Result<()> {
#[cfg(unix)] #[cfg(unix)]
{ {
if unsafe { libc::flock(file.as_raw_fd(), libc::LOCK_EX) } != 0 { let cloned = file.try_clone()?;
return Err(std::io::Error::last_os_error()); let mut locked = Flock::lock(cloned, FlockArg::LockExclusive)
} .map_err(|(_, err)| std::io::Error::from_raw_os_error(err as i32))?;
let write_result = writeln!(&mut *locked, "dc_idx={dc_idx}");
let write_result = writeln!(file, "dc_idx={dc_idx}"); let _ = locked
.unlock()
if unsafe { libc::flock(file.as_raw_fd(), libc::LOCK_UN) } != 0 { .map_err(|(_, err)| std::io::Error::from_raw_os_error(err as i32))?;
return Err(std::io::Error::last_os_error());
}
write_result write_result
} }
#[cfg(not(unix))] #[cfg(not(unix))]

View File

@ -2,29 +2,29 @@
#![allow(dead_code)] #![allow(dead_code)]
use std::net::SocketAddr; use dashmap::DashMap;
use dashmap::mapref::entry::Entry;
use std::collections::HashSet; use std::collections::HashSet;
use std::collections::hash_map::RandomState; use std::collections::hash_map::RandomState;
use std::hash::{BuildHasher, Hash, Hasher};
use std::net::SocketAddr;
use std::net::{IpAddr, Ipv6Addr}; use std::net::{IpAddr, Ipv6Addr};
use std::sync::Arc; use std::sync::Arc;
use std::sync::{Mutex, OnceLock}; use std::sync::{Mutex, OnceLock};
use std::hash::{BuildHasher, Hash, Hasher};
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use dashmap::DashMap;
use dashmap::mapref::entry::Entry;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tracing::{debug, warn, trace}; use tracing::{debug, trace, warn};
use zeroize::{Zeroize, Zeroizing}; use zeroize::{Zeroize, Zeroizing};
use crate::crypto::{sha256, AesCtr, SecureRandom}; use crate::config::ProxyConfig;
use rand::RngExt; use crate::crypto::{AesCtr, SecureRandom, sha256};
use crate::error::{HandshakeResult, ProxyError};
use crate::protocol::constants::*; use crate::protocol::constants::*;
use crate::protocol::tls; use crate::protocol::tls;
use crate::stream::{FakeTlsReader, FakeTlsWriter, CryptoReader, CryptoWriter};
use crate::error::{ProxyError, HandshakeResult};
use crate::stats::ReplayChecker; use crate::stats::ReplayChecker;
use crate::config::ProxyConfig; use crate::stream::{CryptoReader, CryptoWriter, FakeTlsReader, FakeTlsWriter};
use crate::tls_front::{TlsFrontCache, emulator}; use crate::tls_front::{TlsFrontCache, emulator};
use rand::RngExt;
const ACCESS_SECRET_BYTES: usize = 16; const ACCESS_SECRET_BYTES: usize = 16;
static INVALID_SECRET_WARNED: OnceLock<Mutex<HashSet<(String, String)>>> = OnceLock::new(); static INVALID_SECRET_WARNED: OnceLock<Mutex<HashSet<(String, String)>>> = OnceLock::new();
@ -67,7 +67,8 @@ struct AuthProbeSaturationState {
} }
static AUTH_PROBE_STATE: OnceLock<DashMap<IpAddr, AuthProbeState>> = OnceLock::new(); static AUTH_PROBE_STATE: OnceLock<DashMap<IpAddr, AuthProbeState>> = OnceLock::new();
static AUTH_PROBE_SATURATION_STATE: OnceLock<Mutex<Option<AuthProbeSaturationState>>> = OnceLock::new(); static AUTH_PROBE_SATURATION_STATE: OnceLock<Mutex<Option<AuthProbeSaturationState>>> =
OnceLock::new();
static AUTH_PROBE_EVICTION_HASHER: OnceLock<RandomState> = OnceLock::new(); static AUTH_PROBE_EVICTION_HASHER: OnceLock<RandomState> = OnceLock::new();
fn auth_probe_state_map() -> &'static DashMap<IpAddr, AuthProbeState> { fn auth_probe_state_map() -> &'static DashMap<IpAddr, AuthProbeState> {
@ -78,8 +79,8 @@ fn auth_probe_saturation_state() -> &'static Mutex<Option<AuthProbeSaturationSta
AUTH_PROBE_SATURATION_STATE.get_or_init(|| Mutex::new(None)) AUTH_PROBE_SATURATION_STATE.get_or_init(|| Mutex::new(None))
} }
fn auth_probe_saturation_state_lock( fn auth_probe_saturation_state_lock()
) -> std::sync::MutexGuard<'static, Option<AuthProbeSaturationState>> { -> std::sync::MutexGuard<'static, Option<AuthProbeSaturationState>> {
auth_probe_saturation_state() auth_probe_saturation_state()
.lock() .lock()
.unwrap_or_else(|poisoned| poisoned.into_inner()) .unwrap_or_else(|poisoned| poisoned.into_inner())
@ -252,9 +253,7 @@ fn auth_probe_record_failure_with_state(
match eviction_candidate { match eviction_candidate {
Some((_, current_fail, current_seen)) Some((_, current_fail, current_seen))
if fail_streak > current_fail if fail_streak > current_fail
|| (fail_streak == current_fail && last_seen >= current_seen) => || (fail_streak == current_fail && last_seen >= current_seen) => {}
{
}
_ => eviction_candidate = Some((key, fail_streak, last_seen)), _ => eviction_candidate = Some((key, fail_streak, last_seen)),
} }
} }
@ -284,9 +283,7 @@ fn auth_probe_record_failure_with_state(
match eviction_candidate { match eviction_candidate {
Some((_, current_fail, current_seen)) Some((_, current_fail, current_seen))
if fail_streak > current_fail if fail_streak > current_fail
|| (fail_streak == current_fail && last_seen >= current_seen) => || (fail_streak == current_fail && last_seen >= current_seen) => {}
{
}
_ => eviction_candidate = Some((key, fail_streak, last_seen)), _ => eviction_candidate = Some((key, fail_streak, last_seen)),
} }
if auth_probe_state_expired(entry.value(), now) { if auth_probe_state_expired(entry.value(), now) {
@ -306,9 +303,7 @@ fn auth_probe_record_failure_with_state(
match eviction_candidate { match eviction_candidate {
Some((_, current_fail, current_seen)) Some((_, current_fail, current_seen))
if fail_streak > current_fail if fail_streak > current_fail
|| (fail_streak == current_fail && last_seen >= current_seen) => || (fail_streak == current_fail && last_seen >= current_seen) => {}
{
}
_ => eviction_candidate = Some((key, fail_streak, last_seen)), _ => eviction_candidate = Some((key, fail_streak, last_seen)),
} }
if auth_probe_state_expired(entry.value(), now) { if auth_probe_state_expired(entry.value(), now) {
@ -545,7 +540,6 @@ pub struct HandshakeSuccess {
/// Client address /// Client address
pub peer: SocketAddr, pub peer: SocketAddr,
/// Whether TLS was used /// Whether TLS was used
pub is_tls: bool, pub is_tls: bool,
} }
@ -632,7 +626,7 @@ where
let cached = if config.censorship.tls_emulation { let cached = if config.censorship.tls_emulation {
if let Some(cache) = tls_cache.as_ref() { if let Some(cache) = tls_cache.as_ref() {
let selected_domain = if let Some(sni) = client_sni.as_ref() { let selected_domain = if let Some(sni) = client_sni.as_ref() {
if cache.contains_domain(&sni).await { if cache.contains_domain(sni).await {
sni.clone() sni.clone()
} else { } else {
config.censorship.tls_domain.clone() config.censorship.tls_domain.clone()
@ -769,7 +763,6 @@ where
let decoded_users = decode_user_secrets(config, preferred_user); let decoded_users = decode_user_secrets(config, preferred_user);
for (user, secret) in decoded_users { for (user, secret) in decoded_users {
let dec_prekey = &dec_prekey_iv[..PREKEY_LEN]; let dec_prekey = &dec_prekey_iv[..PREKEY_LEN];
let dec_iv_bytes = &dec_prekey_iv[PREKEY_LEN..]; let dec_iv_bytes = &dec_prekey_iv[PREKEY_LEN..];
@ -820,7 +813,7 @@ where
let encryptor = AesCtr::new(&enc_key, enc_iv); let encryptor = AesCtr::new(&enc_key, enc_iv);
// Apply replay tracking only after successful authentication. // Apply replay tracking only after successful authentication.
// //
// This ordering prevents an attacker from producing invalid handshakes that // This ordering prevents an attacker from producing invalid handshakes that
// still collide with a valid handshake's replay slot and thus evict a valid // still collide with a valid handshake's replay slot and thus evict a valid
@ -885,13 +878,19 @@ pub fn generate_tg_nonce(
continue; continue;
}; };
if RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]) { continue; } if RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]) {
continue;
}
let first_four: [u8; 4] = [nonce[0], nonce[1], nonce[2], nonce[3]]; let first_four: [u8; 4] = [nonce[0], nonce[1], nonce[2], nonce[3]];
if RESERVED_NONCE_BEGINNINGS.contains(&first_four) { continue; } if RESERVED_NONCE_BEGINNINGS.contains(&first_four) {
continue;
}
let continue_four: [u8; 4] = [nonce[4], nonce[5], nonce[6], nonce[7]]; let continue_four: [u8; 4] = [nonce[4], nonce[5], nonce[6], nonce[7]];
if RESERVED_NONCE_CONTINUES.contains(&continue_four) { continue; } if RESERVED_NONCE_CONTINUES.contains(&continue_four) {
continue;
}
nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes()); nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes());
// CRITICAL: write dc_idx so upstream DC knows where to route // CRITICAL: write dc_idx so upstream DC knows where to route
@ -955,7 +954,6 @@ pub fn encrypt_tg_nonce_with_ciphers(nonce: &[u8; HANDSHAKE_LEN]) -> (Vec<u8>, A
} }
/// Encrypt nonce for sending to Telegram (legacy function for compatibility) /// Encrypt nonce for sending to Telegram (legacy function for compatibility)
pub fn encrypt_tg_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec<u8> { pub fn encrypt_tg_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec<u8> {
let (encrypted, _, _) = encrypt_tg_nonce_with_ciphers(nonce); let (encrypted, _, _) = encrypt_tg_nonce_with_ciphers(nonce);
encrypted encrypted

View File

@ -1,19 +1,19 @@
//! Masking - forward unrecognized traffic to mask host //! Masking - forward unrecognized traffic to mask host
use std::str;
use std::net::SocketAddr;
use std::time::Duration;
use rand::{Rng, RngExt};
use tokio::net::TcpStream;
#[cfg(unix)]
use tokio::net::UnixStream;
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt};
use tokio::time::{Instant, timeout};
use tracing::debug;
use crate::config::ProxyConfig; use crate::config::ProxyConfig;
use crate::network::dns_overrides::resolve_socket_addr; use crate::network::dns_overrides::resolve_socket_addr;
use crate::stats::beobachten::BeobachtenStore; use crate::stats::beobachten::BeobachtenStore;
use crate::transport::proxy_protocol::{ProxyProtocolV1Builder, ProxyProtocolV2Builder}; use crate::transport::proxy_protocol::{ProxyProtocolV1Builder, ProxyProtocolV2Builder};
use rand::{Rng, RngExt};
use std::net::SocketAddr;
use std::str;
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpStream;
#[cfg(unix)]
use tokio::net::UnixStream;
use tokio::time::{Instant, timeout};
use tracing::debug;
#[cfg(not(test))] #[cfg(not(test))]
const MASK_TIMEOUT: Duration = Duration::from_secs(5); const MASK_TIMEOUT: Duration = Duration::from_secs(5);
@ -98,8 +98,8 @@ async fn maybe_write_shape_padding<W>(
cap: usize, cap: usize,
above_cap_blur: bool, above_cap_blur: bool,
above_cap_blur_max_bytes: usize, above_cap_blur_max_bytes: usize,
) aggressive_mode: bool,
where ) where
W: AsyncWrite + Unpin, W: AsyncWrite + Unpin,
{ {
if !enabled { if !enabled {
@ -108,7 +108,11 @@ where
let target_total = if total_sent >= cap && above_cap_blur && above_cap_blur_max_bytes > 0 { let target_total = if total_sent >= cap && above_cap_blur && above_cap_blur_max_bytes > 0 {
let mut rng = rand::rng(); let mut rng = rand::rng();
let extra = rng.random_range(0..=above_cap_blur_max_bytes); let extra = if aggressive_mode {
rng.random_range(1..=above_cap_blur_max_bytes)
} else {
rng.random_range(0..=above_cap_blur_max_bytes)
};
total_sent.saturating_add(extra) total_sent.saturating_add(extra)
} else { } else {
next_mask_shape_bucket(total_sent, floor, cap) next_mask_shape_bucket(total_sent, floor, cap)
@ -167,7 +171,10 @@ async fn consume_client_data_with_timeout<R>(reader: R)
where where
R: AsyncRead + Unpin, R: AsyncRead + Unpin,
{ {
if timeout(MASK_RELAY_TIMEOUT, consume_client_data(reader)).await.is_err() { if timeout(MASK_RELAY_TIMEOUT, consume_client_data(reader))
.await
.is_err()
{
debug!("Timed out while consuming client data on masking fallback path"); debug!("Timed out while consuming client data on masking fallback path");
} }
} }
@ -213,9 +220,12 @@ async fn wait_mask_outcome_budget(started: Instant, config: &ProxyConfig) {
fn detect_client_type(data: &[u8]) -> &'static str { fn detect_client_type(data: &[u8]) -> &'static str {
// Check for HTTP request // Check for HTTP request
if data.len() > 4 if data.len() > 4
&& (data.starts_with(b"GET ") || data.starts_with(b"POST") || && (data.starts_with(b"GET ")
data.starts_with(b"HEAD") || data.starts_with(b"PUT ") || || data.starts_with(b"POST")
data.starts_with(b"DELETE") || data.starts_with(b"OPTIONS")) || data.starts_with(b"HEAD")
|| data.starts_with(b"PUT ")
|| data.starts_with(b"DELETE")
|| data.starts_with(b"OPTIONS"))
{ {
return "HTTP"; return "HTTP";
} }
@ -252,16 +262,12 @@ fn build_mask_proxy_header(
), ),
_ => { _ => {
let header = match (peer, local_addr) { let header = match (peer, local_addr) {
(SocketAddr::V4(src), SocketAddr::V4(dst)) => { (SocketAddr::V4(src), SocketAddr::V4(dst)) => ProxyProtocolV1Builder::new()
ProxyProtocolV1Builder::new()
.tcp4(src.into(), dst.into()) .tcp4(src.into(), dst.into())
.build() .build(),
} (SocketAddr::V6(src), SocketAddr::V6(dst)) => ProxyProtocolV1Builder::new()
(SocketAddr::V6(src), SocketAddr::V6(dst)) => {
ProxyProtocolV1Builder::new()
.tcp6(src.into(), dst.into()) .tcp6(src.into(), dst.into())
.build() .build(),
}
_ => ProxyProtocolV1Builder::new().build(), _ => ProxyProtocolV1Builder::new().build(),
}; };
Some(header) Some(header)
@ -278,8 +284,7 @@ pub async fn handle_bad_client<R, W>(
local_addr: SocketAddr, local_addr: SocketAddr,
config: &ProxyConfig, config: &ProxyConfig,
beobachten: &BeobachtenStore, beobachten: &BeobachtenStore,
) ) where
where
R: AsyncRead + Unpin + Send + 'static, R: AsyncRead + Unpin + Send + 'static,
W: AsyncWrite + Unpin + Send + 'static, W: AsyncWrite + Unpin + Send + 'static,
{ {
@ -311,14 +316,17 @@ where
match connect_result { match connect_result {
Ok(Ok(stream)) => { Ok(Ok(stream)) => {
let (mask_read, mut mask_write) = stream.into_split(); let (mask_read, mut mask_write) = stream.into_split();
let proxy_header = let proxy_header = build_mask_proxy_header(
build_mask_proxy_header(config.censorship.mask_proxy_protocol, peer, local_addr); config.censorship.mask_proxy_protocol,
if let Some(header) = proxy_header { peer,
if !write_proxy_header_with_timeout(&mut mask_write, &header).await { local_addr,
);
if let Some(header) = proxy_header
&& !write_proxy_header_with_timeout(&mut mask_write, &header).await
{
wait_mask_outcome_budget(outcome_started, config).await; wait_mask_outcome_budget(outcome_started, config).await;
return; return;
} }
}
if timeout( if timeout(
MASK_RELAY_TIMEOUT, MASK_RELAY_TIMEOUT,
relay_to_mask( relay_to_mask(
@ -332,6 +340,7 @@ where
config.censorship.mask_shape_bucket_cap_bytes, config.censorship.mask_shape_bucket_cap_bytes,
config.censorship.mask_shape_above_cap_blur, config.censorship.mask_shape_above_cap_blur,
config.censorship.mask_shape_above_cap_blur_max_bytes, config.censorship.mask_shape_above_cap_blur_max_bytes,
config.censorship.mask_shape_hardening_aggressive_mode,
), ),
) )
.await .await
@ -356,7 +365,10 @@ where
return; return;
} }
let mask_host = config.censorship.mask_host.as_deref() let mask_host = config
.censorship
.mask_host
.as_deref()
.unwrap_or(&config.censorship.tls_domain); .unwrap_or(&config.censorship.tls_domain);
let mask_port = config.censorship.mask_port; let mask_port = config.censorship.mask_port;
@ -381,12 +393,12 @@ where
build_mask_proxy_header(config.censorship.mask_proxy_protocol, peer, local_addr); build_mask_proxy_header(config.censorship.mask_proxy_protocol, peer, local_addr);
let (mask_read, mut mask_write) = stream.into_split(); let (mask_read, mut mask_write) = stream.into_split();
if let Some(header) = proxy_header { if let Some(header) = proxy_header
if !write_proxy_header_with_timeout(&mut mask_write, &header).await { && !write_proxy_header_with_timeout(&mut mask_write, &header).await
{
wait_mask_outcome_budget(outcome_started, config).await; wait_mask_outcome_budget(outcome_started, config).await;
return; return;
} }
}
if timeout( if timeout(
MASK_RELAY_TIMEOUT, MASK_RELAY_TIMEOUT,
relay_to_mask( relay_to_mask(
@ -400,6 +412,7 @@ where
config.censorship.mask_shape_bucket_cap_bytes, config.censorship.mask_shape_bucket_cap_bytes,
config.censorship.mask_shape_above_cap_blur, config.censorship.mask_shape_above_cap_blur,
config.censorship.mask_shape_above_cap_blur_max_bytes, config.censorship.mask_shape_above_cap_blur_max_bytes,
config.censorship.mask_shape_hardening_aggressive_mode,
), ),
) )
.await .await
@ -435,8 +448,8 @@ async fn relay_to_mask<R, W, MR, MW>(
shape_bucket_cap_bytes: usize, shape_bucket_cap_bytes: usize,
shape_above_cap_blur: bool, shape_above_cap_blur: bool,
shape_above_cap_blur_max_bytes: usize, shape_above_cap_blur_max_bytes: usize,
) shape_hardening_aggressive_mode: bool,
where ) where
R: AsyncRead + Unpin + Send + 'static, R: AsyncRead + Unpin + Send + 'static,
W: AsyncWrite + Unpin + Send + 'static, W: AsyncWrite + Unpin + Send + 'static,
MR: AsyncRead + Unpin + Send + 'static, MR: AsyncRead + Unpin + Send + 'static,
@ -450,14 +463,17 @@ where
return; return;
} }
let _ = tokio::join!( let (upstream_copy, downstream_copy) = tokio::join!(
async { async { copy_with_idle_timeout(&mut reader, &mut mask_write).await },
let copied = copy_with_idle_timeout(&mut reader, &mut mask_write).await; async { copy_with_idle_timeout(&mut mask_read, &mut writer).await }
let total_sent = initial_data.len().saturating_add(copied.total); );
let total_sent = initial_data.len().saturating_add(upstream_copy.total);
let should_shape = shape_hardening_enabled let should_shape = shape_hardening_enabled
&& copied.ended_by_eof && !initial_data.is_empty()
&& !initial_data.is_empty(); && (upstream_copy.ended_by_eof
|| (shape_hardening_aggressive_mode && downstream_copy.total == 0));
maybe_write_shape_padding( maybe_write_shape_padding(
&mut mask_write, &mut mask_write,
@ -467,15 +483,12 @@ where
shape_bucket_cap_bytes, shape_bucket_cap_bytes,
shape_above_cap_blur, shape_above_cap_blur,
shape_above_cap_blur_max_bytes, shape_above_cap_blur_max_bytes,
shape_hardening_aggressive_mode,
) )
.await; .await;
let _ = mask_write.shutdown().await; let _ = mask_write.shutdown().await;
},
async {
let _ = copy_with_idle_timeout(&mut mask_read, &mut writer).await;
let _ = writer.shutdown().await; let _ = writer.shutdown().await;
}
);
} }
/// Just consume all data from client without responding /// Just consume all data from client without responding
@ -524,6 +537,14 @@ mod masking_shape_guard_adversarial_tests;
#[path = "tests/masking_shape_classifier_resistance_adversarial_tests.rs"] #[path = "tests/masking_shape_classifier_resistance_adversarial_tests.rs"]
mod masking_shape_classifier_resistance_adversarial_tests; mod masking_shape_classifier_resistance_adversarial_tests;
#[cfg(test)]
#[path = "tests/masking_shape_bypass_blackhat_tests.rs"]
mod masking_shape_bypass_blackhat_tests;
#[cfg(test)]
#[path = "tests/masking_aggressive_mode_security_tests.rs"]
mod masking_aggressive_mode_security_tests;
#[cfg(test)] #[cfg(test)]
#[path = "tests/masking_timing_sidechannel_redteam_expected_fail_tests.rs"] #[path = "tests/masking_timing_sidechannel_redteam_expected_fail_tests.rs"]
mod masking_timing_sidechannel_redteam_expected_fail_tests; mod masking_timing_sidechannel_redteam_expected_fail_tests;

View File

@ -1,7 +1,6 @@
use std::collections::hash_map::RandomState; use std::collections::hash_map::RandomState;
use std::collections::{BTreeSet, HashMap}; use std::collections::{BTreeSet, HashMap};
use std::hash::BuildHasher; use std::hash::{BuildHasher, Hash};
use std::hash::{Hash, Hasher};
use std::net::{IpAddr, SocketAddr}; use std::net::{IpAddr, SocketAddr};
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::{Arc, Mutex, OnceLock}; use std::sync::{Arc, Mutex, OnceLock};
@ -9,17 +8,17 @@ use std::time::{Duration, Instant};
use dashmap::DashMap; use dashmap::DashMap;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::sync::{mpsc, oneshot, watch, Mutex as AsyncMutex}; use tokio::sync::{Mutex as AsyncMutex, mpsc, oneshot, watch};
use tokio::time::timeout; use tokio::time::timeout;
use tracing::{debug, info, trace, warn}; use tracing::{debug, info, trace, warn};
use crate::config::ProxyConfig; use crate::config::ProxyConfig;
use crate::crypto::SecureRandom; use crate::crypto::SecureRandom;
use crate::error::{ProxyError, Result}; use crate::error::{ProxyError, Result};
use crate::protocol::constants::{*, secure_padding_len}; use crate::protocol::constants::{secure_padding_len, *};
use crate::proxy::handshake::HandshakeSuccess; use crate::proxy::handshake::HandshakeSuccess;
use crate::proxy::route_mode::{ use crate::proxy::route_mode::{
RelayRouteMode, RouteCutoverState, ROUTE_SWITCH_ERROR_MSG, affected_cutover_state, ROUTE_SWITCH_ERROR_MSG, RelayRouteMode, RouteCutoverState, affected_cutover_state,
cutover_stagger_delay, cutover_stagger_delay,
}; };
use crate::stats::Stats; use crate::stats::Stats;
@ -50,11 +49,16 @@ const ME_D2C_FLUSH_BATCH_MAX_BYTES_MIN: usize = 4096;
const QUOTA_USER_LOCKS_MAX: usize = 64; const QUOTA_USER_LOCKS_MAX: usize = 64;
#[cfg(not(test))] #[cfg(not(test))]
const QUOTA_USER_LOCKS_MAX: usize = 4_096; const QUOTA_USER_LOCKS_MAX: usize = 4_096;
#[cfg(test)]
const QUOTA_OVERFLOW_LOCK_STRIPES: usize = 16;
#[cfg(not(test))]
const QUOTA_OVERFLOW_LOCK_STRIPES: usize = 256;
static DESYNC_DEDUP: OnceLock<DashMap<u64, Instant>> = OnceLock::new(); static DESYNC_DEDUP: OnceLock<DashMap<u64, Instant>> = OnceLock::new();
static DESYNC_HASHER: OnceLock<RandomState> = OnceLock::new(); static DESYNC_HASHER: OnceLock<RandomState> = OnceLock::new();
static DESYNC_FULL_CACHE_LAST_EMIT_AT: OnceLock<Mutex<Option<Instant>>> = OnceLock::new(); static DESYNC_FULL_CACHE_LAST_EMIT_AT: OnceLock<Mutex<Option<Instant>>> = OnceLock::new();
static DESYNC_DEDUP_EVER_SATURATED: OnceLock<AtomicBool> = OnceLock::new(); static DESYNC_DEDUP_EVER_SATURATED: OnceLock<AtomicBool> = OnceLock::new();
static QUOTA_USER_LOCKS: OnceLock<DashMap<String, Arc<AsyncMutex<()>>>> = OnceLock::new(); static QUOTA_USER_LOCKS: OnceLock<DashMap<String, Arc<AsyncMutex<()>>>> = OnceLock::new();
static QUOTA_USER_OVERFLOW_LOCKS: OnceLock<Vec<Arc<AsyncMutex<()>>>> = OnceLock::new();
static RELAY_IDLE_CANDIDATE_REGISTRY: OnceLock<Mutex<RelayIdleCandidateRegistry>> = OnceLock::new(); static RELAY_IDLE_CANDIDATE_REGISTRY: OnceLock<Mutex<RelayIdleCandidateRegistry>> = OnceLock::new();
static RELAY_IDLE_MARK_SEQ: AtomicU64 = AtomicU64::new(0); static RELAY_IDLE_MARK_SEQ: AtomicU64 = AtomicU64::new(0);
@ -286,9 +290,7 @@ impl MeD2cFlushPolicy {
fn hash_value<T: Hash>(value: &T) -> u64 { fn hash_value<T: Hash>(value: &T) -> u64 {
let state = DESYNC_HASHER.get_or_init(RandomState::new); let state = DESYNC_HASHER.get_or_init(RandomState::new);
let mut hasher = state.build_hasher(); state.hash_one(value)
value.hash(&mut hasher);
hasher.finish()
} }
fn hash_ip(ip: IpAddr) -> u64 { fn hash_ip(ip: IpAddr) -> u64 {
@ -416,6 +418,13 @@ fn desync_dedup_test_lock() -> &'static Mutex<()> {
TEST_LOCK.get_or_init(|| Mutex::new(())) TEST_LOCK.get_or_init(|| Mutex::new(()))
} }
fn desync_forensics_len_bytes(len: usize) -> ([u8; 4], bool) {
match u32::try_from(len) {
Ok(value) => (value.to_le_bytes(), false),
Err(_) => (u32::MAX.to_le_bytes(), true),
}
}
fn report_desync_frame_too_large( fn report_desync_frame_too_large(
state: &RelayForensicsState, state: &RelayForensicsState,
proto_tag: ProtoTag, proto_tag: ProtoTag,
@ -425,7 +434,8 @@ fn report_desync_frame_too_large(
raw_len_bytes: Option<[u8; 4]>, raw_len_bytes: Option<[u8; 4]>,
stats: &Stats, stats: &Stats,
) -> ProxyError { ) -> ProxyError {
let len_buf = raw_len_bytes.unwrap_or((len as u32).to_le_bytes()); let (fallback_len_buf, len_buf_truncated) = desync_forensics_len_bytes(len);
let len_buf = raw_len_bytes.unwrap_or(fallback_len_buf);
let looks_like_tls = raw_len_bytes let looks_like_tls = raw_len_bytes
.map(|b| b[0] == 0x16 && b[1] == 0x03) .map(|b| b[0] == 0x16 && b[1] == 0x03)
.unwrap_or(false); .unwrap_or(false);
@ -461,6 +471,7 @@ fn report_desync_frame_too_large(
bytes_me2c, bytes_me2c,
raw_len = len, raw_len = len,
raw_len_hex = format_args!("0x{:08x}", len), raw_len_hex = format_args!("0x{:08x}", len),
raw_len_bytes_truncated = len_buf_truncated,
raw_bytes = format_args!( raw_bytes = format_args!(
"{:02x} {:02x} {:02x} {:02x}", "{:02x} {:02x} {:02x} {:02x}",
len_buf[0], len_buf[1], len_buf[2], len_buf[3] len_buf[0], len_buf[1], len_buf[2], len_buf[3]
@ -503,8 +514,7 @@ fn report_desync_frame_too_large(
ProxyError::Proxy(format!( ProxyError::Proxy(format!(
"Frame too large: {len} (max {max_frame}), frames_ok={frame_counter}, conn_id={}, trace_id=0x{:016x}", "Frame too large: {len} (max {max_frame}), frames_ok={frame_counter}, conn_id={}, trace_id=0x{:016x}",
state.conn_id, state.conn_id, state.trace_id
state.trace_id
)) ))
} }
@ -528,6 +538,30 @@ fn quota_would_be_exceeded_for_user(
}) })
} }
#[cfg(test)]
fn quota_user_lock_test_guard() -> &'static Mutex<()> {
static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
TEST_LOCK.get_or_init(|| Mutex::new(()))
}
#[cfg(test)]
fn quota_user_lock_test_scope() -> std::sync::MutexGuard<'static, ()> {
quota_user_lock_test_guard()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
fn quota_overflow_user_lock(user: &str) -> Arc<AsyncMutex<()>> {
let stripes = QUOTA_USER_OVERFLOW_LOCKS.get_or_init(|| {
(0..QUOTA_OVERFLOW_LOCK_STRIPES)
.map(|_| Arc::new(AsyncMutex::new(())))
.collect()
});
let hash = crc32fast::hash(user.as_bytes()) as usize;
Arc::clone(&stripes[hash % stripes.len()])
}
fn quota_user_lock(user: &str) -> Arc<AsyncMutex<()>> { fn quota_user_lock(user: &str) -> Arc<AsyncMutex<()>> {
let locks = QUOTA_USER_LOCKS.get_or_init(DashMap::new); let locks = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
if let Some(existing) = locks.get(user) { if let Some(existing) = locks.get(user) {
@ -539,7 +573,7 @@ fn quota_user_lock(user: &str) -> Arc<AsyncMutex<()>> {
} }
if locks.len() >= QUOTA_USER_LOCKS_MAX { if locks.len() >= QUOTA_USER_LOCKS_MAX {
return Arc::new(AsyncMutex::new(())); return quota_overflow_user_lock(user);
} }
let created = Arc::new(AsyncMutex::new(())); let created = Arc::new(AsyncMutex::new(()));
@ -629,11 +663,9 @@ where
stats.increment_user_connects(&user); stats.increment_user_connects(&user);
let _me_connection_lease = stats.acquire_me_connection_lease(); let _me_connection_lease = stats.acquire_me_connection_lease();
if let Some(cutover) = affected_cutover_state( if let Some(cutover) =
&route_rx, affected_cutover_state(&route_rx, RelayRouteMode::Middle, route_snapshot.generation)
RelayRouteMode::Middle, {
route_snapshot.generation,
) {
let delay = cutover_stagger_delay(session_id, cutover.generation); let delay = cutover_stagger_delay(session_id, cutover.generation);
warn!( warn!(
conn_id, conn_id,
@ -689,13 +721,13 @@ where
.max(C2ME_CHANNEL_CAPACITY_FALLBACK); .max(C2ME_CHANNEL_CAPACITY_FALLBACK);
let (c2me_tx, mut c2me_rx) = mpsc::channel::<C2MeCommand>(c2me_channel_capacity); let (c2me_tx, mut c2me_rx) = mpsc::channel::<C2MeCommand>(c2me_channel_capacity);
let me_pool_c2me = me_pool.clone(); let me_pool_c2me = me_pool.clone();
let effective_tag = effective_tag;
let c2me_sender = tokio::spawn(async move { let c2me_sender = tokio::spawn(async move {
let mut sent_since_yield = 0usize; let mut sent_since_yield = 0usize;
while let Some(cmd) = c2me_rx.recv().await { while let Some(cmd) = c2me_rx.recv().await {
match cmd { match cmd {
C2MeCommand::Data { payload, flags } => { C2MeCommand::Data { payload, flags } => {
me_pool_c2me.send_proxy_req( me_pool_c2me
.send_proxy_req(
conn_id, conn_id,
success.dc_idx, success.dc_idx,
peer, peer,
@ -703,7 +735,8 @@ where
payload.as_ref(), payload.as_ref(),
flags, flags,
effective_tag.as_deref(), effective_tag.as_deref(),
).await?; )
.await?;
sent_since_yield = sent_since_yield.saturating_add(1); sent_since_yield = sent_since_yield.saturating_add(1);
if should_yield_c2me_sender(sent_since_yield, !c2me_rx.is_empty()) { if should_yield_c2me_sender(sent_since_yield, !c2me_rx.is_empty()) {
sent_since_yield = 0; sent_since_yield = 0;
@ -916,7 +949,11 @@ where
let mut seen_pressure_seq = relay_pressure_event_seq(); let mut seen_pressure_seq = relay_pressure_event_seq();
loop { loop {
if relay_idle_policy.enabled if relay_idle_policy.enabled
&& maybe_evict_idle_candidate_on_pressure(conn_id, &mut seen_pressure_seq, stats.as_ref()) && maybe_evict_idle_candidate_on_pressure(
conn_id,
&mut seen_pressure_seq,
stats.as_ref(),
)
{ {
info!( info!(
conn_id, conn_id,
@ -931,11 +968,9 @@ where
break; break;
} }
if let Some(cutover) = affected_cutover_state( if let Some(cutover) =
&route_rx, affected_cutover_state(&route_rx, RelayRouteMode::Middle, route_snapshot.generation)
RelayRouteMode::Middle, {
route_snapshot.generation,
) {
let delay = cutover_stagger_delay(session_id, cutover.generation); let delay = cutover_stagger_delay(session_id, cutover.generation);
warn!( warn!(
conn_id, conn_id,
@ -1102,7 +1137,8 @@ where
return deadline; return deadline;
} }
let downstream_at = session_started_at + Duration::from_millis(last_downstream_activity_ms); let downstream_at =
session_started_at + Duration::from_millis(last_downstream_activity_ms);
if downstream_at > idle_state.last_client_frame_at { if downstream_at > idle_state.last_client_frame_at {
let grace_deadline = downstream_at + idle_policy.grace_after_downstream_activity; let grace_deadline = downstream_at + idle_policy.grace_after_downstream_activity;
if grace_deadline > deadline { if grace_deadline > deadline {
@ -1117,12 +1153,8 @@ where
let timeout_window = if idle_policy.enabled { let timeout_window = if idle_policy.enabled {
let now = Instant::now(); let now = Instant::now();
let downstream_ms = last_downstream_activity_ms.load(Ordering::Relaxed); let downstream_ms = last_downstream_activity_ms.load(Ordering::Relaxed);
let hard_deadline = hard_deadline( let hard_deadline =
idle_policy, hard_deadline(idle_policy, idle_state, session_started_at, downstream_ms);
idle_state,
session_started_at,
downstream_ms,
);
if now >= hard_deadline { if now >= hard_deadline {
clear_relay_idle_candidate(forensics.conn_id); clear_relay_idle_candidate(forensics.conn_id);
stats.increment_relay_idle_hard_close_total(); stats.increment_relay_idle_hard_close_total();
@ -1130,7 +1162,9 @@ where
.saturating_duration_since(idle_state.last_client_frame_at) .saturating_duration_since(idle_state.last_client_frame_at)
.as_secs(); .as_secs();
let downstream_idle_secs = now let downstream_idle_secs = now
.saturating_duration_since(session_started_at + Duration::from_millis(downstream_ms)) .saturating_duration_since(
session_started_at + Duration::from_millis(downstream_ms),
)
.as_secs(); .as_secs();
warn!( warn!(
trace_id = format_args!("0x{:016x}", forensics.trace_id), trace_id = format_args!("0x{:016x}", forensics.trace_id),
@ -1204,7 +1238,9 @@ where
Err(_) if !idle_policy.enabled => { Err(_) if !idle_policy.enabled => {
return Err(ProxyError::Io(std::io::Error::new( return Err(ProxyError::Io(std::io::Error::new(
std::io::ErrorKind::TimedOut, std::io::ErrorKind::TimedOut,
format!("middle-relay client frame read timeout while reading {read_label}"), format!(
"middle-relay client frame read timeout while reading {read_label}"
),
))); )));
} }
Err(_) => {} Err(_) => {}
@ -1470,14 +1506,7 @@ where
user: user.to_string(), user: user.to_string(),
}); });
} }
write_client_payload( write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf)
client_writer,
proto_tag,
flags,
&data,
rng,
frame_buf,
)
.await?; .await?;
bytes_me2c.fetch_add(data.len() as u64, Ordering::Relaxed); bytes_me2c.fetch_add(data.len() as u64, Ordering::Relaxed);
@ -1489,14 +1518,7 @@ where
}); });
} }
} else { } else {
write_client_payload( write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf)
client_writer,
proto_tag,
flags,
&data,
rng,
frame_buf,
)
.await?; .await?;
bytes_me2c.fetch_add(data.len() as u64, Ordering::Relaxed); bytes_me2c.fetch_add(data.len() as u64, Ordering::Relaxed);
@ -1534,6 +1556,31 @@ where
} }
} }
fn compute_intermediate_secure_wire_len(
data_len: usize,
padding_len: usize,
quickack: bool,
) -> Result<(u32, usize)> {
let wire_len = data_len
.checked_add(padding_len)
.ok_or_else(|| ProxyError::Proxy("Frame length overflow".into()))?;
if wire_len > 0x7fff_ffffusize {
return Err(ProxyError::Proxy(format!(
"Intermediate/Secure frame too large: {wire_len}"
)));
}
let total = 4usize
.checked_add(wire_len)
.ok_or_else(|| ProxyError::Proxy("Frame buffer size overflow".into()))?;
let mut len_val = u32::try_from(wire_len)
.map_err(|_| ProxyError::Proxy("Frame length conversion overflow".into()))?;
if quickack {
len_val |= 0x8000_0000;
}
Ok((len_val, total))
}
async fn write_client_payload<W>( async fn write_client_payload<W>(
client_writer: &mut CryptoWriter<W>, client_writer: &mut CryptoWriter<W>,
proto_tag: ProtoTag, proto_tag: ProtoTag,
@ -1603,11 +1650,8 @@ where
} else { } else {
0 0
}; };
let mut len_val = (data.len() + padding_len) as u32; let (len_val, total) =
if quickack { compute_intermediate_secure_wire_len(data.len(), padding_len, quickack)?;
len_val |= 0x8000_0000;
}
let total = 4 + data.len() + padding_len;
frame_buf.clear(); frame_buf.clear();
frame_buf.reserve(total); frame_buf.reserve(total);
frame_buf.extend_from_slice(&len_val.to_le_bytes()); frame_buf.extend_from_slice(&len_val.to_le_bytes());
@ -1657,3 +1701,23 @@ mod idle_policy_security_tests;
#[cfg(test)] #[cfg(test)]
#[path = "tests/middle_relay_desync_all_full_dedup_security_tests.rs"] #[path = "tests/middle_relay_desync_all_full_dedup_security_tests.rs"]
mod desync_all_full_dedup_security_tests; mod desync_all_full_dedup_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_stub_completion_security_tests.rs"]
mod stub_completion_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_coverage_high_risk_security_tests.rs"]
mod coverage_high_risk_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_quota_overflow_lock_security_tests.rs"]
mod quota_overflow_lock_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_length_cast_hardening_security_tests.rs"]
mod length_cast_hardening_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_blackhat_campaign_integration_tests.rs"]
mod blackhat_campaign_integration_tests;

View File

@ -1,13 +1,71 @@
//! Proxy Defs //! Proxy Defs
// Apply strict linting to proxy production code while keeping test builds noise-tolerant.
#![cfg_attr(test, allow(warnings))]
#![cfg_attr(not(test), forbid(clippy::undocumented_unsafe_blocks))]
#![cfg_attr(
not(test),
deny(
clippy::unwrap_used,
clippy::expect_used,
clippy::panic,
clippy::todo,
clippy::unimplemented,
clippy::correctness,
clippy::option_if_let_else,
clippy::or_fun_call,
clippy::branches_sharing_code,
clippy::single_option_map,
clippy::useless_let_if_seq,
clippy::redundant_locals,
clippy::cloned_ref_to_slice_refs,
unsafe_code,
clippy::await_holding_lock,
clippy::await_holding_refcell_ref,
clippy::debug_assert_with_mut_call,
clippy::macro_use_imports,
clippy::cast_ptr_alignment,
clippy::cast_lossless,
clippy::ptr_as_ptr,
clippy::large_stack_arrays,
clippy::same_functions_in_if_condition,
trivial_casts,
trivial_numeric_casts,
unused_extern_crates,
unused_import_braces,
rust_2018_idioms
)
)]
#![cfg_attr(
not(test),
allow(
clippy::use_self,
clippy::redundant_closure,
clippy::too_many_arguments,
clippy::doc_markdown,
clippy::missing_const_for_fn,
clippy::unnecessary_operation,
clippy::redundant_pub_crate,
clippy::derive_partial_eq_without_eq,
clippy::type_complexity,
clippy::new_ret_no_self,
clippy::cast_possible_truncation,
clippy::cast_possible_wrap,
clippy::significant_drop_tightening,
clippy::significant_drop_in_scrutinee,
clippy::float_cmp,
clippy::nursery
)
)]
pub mod adaptive_buffers; pub mod adaptive_buffers;
pub mod client; pub mod client;
pub mod direct_relay; pub mod direct_relay;
pub mod handshake; pub mod handshake;
pub mod masking; pub mod masking;
pub mod middle_relay; pub mod middle_relay;
pub mod route_mode;
pub mod relay; pub mod relay;
pub mod route_mode;
pub mod session_eviction; pub mod session_eviction;
pub use client::ClientHandler; pub use client::ClientHandler;

View File

@ -51,21 +51,19 @@
//! - `poll_write` on client = S→C (to client) → `octets_to`, `msgs_to` //! - `poll_write` on client = S→C (to client) → `octets_to`, `msgs_to`
//! - `SharedCounters` (atomics) let the watchdog read stats without locking //! - `SharedCounters` (atomics) let the watchdog read stats without locking
use std::io;
use std::pin::Pin;
use std::sync::{Arc, Mutex, OnceLock};
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::task::{Context, Poll};
use std::time::Duration;
use dashmap::DashMap;
use tokio::io::{
AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf, copy_bidirectional_with_sizes,
};
use tokio::time::Instant;
use tracing::{debug, trace, warn};
use crate::error::{ProxyError, Result}; use crate::error::{ProxyError, Result};
use crate::stats::Stats; use crate::stats::Stats;
use crate::stream::BufferPool; use crate::stream::BufferPool;
use dashmap::DashMap;
use std::io;
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::{Arc, Mutex, OnceLock};
use std::task::{Context, Poll};
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf, copy_bidirectional_with_sizes};
use tokio::time::Instant;
use tracing::{debug, trace, warn};
// ============= Constants ============= // ============= Constants =============
@ -251,7 +249,8 @@ impl<S> StatsIo<S> {
impl<S> Drop for StatsIo<S> { impl<S> Drop for StatsIo<S> {
fn drop(&mut self) { fn drop(&mut self) {
self.quota_read_retry_active.store(false, Ordering::Relaxed); self.quota_read_retry_active.store(false, Ordering::Relaxed);
self.quota_write_retry_active.store(false, Ordering::Relaxed); self.quota_write_retry_active
.store(false, Ordering::Relaxed);
} }
} }
@ -428,7 +427,9 @@ impl<S: AsyncRead + Unpin> AsyncRead for StatsIo<S> {
} }
// C→S: client sent data // C→S: client sent data
this.counters.c2s_bytes.fetch_add(n as u64, Ordering::Relaxed); this.counters
.c2s_bytes
.fetch_add(n as u64, Ordering::Relaxed);
this.counters.c2s_ops.fetch_add(1, Ordering::Relaxed); this.counters.c2s_ops.fetch_add(1, Ordering::Relaxed);
this.counters.touch(Instant::now(), this.epoch); this.counters.touch(Instant::now(), this.epoch);
@ -467,7 +468,8 @@ impl<S: AsyncWrite + Unpin> AsyncWrite for StatsIo<S> {
match lock.try_lock() { match lock.try_lock() {
Ok(guard) => { Ok(guard) => {
this.quota_write_wake_scheduled = false; this.quota_write_wake_scheduled = false;
this.quota_write_retry_active.store(false, Ordering::Relaxed); this.quota_write_retry_active
.store(false, Ordering::Relaxed);
Some(guard) Some(guard)
} }
Err(_) => { Err(_) => {
@ -509,7 +511,9 @@ impl<S: AsyncWrite + Unpin> AsyncWrite for StatsIo<S> {
Poll::Ready(Ok(n)) => { Poll::Ready(Ok(n)) => {
if n > 0 { if n > 0 {
// S→C: data written to client // S→C: data written to client
this.counters.s2c_bytes.fetch_add(n as u64, Ordering::Relaxed); this.counters
.s2c_bytes
.fetch_add(n as u64, Ordering::Relaxed);
this.counters.s2c_ops.fetch_add(1, Ordering::Relaxed); this.counters.s2c_ops.fetch_add(1, Ordering::Relaxed);
this.counters.touch(Instant::now(), this.epoch); this.counters.touch(Instant::now(), this.epoch);

View File

@ -119,9 +119,7 @@ pub(crate) fn affected_cutover_state(
} }
pub(crate) fn cutover_stagger_delay(session_id: u64, generation: u64) -> Duration { pub(crate) fn cutover_stagger_delay(session_id: u64, generation: u64) -> Duration {
let mut value = session_id let mut value = session_id ^ generation.rotate_left(17) ^ 0x9e37_79b9_7f4a_7c15;
^ generation.rotate_left(17)
^ 0x9e37_79b9_7f4a_7c15;
value ^= value >> 30; value ^= value >> 30;
value = value.wrapping_mul(0xbf58_476d_1ce4_e5b9); value = value.wrapping_mul(0xbf58_476d_1ce4_e5b9);
value ^= value >> 27; value ^= value >> 27;

View File

@ -1,3 +1,5 @@
#![allow(dead_code)]
/// Session eviction is intentionally disabled in runtime. /// Session eviction is intentionally disabled in runtime.
/// ///
/// The initial `user+dc` single-lease model caused valid parallel client /// The initial `user+dc` single-lease model caused valid parallel client

View File

@ -1,11 +1,11 @@
use super::*; use super::*;
use crate::config::ProxyConfig; use crate::config::ProxyConfig;
use crate::stats::Stats;
use crate::ip_tracker::UserIpTracker;
use crate::error::ProxyError; use crate::error::ProxyError;
use crate::ip_tracker::UserIpTracker;
use crate::stats::Stats;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::sync::Arc; use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::atomic::{AtomicU64, Ordering};
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
// ------------------------------------------------------------------ // ------------------------------------------------------------------
// Priority 3: Massive Concurrency Stress (OWASP ASVS 5.1.6) // Priority 3: Massive Concurrency Stress (OWASP ASVS 5.1.6)
@ -20,7 +20,10 @@ async fn client_stress_10k_connections_limit_strict() {
let ip_tracker = Arc::new(UserIpTracker::new()); let ip_tracker = Arc::new(UserIpTracker::new());
let mut config = ProxyConfig::default(); let mut config = ProxyConfig::default();
config.access.user_max_tcp_conns.insert(user.to_string(), limit); config
.access
.user_max_tcp_conns
.insert(user.to_string(), limit);
let iterations = 1000; let iterations = 1000;
let mut tasks = Vec::new(); let mut tasks = Vec::new();
@ -38,12 +41,10 @@ async fn client_stress_10k_connections_limit_strict() {
); );
match RunningClientHandler::acquire_user_connection_reservation_static( match RunningClientHandler::acquire_user_connection_reservation_static(
&user_str, &user_str, &config, stats, peer, ip_tracker,
&config, )
stats, .await
peer, {
ip_tracker,
).await {
Ok(res) => Ok(res), Ok(res) => Ok(res),
Err(ProxyError::ConnectionLimitExceeded { .. }) => Err(()), Err(ProxyError::ConnectionLimitExceeded { .. }) => Err(()),
Err(e) => panic!("Unexpected error: {:?}", e), Err(e) => panic!("Unexpected error: {:?}", e),
@ -67,15 +68,27 @@ async fn client_stress_10k_connections_limit_strict() {
} }
assert_eq!(successes, limit, "Should allow exactly 'limit' connections"); assert_eq!(successes, limit, "Should allow exactly 'limit' connections");
assert_eq!(failures, iterations - limit, "Should fail the rest with LimitExceeded"); assert_eq!(
failures,
iterations - limit,
"Should fail the rest with LimitExceeded"
);
assert_eq!(stats.get_user_curr_connects(user), limit as u64); assert_eq!(stats.get_user_curr_connects(user), limit as u64);
drop(reservations); drop(reservations);
ip_tracker.drain_cleanup_queue().await; ip_tracker.drain_cleanup_queue().await;
assert_eq!(stats.get_user_curr_connects(user), 0, "Stats must converge to 0 after all drops"); assert_eq!(
assert_eq!(ip_tracker.get_active_ip_count(user).await, 0, "IP tracker must converge to 0"); stats.get_user_curr_connects(user),
0,
"Stats must converge to 0 after all drops"
);
assert_eq!(
ip_tracker.get_active_ip_count(user).await,
0,
"IP tracker must converge to 0"
);
} }
// ------------------------------------------------------------------ // ------------------------------------------------------------------
@ -106,7 +119,11 @@ async fn client_ip_tracker_race_condition_stress() {
futures::future::join_all(tasks).await; futures::future::join_all(tasks).await;
assert_eq!(ip_tracker.get_active_ip_count(user).await, 0, "IP count must be zero after balanced add/remove burst"); assert_eq!(
ip_tracker.get_active_ip_count(user).await,
0,
"IP count must be zero after balanced add/remove burst"
);
} }
#[tokio::test] #[tokio::test]
@ -119,7 +136,10 @@ async fn client_limit_burst_peak_never_exceeds_cap() {
let ip_tracker = Arc::new(UserIpTracker::new()); let ip_tracker = Arc::new(UserIpTracker::new());
let mut config = ProxyConfig::default(); let mut config = ProxyConfig::default();
config.access.user_max_tcp_conns.insert(user.to_string(), limit); config
.access
.user_max_tcp_conns
.insert(user.to_string(), limit);
let peak = Arc::new(AtomicU64::new(0)); let peak = Arc::new(AtomicU64::new(0));
let mut tasks = Vec::with_capacity(attempts); let mut tasks = Vec::with_capacity(attempts);
@ -207,10 +227,10 @@ async fn client_expiration_rejection_never_mutates_live_counters() {
let ip_tracker = Arc::new(UserIpTracker::new()); let ip_tracker = Arc::new(UserIpTracker::new());
let mut config = ProxyConfig::default(); let mut config = ProxyConfig::default();
config config.access.user_expirations.insert(
.access user.to_string(),
.user_expirations chrono::Utc::now() - chrono::Duration::seconds(1),
.insert(user.to_string(), chrono::Utc::now() - chrono::Duration::seconds(1)); );
let peer: SocketAddr = "198.51.100.202:31112".parse().unwrap(); let peer: SocketAddr = "198.51.100.202:31112".parse().unwrap();
let res = RunningClientHandler::acquire_user_connection_reservation_static( let res = RunningClientHandler::acquire_user_connection_reservation_static(
@ -235,7 +255,10 @@ async fn client_ip_limit_failure_rolls_back_counter_exactly() {
ip_tracker.set_user_limit(user, 1).await; ip_tracker.set_user_limit(user, 1).await;
let mut config = ProxyConfig::default(); let mut config = ProxyConfig::default();
config.access.user_max_tcp_conns.insert(user.to_string(), 16); config
.access
.user_max_tcp_conns
.insert(user.to_string(), 16);
let first_peer: SocketAddr = "198.51.100.203:31113".parse().unwrap(); let first_peer: SocketAddr = "198.51.100.203:31113".parse().unwrap();
let first = RunningClientHandler::acquire_user_connection_reservation_static( let first = RunningClientHandler::acquire_user_connection_reservation_static(
@ -258,7 +281,10 @@ async fn client_ip_limit_failure_rolls_back_counter_exactly() {
) )
.await; .await;
assert!(matches!(second, Err(ProxyError::ConnectionLimitExceeded { .. }))); assert!(matches!(
second,
Err(ProxyError::ConnectionLimitExceeded { .. })
));
assert_eq!(stats.get_user_curr_connects(user), 1); assert_eq!(stats.get_user_curr_connects(user), 1);
drop(first); drop(first);
@ -276,7 +302,10 @@ async fn client_parallel_limit_checks_success_path_leaves_no_residue() {
ip_tracker.set_user_limit(user, 128).await; ip_tracker.set_user_limit(user, 128).await;
let mut config = ProxyConfig::default(); let mut config = ProxyConfig::default();
config.access.user_max_tcp_conns.insert(user.to_string(), 128); config
.access
.user_max_tcp_conns
.insert(user.to_string(), 128);
let mut tasks = Vec::new(); let mut tasks = Vec::new();
for i in 0..128u16 { for i in 0..128u16 {
@ -310,7 +339,10 @@ async fn client_parallel_limit_checks_failure_path_leaves_no_residue() {
ip_tracker.set_user_limit(user, 0).await; ip_tracker.set_user_limit(user, 0).await;
let mut config = ProxyConfig::default(); let mut config = ProxyConfig::default();
config.access.user_max_tcp_conns.insert(user.to_string(), 512); config
.access
.user_max_tcp_conns
.insert(user.to_string(), 512);
let mut tasks = Vec::new(); let mut tasks = Vec::new();
for i in 0..64u16 { for i in 0..64u16 {
@ -319,7 +351,10 @@ async fn client_parallel_limit_checks_failure_path_leaves_no_residue() {
let config = config.clone(); let config = config.clone();
tasks.push(tokio::spawn(async move { tasks.push(tokio::spawn(async move {
let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(172, 16, 0, (i % 250 + 1) as u8)), 33000 + i); let peer = SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(172, 16, 0, (i % 250 + 1) as u8)),
33000 + i,
);
RunningClientHandler::check_user_limits_static(user, &config, &stats, peer, &ip_tracker) RunningClientHandler::check_user_limits_static(user, &config, &stats, peer, &ip_tracker)
.await .await
})); }));
@ -360,11 +395,7 @@ async fn client_churn_mixed_success_failure_converges_to_zero_state() {
34000 + (i % 32), 34000 + (i % 32),
); );
let maybe_res = RunningClientHandler::acquire_user_connection_reservation_static( let maybe_res = RunningClientHandler::acquire_user_connection_reservation_static(
user, user, &config, stats, peer, ip_tracker,
&config,
stats,
peer,
ip_tracker,
) )
.await; .await;
@ -401,11 +432,7 @@ async fn client_same_ip_parallel_attempts_allow_at_most_one_when_limit_is_one()
let config = config.clone(); let config = config.clone();
tasks.push(tokio::spawn(async move { tasks.push(tokio::spawn(async move {
RunningClientHandler::acquire_user_connection_reservation_static( RunningClientHandler::acquire_user_connection_reservation_static(
user, user, &config, stats, peer, ip_tracker,
&config,
stats,
peer,
ip_tracker,
) )
.await .await
})); }));
@ -424,7 +451,10 @@ async fn client_same_ip_parallel_attempts_allow_at_most_one_when_limit_is_one()
} }
} }
assert_eq!(granted, 1, "only one reservation may be granted for same IP with limit=1"); assert_eq!(
granted, 1,
"only one reservation may be granted for same IP with limit=1"
);
drop(reservations); drop(reservations);
ip_tracker.drain_cleanup_queue().await; ip_tracker.drain_cleanup_queue().await;
assert_eq!(stats.get_user_curr_connects(user), 0); assert_eq!(stats.get_user_curr_connects(user), 0);
@ -439,7 +469,10 @@ async fn client_repeat_acquire_release_cycles_never_accumulate_state() {
ip_tracker.set_user_limit(user, 32).await; ip_tracker.set_user_limit(user, 32).await;
let mut config = ProxyConfig::default(); let mut config = ProxyConfig::default();
config.access.user_max_tcp_conns.insert(user.to_string(), 32); config
.access
.user_max_tcp_conns
.insert(user.to_string(), 32);
for i in 0..500u16 { for i in 0..500u16 {
let peer = SocketAddr::new( let peer = SocketAddr::new(
@ -484,11 +517,7 @@ async fn client_multi_user_isolation_under_parallel_limit_exhaustion() {
37000 + i, 37000 + i,
); );
RunningClientHandler::acquire_user_connection_reservation_static( RunningClientHandler::acquire_user_connection_reservation_static(
user, user, &config, stats, peer, ip_tracker,
&config,
stats,
peer,
ip_tracker,
) )
.await .await
})); }));
@ -497,7 +526,11 @@ async fn client_multi_user_isolation_under_parallel_limit_exhaustion() {
let mut u1_success = 0usize; let mut u1_success = 0usize;
let mut u2_success = 0usize; let mut u2_success = 0usize;
let mut reservations = Vec::new(); let mut reservations = Vec::new();
for (idx, result) in futures::future::join_all(tasks).await.into_iter().enumerate() { for (idx, result) in futures::future::join_all(tasks)
.await
.into_iter()
.enumerate()
{
let user = if idx % 2 == 0 { "u1" } else { "u2" }; let user = if idx % 2 == 0 { "u1" } else { "u2" };
match result.unwrap() { match result.unwrap() {
Ok(reservation) => { Ok(reservation) => {
@ -556,7 +589,10 @@ async fn client_limit_recovery_after_full_rejection_wave() {
ip_tracker.clone(), ip_tracker.clone(),
) )
.await; .await;
assert!(matches!(denied, Err(ProxyError::ConnectionLimitExceeded { .. }))); assert!(matches!(
denied,
Err(ProxyError::ConnectionLimitExceeded { .. })
));
} }
drop(reservation); drop(reservation);
@ -572,7 +608,10 @@ async fn client_limit_recovery_after_full_rejection_wave() {
ip_tracker.clone(), ip_tracker.clone(),
) )
.await; .await;
assert!(recovered.is_ok(), "capacity must recover after prior holder drops"); assert!(
recovered.is_ok(),
"capacity must recover after prior holder drops"
);
} }
#[tokio::test] #[tokio::test]
@ -619,7 +658,10 @@ async fn client_dual_limit_cross_product_never_leaks_on_reject() {
ip_tracker.clone(), ip_tracker.clone(),
) )
.await; .await;
assert!(matches!(denied, Err(ProxyError::ConnectionLimitExceeded { .. }))); assert!(matches!(
denied,
Err(ProxyError::ConnectionLimitExceeded { .. })
));
} }
assert_eq!(stats.get_user_curr_connects(user), 2); assert_eq!(stats.get_user_curr_connects(user), 2);
@ -637,7 +679,10 @@ async fn client_check_user_limits_concurrent_churn_no_counter_drift() {
ip_tracker.set_user_limit(user, 64).await; ip_tracker.set_user_limit(user, 64).await;
let mut config = ProxyConfig::default(); let mut config = ProxyConfig::default();
config.access.user_max_tcp_conns.insert(user.to_string(), 64); config
.access
.user_max_tcp_conns
.insert(user.to_string(), 64);
let mut tasks = Vec::new(); let mut tasks = Vec::new();
for i in 0..512u16 { for i in 0..512u16 {

View File

@ -2,17 +2,14 @@ use super::*;
use crate::config::{UpstreamConfig, UpstreamType}; use crate::config::{UpstreamConfig, UpstreamType};
use crate::crypto::sha256_hmac; use crate::crypto::sha256_hmac;
use crate::protocol::constants::{ use crate::protocol::constants::{
HANDSHAKE_LEN, HANDSHAKE_LEN, MAX_TLS_PLAINTEXT_SIZE, MIN_TLS_CLIENT_HELLO_SIZE, TLS_RECORD_APPLICATION,
MAX_TLS_PLAINTEXT_SIZE,
MIN_TLS_CLIENT_HELLO_SIZE,
TLS_RECORD_APPLICATION,
TLS_VERSION, TLS_VERSION,
}; };
use crate::protocol::tls; use crate::protocol::tls;
use std::collections::HashSet; use std::collections::HashSet;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio::time::{Duration, Instant}; use tokio::time::{Duration, Instant};
@ -79,7 +76,10 @@ fn build_mask_harness(secret_hex: &str, mask_port: u16) -> CampaignHarness {
} }
fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec<u8> { fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec<u8> {
assert!(tls_len <= u16::MAX as usize, "TLS length must fit into record header"); assert!(
tls_len <= u16::MAX as usize,
"TLS length must fit into record header"
);
let total_len = 5 + tls_len; let total_len = 5 + tls_len;
let mut handshake = vec![fill; total_len]; let mut handshake = vec![fill; total_len];
@ -171,7 +171,10 @@ async fn run_tls_success_mtproto_fail_capture(
client_side.write_all(&client_hello).await.unwrap(); client_side.write_all(&client_hello).await.unwrap();
let mut tls_response_head = [0u8; 5]; let mut tls_response_head = [0u8; 5];
client_side.read_exact(&mut tls_response_head).await.unwrap(); client_side
.read_exact(&mut tls_response_head)
.await
.unwrap();
assert_eq!(tls_response_head[0], 0x16); assert_eq!(tls_response_head[0], 0x16);
read_and_discard_tls_record_body(&mut client_side, tls_response_head).await; read_and_discard_tls_record_body(&mut client_side, tls_response_head).await;
@ -427,7 +430,10 @@ async fn blackhat_campaign_06_replayed_tls_hello_is_masked_without_serverhello()
client_side.read_exact(&mut head).await.unwrap(); client_side.read_exact(&mut head).await.unwrap();
assert_eq!(head[0], 0x16); assert_eq!(head[0], 0x16);
read_and_discard_tls_record_body(&mut client_side, head).await; read_and_discard_tls_record_body(&mut client_side, head).await;
client_side.write_all(&invalid_mtproto_record).await.unwrap(); client_side
.write_all(&invalid_mtproto_record)
.await
.unwrap();
client_side.write_all(&first_tail).await.unwrap(); client_side.write_all(&first_tail).await.unwrap();
} else { } else {
let mut one = [0u8; 1]; let mut one = [0u8; 1];
@ -697,13 +703,15 @@ async fn blackhat_campaign_12_parallel_tls_success_mtproto_fail_sessions_keep_is
let mut tasks = Vec::new(); let mut tasks = Vec::new();
for i in 0..sessions { for i in 0..sessions {
let mut harness = build_mask_harness("abababababababababababababababab", backend_addr.port()); let mut harness =
build_mask_harness("abababababababababababababababab", backend_addr.port());
let mut cfg = (*harness.config).clone(); let mut cfg = (*harness.config).clone();
cfg.censorship.mask_port = backend_addr.port(); cfg.censorship.mask_port = backend_addr.port();
harness.config = Arc::new(cfg); harness.config = Arc::new(cfg);
tasks.push(tokio::spawn(async move { tasks.push(tokio::spawn(async move {
let secret = [0xABu8; 16]; let secret = [0xABu8; 16];
let hello = make_valid_tls_client_hello(&secret, 100 + i as u32, 600, 0x40 + (i as u8 % 10)); let hello =
make_valid_tls_client_hello(&secret, 100 + i as u32, 600, 0x40 + (i as u8 % 10));
let bad = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); let bad = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]);
let tail = wrap_tls_application_data(&vec![i as u8; 8 + i]); let tail = wrap_tls_application_data(&vec![i as u8; 8 + i]);
let (server_side, mut client_side) = duplex(131072); let (server_side, mut client_side) = duplex(131072);
@ -843,8 +851,8 @@ async fn blackhat_campaign_15_light_fuzz_tls_lengths_and_fragmentation() {
tls_len = MAX_TLS_PLAINTEXT_SIZE + 1 + (tls_len % 1024); tls_len = MAX_TLS_PLAINTEXT_SIZE + 1 + (tls_len % 1024);
} }
let body_to_send = if (MIN_TLS_CLIENT_HELLO_SIZE..=MAX_TLS_PLAINTEXT_SIZE).contains(&tls_len) let body_to_send =
{ if (MIN_TLS_CLIENT_HELLO_SIZE..=MAX_TLS_PLAINTEXT_SIZE).contains(&tls_len) {
(seed as usize % 29).min(tls_len.saturating_sub(1)) (seed as usize % 29).min(tls_len.saturating_sub(1))
} else { } else {
0 0
@ -856,7 +864,9 @@ async fn blackhat_campaign_15_light_fuzz_tls_lengths_and_fragmentation() {
probe[2] = 0x01; probe[2] = 0x01;
probe[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes()); probe[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes());
for b in &mut probe[5..] { for b in &mut probe[5..] {
seed = seed.wrapping_mul(2862933555777941757).wrapping_add(3037000493); seed = seed
.wrapping_mul(2862933555777941757)
.wrapping_add(3037000493);
*b = (seed >> 24) as u8; *b = (seed >> 24) as u8;
} }
@ -879,7 +889,8 @@ async fn blackhat_campaign_16_mixed_probe_burst_stress_finishes_without_panics()
probe[2] = 0x01; probe[2] = 0x01;
probe[3..5].copy_from_slice(&600u16.to_be_bytes()); probe[3..5].copy_from_slice(&600u16.to_be_bytes());
probe[5..].fill((0x90 + i as u8) ^ 0x5A); probe[5..].fill((0x90 + i as u8) ^ 0x5A);
run_invalid_tls_capture(Arc::new(ProxyConfig::default()), probe.clone(), probe).await; run_invalid_tls_capture(Arc::new(ProxyConfig::default()), probe.clone(), probe)
.await;
} else { } else {
let hdr = vec![0x16, 0x03, 0x01, 0xFF, i as u8]; let hdr = vec![0x16, 0x03, 0x01, 0xFF, i as u8];
run_invalid_tls_capture(Arc::new(ProxyConfig::default()), hdr.clone(), hdr).await; run_invalid_tls_capture(Arc::new(ProxyConfig::default()), hdr.clone(), hdr).await;

View File

@ -3,7 +3,7 @@ use crate::config::{UpstreamConfig, UpstreamType};
use crate::crypto::sha256_hmac; use crate::crypto::sha256_hmac;
use crate::protocol::constants::{HANDSHAKE_LEN, TLS_VERSION}; use crate::protocol::constants::{HANDSHAKE_LEN, TLS_VERSION};
use crate::protocol::tls; use crate::protocol::tls;
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio::time::{Duration, Instant}; use tokio::time::{Duration, Instant};
@ -55,7 +55,10 @@ fn build_harness(config: ProxyConfig) -> PipelineHarness {
} }
fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec<u8> { fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec<u8> {
assert!(tls_len <= u16::MAX as usize, "TLS length must fit into record header"); assert!(
tls_len <= u16::MAX as usize,
"TLS length must fit into record header"
);
let total_len = 5 + tls_len; let total_len = 5 + tls_len;
let mut handshake = vec![fill; total_len]; let mut handshake = vec![fill; total_len];
@ -150,7 +153,10 @@ async fn masking_runs_outside_handshake_timeout_budget_with_high_reject_delay()
.unwrap() .unwrap()
.unwrap(); .unwrap();
assert!(result.is_ok(), "bad-client fallback must not be canceled by handshake timeout"); assert!(
result.is_ok(),
"bad-client fallback must not be canceled by handshake timeout"
);
assert_eq!( assert_eq!(
stats.get_handshake_timeouts(), stats.get_handshake_timeouts(),
0, 0,
@ -175,10 +181,10 @@ async fn tls_mtproto_bad_client_does_not_reinject_clienthello_into_mask_backend(
config.censorship.mask_port = backend_addr.port(); config.censorship.mask_port = backend_addr.port();
config.censorship.mask_proxy_protocol = 0; config.censorship.mask_proxy_protocol = 0;
config.access.ignore_time_skew = true; config.access.ignore_time_skew = true;
config config.access.users.insert(
.access "user".to_string(),
.users "d0d0d0d0d0d0d0d0d0d0d0d0d0d0d0d0".to_string(),
.insert("user".to_string(), "d0d0d0d0d0d0d0d0d0d0d0d0d0d0d0d0".to_string()); );
let harness = build_harness(config); let harness = build_harness(config);
@ -194,8 +200,7 @@ async fn tls_mtproto_bad_client_does_not_reinject_clienthello_into_mask_backend(
let mut got = vec![0u8; expected_trailing.len()]; let mut got = vec![0u8; expected_trailing.len()];
stream.read_exact(&mut got).await.unwrap(); stream.read_exact(&mut got).await.unwrap();
assert_eq!( assert_eq!(
got, got, expected_trailing,
expected_trailing,
"mask backend must receive only post-handshake trailing TLS records" "mask backend must receive only post-handshake trailing TLS records"
); );
}); });
@ -223,11 +228,17 @@ async fn tls_mtproto_bad_client_does_not_reinject_clienthello_into_mask_backend(
client_side.write_all(&client_hello).await.unwrap(); client_side.write_all(&client_hello).await.unwrap();
let mut tls_response_head = [0u8; 5]; let mut tls_response_head = [0u8; 5];
client_side.read_exact(&mut tls_response_head).await.unwrap(); client_side
.read_exact(&mut tls_response_head)
.await
.unwrap();
assert_eq!(tls_response_head[0], 0x16); assert_eq!(tls_response_head[0], 0x16);
read_and_discard_tls_record_body(&mut client_side, tls_response_head).await; read_and_discard_tls_record_body(&mut client_side, tls_response_head).await;
client_side.write_all(&invalid_mtproto_record).await.unwrap(); client_side
.write_all(&invalid_mtproto_record)
.await
.unwrap();
client_side.write_all(&trailing_record).await.unwrap(); client_side.write_all(&trailing_record).await.unwrap();
tokio::time::timeout(Duration::from_secs(3), accept_task) tokio::time::timeout(Duration::from_secs(3), accept_task)

View File

@ -1,7 +1,7 @@
use super::*; use super::*;
use crate::config::{UpstreamConfig, UpstreamType}; use crate::config::{UpstreamConfig, UpstreamType};
use std::sync::Arc; use std::sync::Arc;
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio::time::{Duration, Instant}; use tokio::time::{Duration, Instant};
@ -163,21 +163,36 @@ async fn diagnostic_timing_profiles_are_within_realistic_guardrails() {
); );
assert!(p50 >= 650, "p50 too low for delayed reject class={}", class); assert!(p50 >= 650, "p50 too low for delayed reject class={}", class);
assert!(p95 <= 1200, "p95 too high for delayed reject class={}", class); assert!(
assert!(max <= 1500, "max too high for delayed reject class={}", class); p95 <= 1200,
"p95 too high for delayed reject class={}",
class
);
assert!(
max <= 1500,
"max too high for delayed reject class={}",
class
);
} }
} }
#[tokio::test] #[tokio::test]
async fn diagnostic_forwarded_size_profiles_by_probe_class() { async fn diagnostic_forwarded_size_profiles_by_probe_class() {
let classes = [0usize, 1usize, 7usize, 17usize, 63usize, 511usize, 1023usize, 2047usize]; let classes = [
0usize, 1usize, 7usize, 17usize, 63usize, 511usize, 1023usize, 2047usize,
];
let mut observed = Vec::new(); let mut observed = Vec::new();
for class in classes { for class in classes {
let len = capture_forwarded_len(class).await; let len = capture_forwarded_len(class).await;
println!("diagnostic_shape class={} forwarded_len={}", class, len); println!("diagnostic_shape class={} forwarded_len={}", class, len);
observed.push(len as u128); observed.push(len as u128);
assert_eq!(len, 5 + class, "unexpected forwarded len for class={}", class); assert_eq!(
len,
5 + class,
"unexpected forwarded len for class={}",
class
);
} }
let p50 = percentile_ms(observed.clone(), 50, 100); let p50 = percentile_ms(observed.clone(), 50, 100);

View File

@ -3,7 +3,7 @@ use crate::config::{UpstreamConfig, UpstreamType};
use crate::crypto::sha256_hmac; use crate::crypto::sha256_hmac;
use crate::protocol::constants::{HANDSHAKE_LEN, TLS_RECORD_APPLICATION, TLS_VERSION}; use crate::protocol::constants::{HANDSHAKE_LEN, TLS_RECORD_APPLICATION, TLS_VERSION};
use crate::protocol::tls; use crate::protocol::tls;
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio::time::{Duration, Instant}; use tokio::time::{Duration, Instant};
@ -70,7 +70,10 @@ fn build_harness(secret_hex: &str, mask_port: u16) -> Harness {
} }
fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec<u8> { fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec<u8> {
assert!(tls_len <= u16::MAX as usize, "TLS length must fit into record header"); assert!(
tls_len <= u16::MAX as usize,
"TLS length must fit into record header"
);
let total_len = 5 + tls_len; let total_len = 5 + tls_len;
let mut handshake = vec![fill; total_len]; let mut handshake = vec![fill; total_len];
@ -158,11 +161,17 @@ async fn run_tls_success_mtproto_fail_capture(
client_side.write_all(&client_hello).await.unwrap(); client_side.write_all(&client_hello).await.unwrap();
let mut tls_response_head = [0u8; 5]; let mut tls_response_head = [0u8; 5];
client_side.read_exact(&mut tls_response_head).await.unwrap(); client_side
.read_exact(&mut tls_response_head)
.await
.unwrap();
assert_eq!(tls_response_head[0], 0x16); assert_eq!(tls_response_head[0], 0x16);
read_tls_record_body(&mut client_side, tls_response_head).await; read_tls_record_body(&mut client_side, tls_response_head).await;
client_side.write_all(&invalid_mtproto_record).await.unwrap(); client_side
.write_all(&invalid_mtproto_record)
.await
.unwrap();
for record in trailing_records { for record in trailing_records {
client_side.write_all(&record).await.unwrap(); client_side.write_all(&record).await.unwrap();
} }
@ -330,7 +339,10 @@ async fn replayed_tls_hello_gets_no_serverhello_and_is_masked() {
client_side.read_exact(&mut head).await.unwrap(); client_side.read_exact(&mut head).await.unwrap();
assert_eq!(head[0], 0x16); assert_eq!(head[0], 0x16);
read_tls_record_body(&mut client_side, head).await; read_tls_record_body(&mut client_side, head).await;
client_side.write_all(&invalid_mtproto_record).await.unwrap(); client_side
.write_all(&invalid_mtproto_record)
.await
.unwrap();
client_side.write_all(&first_tail).await.unwrap(); client_side.write_all(&first_tail).await.unwrap();
} else { } else {
let mut one = [0u8; 1]; let mut one = [0u8; 1];
@ -402,7 +414,10 @@ async fn connects_bad_increments_once_per_invalid_mtproto() {
let mut head = [0u8; 5]; let mut head = [0u8; 5];
client_side.read_exact(&mut head).await.unwrap(); client_side.read_exact(&mut head).await.unwrap();
read_tls_record_body(&mut client_side, head).await; read_tls_record_body(&mut client_side, head).await;
client_side.write_all(&invalid_mtproto_record).await.unwrap(); client_side
.write_all(&invalid_mtproto_record)
.await
.unwrap();
client_side.write_all(&tail).await.unwrap(); client_side.write_all(&tail).await.unwrap();
tokio::time::timeout(Duration::from_secs(3), accept_task) tokio::time::timeout(Duration::from_secs(3), accept_task)
@ -625,7 +640,8 @@ async fn concurrent_tls_mtproto_fail_sessions_are_isolated() {
for idx in 0..sessions { for idx in 0..sessions {
let secret_hex = "c4c4c4c4c4c4c4c4c4c4c4c4c4c4c4c4"; let secret_hex = "c4c4c4c4c4c4c4c4c4c4c4c4c4c4c4c4";
let harness = build_harness(secret_hex, backend_addr.port()); let harness = build_harness(secret_hex, backend_addr.port());
let hello = make_valid_tls_client_hello(&[0xC4; 16], 20 + idx as u32, 600, 0x40 + idx as u8); let hello =
make_valid_tls_client_hello(&[0xC4; 16], 20 + idx as u32, 600, 0x40 + idx as u8);
let invalid_mtproto = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); let invalid_mtproto = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]);
let trailing = wrap_tls_application_data(&vec![idx as u8; 32 + idx]); let trailing = wrap_tls_application_data(&vec![idx as u8; 32 + idx]);
let peer: SocketAddr = format!("198.51.100.217:{}", 56100 + idx as u16) let peer: SocketAddr = format!("198.51.100.217:{}", 56100 + idx as u16)
@ -685,17 +701,67 @@ macro_rules! tail_length_case {
*b = (i as u8).wrapping_mul(17).wrapping_add(5); *b = (i as u8).wrapping_mul(17).wrapping_add(5);
} }
let record = wrap_tls_application_data(&payload); let record = wrap_tls_application_data(&payload);
let got = run_tls_success_mtproto_fail_capture($hex, $secret, $ts, vec![record.clone()]).await; let got =
run_tls_success_mtproto_fail_capture($hex, $secret, $ts, vec![record.clone()])
.await;
assert_eq!(got, record); assert_eq!(got, record);
} }
}; };
} }
tail_length_case!(tail_len_1_preserved, "d1d1d1d1d1d1d1d1d1d1d1d1d1d1d1d1", [0xD1; 16], 30, 1); tail_length_case!(
tail_length_case!(tail_len_2_preserved, "d2d2d2d2d2d2d2d2d2d2d2d2d2d2d2d2", [0xD2; 16], 31, 2); tail_len_1_preserved,
tail_length_case!(tail_len_3_preserved, "d3d3d3d3d3d3d3d3d3d3d3d3d3d3d3d3", [0xD3; 16], 32, 3); "d1d1d1d1d1d1d1d1d1d1d1d1d1d1d1d1",
tail_length_case!(tail_len_7_preserved, "d4d4d4d4d4d4d4d4d4d4d4d4d4d4d4d4", [0xD4; 16], 33, 7); [0xD1; 16],
tail_length_case!(tail_len_31_preserved, "d5d5d5d5d5d5d5d5d5d5d5d5d5d5d5d5", [0xD5; 16], 34, 31); 30,
tail_length_case!(tail_len_127_preserved, "d6d6d6d6d6d6d6d6d6d6d6d6d6d6d6d6", [0xD6; 16], 35, 127); 1
tail_length_case!(tail_len_511_preserved, "d7d7d7d7d7d7d7d7d7d7d7d7d7d7d7d7", [0xD7; 16], 36, 511); );
tail_length_case!(tail_len_1023_preserved, "d8d8d8d8d8d8d8d8d8d8d8d8d8d8d8d8", [0xD8; 16], 37, 1023); tail_length_case!(
tail_len_2_preserved,
"d2d2d2d2d2d2d2d2d2d2d2d2d2d2d2d2",
[0xD2; 16],
31,
2
);
tail_length_case!(
tail_len_3_preserved,
"d3d3d3d3d3d3d3d3d3d3d3d3d3d3d3d3",
[0xD3; 16],
32,
3
);
tail_length_case!(
tail_len_7_preserved,
"d4d4d4d4d4d4d4d4d4d4d4d4d4d4d4d4",
[0xD4; 16],
33,
7
);
tail_length_case!(
tail_len_31_preserved,
"d5d5d5d5d5d5d5d5d5d5d5d5d5d5d5d5",
[0xD5; 16],
34,
31
);
tail_length_case!(
tail_len_127_preserved,
"d6d6d6d6d6d6d6d6d6d6d6d6d6d6d6d6",
[0xD6; 16],
35,
127
);
tail_length_case!(
tail_len_511_preserved,
"d7d7d7d7d7d7d7d7d7d7d7d7d7d7d7d7",
[0xD7; 16],
36,
511
);
tail_length_case!(
tail_len_1023_preserved,
"d8d8d8d8d8d8d8d8d8d8d8d8d8d8d8d8",
[0xD8; 16],
37,
1023
);

View File

@ -5,7 +5,7 @@ use rand::{Rng, SeedableRng};
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::net::{TcpListener, TcpStream}; use tokio::net::{TcpListener, TcpStream};
const REPLY_404: &[u8] = b"HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n"; const REPLY_404: &[u8] = b"HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n";
@ -92,7 +92,10 @@ async fn run_generic_probe_and_capture_prefix(payload: Vec<u8>, expected_prefix:
client_side.shutdown().await.unwrap(); client_side.shutdown().await.unwrap();
let mut observed = vec![0u8; REPLY_404.len()]; let mut observed = vec![0u8; REPLY_404.len()];
tokio::time::timeout(Duration::from_secs(2), client_side.read_exact(&mut observed)) tokio::time::timeout(
Duration::from_secs(2),
client_side.read_exact(&mut observed),
)
.await .await
.unwrap() .unwrap()
.unwrap(); .unwrap();
@ -264,7 +267,8 @@ async fn stress_parallel_probe_mix_masks_all_sessions_without_cross_leakage() {
let mut expected = std::collections::HashSet::new(); let mut expected = std::collections::HashSet::new();
for idx in 0..session_count { for idx in 0..session_count {
let probe = format!("GET /stress-{idx} HTTP/1.1\r\nHost: s{idx}.example\r\n\r\n").into_bytes(); let probe =
format!("GET /stress-{idx} HTTP/1.1\r\nHost: s{idx}.example\r\n\r\n").into_bytes();
expected.insert(probe); expected.insert(probe);
} }
@ -274,9 +278,15 @@ async fn stress_parallel_probe_mix_masks_all_sessions_without_cross_leakage() {
let (mut stream, _) = listener.accept().await.unwrap(); let (mut stream, _) = listener.accept().await.unwrap();
let head = read_http_probe_header(&mut stream).await; let head = read_http_probe_header(&mut stream).await;
stream.write_all(REPLY_404).await.unwrap(); stream.write_all(REPLY_404).await.unwrap();
assert!(remaining.remove(&head), "backend received unexpected or duplicated probe prefix"); assert!(
remaining.remove(&head),
"backend received unexpected or duplicated probe prefix"
);
} }
assert!(remaining.is_empty(), "all session prefixes must be observed exactly once"); assert!(
remaining.is_empty(),
"all session prefixes must be observed exactly once"
);
}); });
let mut tasks = Vec::with_capacity(session_count); let mut tasks = Vec::with_capacity(session_count);
@ -291,7 +301,8 @@ async fn stress_parallel_probe_mix_masks_all_sessions_without_cross_leakage() {
let ip_tracker = Arc::new(UserIpTracker::new()); let ip_tracker = Arc::new(UserIpTracker::new());
let beobachten = Arc::new(BeobachtenStore::new()); let beobachten = Arc::new(BeobachtenStore::new());
let probe = format!("GET /stress-{idx} HTTP/1.1\r\nHost: s{idx}.example\r\n\r\n").into_bytes(); let probe =
format!("GET /stress-{idx} HTTP/1.1\r\nHost: s{idx}.example\r\n\r\n").into_bytes();
let peer: SocketAddr = format!("203.0.113.{}:{}", 30 + idx, 56000 + idx) let peer: SocketAddr = format!("203.0.113.{}:{}", 30 + idx, 56000 + idx)
.parse() .parse()
.unwrap(); .unwrap();
@ -319,7 +330,10 @@ async fn stress_parallel_probe_mix_masks_all_sessions_without_cross_leakage() {
client_side.shutdown().await.unwrap(); client_side.shutdown().await.unwrap();
let mut observed = vec![0u8; REPLY_404.len()]; let mut observed = vec![0u8; REPLY_404.len()];
tokio::time::timeout(Duration::from_secs(2), client_side.read_exact(&mut observed)) tokio::time::timeout(
Duration::from_secs(2),
client_side.read_exact(&mut observed),
)
.await .await
.unwrap() .unwrap()
.unwrap(); .unwrap();

View File

@ -3,7 +3,7 @@ use crate::config::{UpstreamConfig, UpstreamType};
use crate::crypto::sha256_hmac; use crate::crypto::sha256_hmac;
use crate::protocol::constants::{HANDSHAKE_LEN, TLS_VERSION}; use crate::protocol::constants::{HANDSHAKE_LEN, TLS_VERSION};
use crate::protocol::tls; use crate::protocol::tls;
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio::time::{Duration, Instant}; use tokio::time::{Duration, Instant};
@ -67,7 +67,10 @@ fn build_harness(secret_hex: &str, mask_port: u16) -> RedTeamHarness {
} }
fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec<u8> { fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec<u8> {
assert!(tls_len <= u16::MAX as usize, "TLS length must fit into record header"); assert!(
tls_len <= u16::MAX as usize,
"TLS length must fit into record header"
);
let total_len = 5 + tls_len; let total_len = 5 + tls_len;
let mut handshake = vec![fill; total_len]; let mut handshake = vec![fill; total_len];
@ -148,8 +151,14 @@ async fn run_tls_success_mtproto_fail_session(
let mut body = vec![0u8; body_len]; let mut body = vec![0u8; body_len];
client_side.read_exact(&mut body).await.unwrap(); client_side.read_exact(&mut body).await.unwrap();
client_side.write_all(&invalid_mtproto_record).await.unwrap(); client_side
client_side.write_all(&wrap_tls_application_data(&tail)).await.unwrap(); .write_all(&invalid_mtproto_record)
.await
.unwrap();
client_side
.write_all(&wrap_tls_application_data(&tail))
.await
.unwrap();
let forwarded = tokio::time::timeout(Duration::from_secs(3), accept_task) let forwarded = tokio::time::timeout(Duration::from_secs(3), accept_task)
.await .await
@ -175,7 +184,10 @@ async fn redteam_01_backend_receives_no_data_after_mtproto_fail() {
b"probe-a".to_vec(), b"probe-a".to_vec(),
) )
.await; .await;
assert!(forwarded.is_empty(), "backend unexpectedly received fallback bytes"); assert!(
forwarded.is_empty(),
"backend unexpectedly received fallback bytes"
);
} }
#[tokio::test] #[tokio::test]
@ -188,7 +200,10 @@ async fn redteam_02_backend_must_never_receive_tls_records_after_mtproto_fail()
b"probe-b".to_vec(), b"probe-b".to_vec(),
) )
.await; .await;
assert_ne!(forwarded[0], 0x17, "received TLS application record despite strict policy"); assert_ne!(
forwarded[0], 0x17,
"received TLS application record despite strict policy"
);
} }
#[tokio::test] #[tokio::test]
@ -200,9 +215,10 @@ async fn redteam_03_masking_duration_must_be_less_than_1ms_when_backend_down() {
cfg.censorship.mask_host = Some("127.0.0.1".to_string()); cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = 1; cfg.censorship.mask_port = 1;
cfg.access.ignore_time_skew = true; cfg.access.ignore_time_skew = true;
cfg.access cfg.access.users.insert(
.users "user".to_string(),
.insert("user".to_string(), "acacacacacacacacacacacacacacacac".to_string()); "acacacacacacacacacacacacacacacac".to_string(),
);
let harness = RedTeamHarness { let harness = RedTeamHarness {
config: Arc::new(cfg), config: Arc::new(cfg),
@ -261,7 +277,10 @@ async fn redteam_03_masking_duration_must_be_less_than_1ms_when_backend_down() {
.unwrap() .unwrap()
.unwrap(); .unwrap();
assert!(started.elapsed() < Duration::from_millis(1), "fallback path took longer than 1ms"); assert!(
started.elapsed() < Duration::from_millis(1),
"fallback path took longer than 1ms"
);
} }
macro_rules! redteam_tail_must_not_forward_case { macro_rules! redteam_tail_must_not_forward_case {
@ -283,18 +302,90 @@ macro_rules! redteam_tail_must_not_forward_case {
}; };
} }
redteam_tail_must_not_forward_case!(redteam_04_tail_len_1_not_forwarded, "adadadadadadadadadadadadadadadad", [0xAD; 16], 4, 1); redteam_tail_must_not_forward_case!(
redteam_tail_must_not_forward_case!(redteam_05_tail_len_2_not_forwarded, "aeaeaeaeaeaeaeaeaeaeaeaeaeaeaeae", [0xAE; 16], 5, 2); redteam_04_tail_len_1_not_forwarded,
redteam_tail_must_not_forward_case!(redteam_06_tail_len_3_not_forwarded, "afafafafafafafafafafafafafafafaf", [0xAF; 16], 6, 3); "adadadadadadadadadadadadadadadad",
redteam_tail_must_not_forward_case!(redteam_07_tail_len_7_not_forwarded, "b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0", [0xB0; 16], 7, 7); [0xAD; 16],
redteam_tail_must_not_forward_case!(redteam_08_tail_len_15_not_forwarded, "b1b1b1b1b1b1b1b1b1b1b1b1b1b1b1b1", [0xB1; 16], 8, 15); 4,
redteam_tail_must_not_forward_case!(redteam_09_tail_len_63_not_forwarded, "b2b2b2b2b2b2b2b2b2b2b2b2b2b2b2b2", [0xB2; 16], 9, 63); 1
redteam_tail_must_not_forward_case!(redteam_10_tail_len_127_not_forwarded, "b3b3b3b3b3b3b3b3b3b3b3b3b3b3b3b3", [0xB3; 16], 10, 127); );
redteam_tail_must_not_forward_case!(redteam_11_tail_len_255_not_forwarded, "b4b4b4b4b4b4b4b4b4b4b4b4b4b4b4b4", [0xB4; 16], 11, 255); redteam_tail_must_not_forward_case!(
redteam_tail_must_not_forward_case!(redteam_12_tail_len_511_not_forwarded, "b5b5b5b5b5b5b5b5b5b5b5b5b5b5b5b5", [0xB5; 16], 12, 511); redteam_05_tail_len_2_not_forwarded,
redteam_tail_must_not_forward_case!(redteam_13_tail_len_1023_not_forwarded, "b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6", [0xB6; 16], 13, 1023); "aeaeaeaeaeaeaeaeaeaeaeaeaeaeaeae",
redteam_tail_must_not_forward_case!(redteam_14_tail_len_2047_not_forwarded, "b7b7b7b7b7b7b7b7b7b7b7b7b7b7b7b7", [0xB7; 16], 14, 2047); [0xAE; 16],
redteam_tail_must_not_forward_case!(redteam_15_tail_len_4095_not_forwarded, "b8b8b8b8b8b8b8b8b8b8b8b8b8b8b8b8", [0xB8; 16], 15, 4095); 5,
2
);
redteam_tail_must_not_forward_case!(
redteam_06_tail_len_3_not_forwarded,
"afafafafafafafafafafafafafafafaf",
[0xAF; 16],
6,
3
);
redteam_tail_must_not_forward_case!(
redteam_07_tail_len_7_not_forwarded,
"b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0",
[0xB0; 16],
7,
7
);
redteam_tail_must_not_forward_case!(
redteam_08_tail_len_15_not_forwarded,
"b1b1b1b1b1b1b1b1b1b1b1b1b1b1b1b1",
[0xB1; 16],
8,
15
);
redteam_tail_must_not_forward_case!(
redteam_09_tail_len_63_not_forwarded,
"b2b2b2b2b2b2b2b2b2b2b2b2b2b2b2b2",
[0xB2; 16],
9,
63
);
redteam_tail_must_not_forward_case!(
redteam_10_tail_len_127_not_forwarded,
"b3b3b3b3b3b3b3b3b3b3b3b3b3b3b3b3",
[0xB3; 16],
10,
127
);
redteam_tail_must_not_forward_case!(
redteam_11_tail_len_255_not_forwarded,
"b4b4b4b4b4b4b4b4b4b4b4b4b4b4b4b4",
[0xB4; 16],
11,
255
);
redteam_tail_must_not_forward_case!(
redteam_12_tail_len_511_not_forwarded,
"b5b5b5b5b5b5b5b5b5b5b5b5b5b5b5b5",
[0xB5; 16],
12,
511
);
redteam_tail_must_not_forward_case!(
redteam_13_tail_len_1023_not_forwarded,
"b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6",
[0xB6; 16],
13,
1023
);
redteam_tail_must_not_forward_case!(
redteam_14_tail_len_2047_not_forwarded,
"b7b7b7b7b7b7b7b7b7b7b7b7b7b7b7b7",
[0xB7; 16],
14,
2047
);
redteam_tail_must_not_forward_case!(
redteam_15_tail_len_4095_not_forwarded,
"b8b8b8b8b8b8b8b8b8b8b8b8b8b8b8b8",
[0xB8; 16],
15,
4095
);
#[tokio::test] #[tokio::test]
#[ignore = "red-team expected-fail: impossible indistinguishability envelope"] #[ignore = "red-team expected-fail: impossible indistinguishability envelope"]
@ -349,14 +440,13 @@ async fn redteam_16_timing_delta_between_paths_must_be_sub_1ms_under_concurrency
let min = durations.iter().copied().min().unwrap(); let min = durations.iter().copied().min().unwrap();
let max = durations.iter().copied().max().unwrap(); let max = durations.iter().copied().max().unwrap();
assert!(max - min <= Duration::from_millis(1), "timing spread too wide for strict anti-probing envelope"); assert!(
max - min <= Duration::from_millis(1),
"timing spread too wide for strict anti-probing envelope"
);
} }
async fn measure_invalid_probe_duration_ms( async fn measure_invalid_probe_duration_ms(delay_ms: u64, tls_len: u16, body_sent: usize) -> u128 {
delay_ms: u64,
tls_len: u16,
body_sent: usize,
) -> u128 {
let mut cfg = ProxyConfig::default(); let mut cfg = ProxyConfig::default();
cfg.general.beobachten = false; cfg.general.beobachten = false;
cfg.censorship.mask = true; cfg.censorship.mask = true;
@ -501,7 +591,8 @@ macro_rules! redteam_timing_envelope_case {
#[tokio::test] #[tokio::test]
#[ignore = "red-team expected-fail: unrealistically tight reject timing envelope"] #[ignore = "red-team expected-fail: unrealistically tight reject timing envelope"]
async fn $name() { async fn $name() {
let elapsed_ms = measure_invalid_probe_duration_ms($delay_ms, $tls_len, $body_sent).await; let elapsed_ms =
measure_invalid_probe_duration_ms($delay_ms, $tls_len, $body_sent).await;
assert!( assert!(
elapsed_ms <= $max_ms, elapsed_ms <= $max_ms,
"timing envelope violated: elapsed={}ms, max={}ms", "timing envelope violated: elapsed={}ms, max={}ms",
@ -519,11 +610,9 @@ macro_rules! redteam_constant_shape_case {
async fn $name() { async fn $name() {
let got = capture_forwarded_probe_len($tls_len, $body_sent).await; let got = capture_forwarded_probe_len($tls_len, $body_sent).await;
assert_eq!( assert_eq!(
got, got, $expected_len,
$expected_len,
"fingerprint shape mismatch: got={} expected={} (strict constant-shape model)", "fingerprint shape mismatch: got={} expected={} (strict constant-shape model)",
got, got, $expected_len
$expected_len
); );
} }
}; };

View File

@ -1,7 +1,7 @@
use super::*; use super::*;
use crate::config::{UpstreamConfig, UpstreamType}; use crate::config::{UpstreamConfig, UpstreamType};
use std::sync::Arc; use std::sync::Arc;
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio::time::Duration; use tokio::time::Duration;
@ -172,7 +172,10 @@ async fn redteam_fuzz_01_hardened_output_length_correlation_should_be_below_0_2(
let y_hard: Vec<f64> = hardened.iter().map(|v| *v as f64).collect(); let y_hard: Vec<f64> = hardened.iter().map(|v| *v as f64).collect();
let corr_hard = pearson_corr(&x, &y_hard).abs(); let corr_hard = pearson_corr(&x, &y_hard).abs();
println!("redteam_fuzz corr_hardened={corr_hard:.4} samples={}", sizes.len()); println!(
"redteam_fuzz corr_hardened={corr_hard:.4} samples={}",
sizes.len()
);
assert!( assert!(
corr_hard < 0.2, corr_hard < 0.2,
@ -234,9 +237,7 @@ async fn redteam_fuzz_03_hardened_signal_must_be_10x_lower_than_plain() {
let corr_plain = pearson_corr(&x, &y_plain).abs(); let corr_plain = pearson_corr(&x, &y_plain).abs();
let corr_hard = pearson_corr(&x, &y_hard).abs(); let corr_hard = pearson_corr(&x, &y_hard).abs();
println!( println!("redteam_fuzz corr_plain={corr_plain:.4} corr_hardened={corr_hard:.4}");
"redteam_fuzz corr_plain={corr_plain:.4} corr_hardened={corr_hard:.4}"
);
assert!( assert!(
corr_hard <= corr_plain * 0.1, corr_hard <= corr_plain * 0.1,

View File

@ -1,7 +1,7 @@
use super::*; use super::*;
use crate::config::{UpstreamConfig, UpstreamType}; use crate::config::{UpstreamConfig, UpstreamType};
use std::sync::Arc; use std::sync::Arc;
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio::time::Duration; use tokio::time::Duration;

View File

@ -1,7 +1,7 @@
use super::*; use super::*;
use crate::config::{UpstreamConfig, UpstreamType}; use crate::config::{UpstreamConfig, UpstreamType};
use std::sync::Arc; use std::sync::Arc;
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio::time::{Duration, Instant}; use tokio::time::{Duration, Instant};
@ -164,10 +164,7 @@ async fn redteam_shape_02_padding_tail_must_be_non_deterministic() {
let cap = 4096usize; let cap = 4096usize;
let got = run_probe_capture(17, 600, true, floor, cap).await; let got = run_probe_capture(17, 600, true, floor, cap).await;
assert!( assert!(got.len() > 22, "test requires padding tail to exist");
got.len() > 22,
"test requires padding tail to exist"
);
let tail = &got[22..]; let tail = &got[22..];
assert!( assert!(
@ -194,7 +191,9 @@ async fn redteam_shape_03_exact_floor_input_should_not_be_fixed_point() {
async fn redteam_shape_04_all_sub_cap_sizes_should_collapse_to_single_size() { async fn redteam_shape_04_all_sub_cap_sizes_should_collapse_to_single_size() {
let floor = 512usize; let floor = 512usize;
let cap = 4096usize; let cap = 4096usize;
let classes = [17usize, 63usize, 255usize, 511usize, 1023usize, 2047usize, 3071usize]; let classes = [
17usize, 63usize, 255usize, 511usize, 1023usize, 2047usize, 3071usize,
];
let mut observed = Vec::new(); let mut observed = Vec::new();
for body in classes { for body in classes {
@ -203,7 +202,10 @@ async fn redteam_shape_04_all_sub_cap_sizes_should_collapse_to_single_size() {
let first = observed[0]; let first = observed[0];
for v in observed { for v in observed {
assert_eq!(v, first, "strict model expects one collapsed class across all sub-cap probes"); assert_eq!(
v, first,
"strict model expects one collapsed class across all sub-cap probes"
);
} }
} }

View File

@ -1,7 +1,7 @@
use super::*; use super::*;
use crate::config::{UpstreamConfig, UpstreamType}; use crate::config::{UpstreamConfig, UpstreamType};
use std::sync::Arc; use std::sync::Arc;
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio::time::Duration; use tokio::time::Duration;

View File

@ -3,7 +3,7 @@ use crate::config::{UpstreamConfig, UpstreamType};
use crate::crypto::sha256_hmac; use crate::crypto::sha256_hmac;
use crate::protocol::constants::{HANDSHAKE_LEN, TLS_RECORD_APPLICATION, TLS_VERSION}; use crate::protocol::constants::{HANDSHAKE_LEN, TLS_RECORD_APPLICATION, TLS_VERSION};
use crate::protocol::tls; use crate::protocol::tls;
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio::time::Duration; use tokio::time::Duration;
@ -70,7 +70,10 @@ fn build_harness(mask_port: u16, secret_hex: &str) -> StressHarness {
} }
fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec<u8> { fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec<u8> {
assert!(tls_len <= u16::MAX as usize, "TLS length must fit into record header"); assert!(
tls_len <= u16::MAX as usize,
"TLS length must fit into record header"
);
let total_len = 5 + tls_len; let total_len = 5 + tls_len;
let mut handshake = vec![fill; total_len]; let mut handshake = vec![fill; total_len];
@ -150,12 +153,8 @@ async fn run_parallel_tail_fallback_case(
for idx in 0..sessions { for idx in 0..sessions {
let harness = build_harness(backend_addr.port(), "e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0"); let harness = build_harness(backend_addr.port(), "e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0");
let hello = make_valid_tls_client_hello( let hello =
&[0xE0; 16], make_valid_tls_client_hello(&[0xE0; 16], ts_base + idx as u32, 600, 0x40 + (idx as u8));
ts_base + idx as u32,
600,
0x40 + (idx as u8),
);
let invalid_mtproto = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); let invalid_mtproto = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]);
let payload = vec![((idx * 37) & 0xff) as u8; payload_len + idx % 3]; let payload = vec![((idx * 37) & 0xff) as u8; payload_len + idx % 3];
@ -194,7 +193,10 @@ async fn run_parallel_tail_fallback_case(
client_side.write_all(&hello).await.unwrap(); client_side.write_all(&hello).await.unwrap();
let mut server_hello_head = [0u8; 5]; let mut server_hello_head = [0u8; 5];
client_side.read_exact(&mut server_hello_head).await.unwrap(); client_side
.read_exact(&mut server_hello_head)
.await
.unwrap();
assert_eq!(server_hello_head[0], 0x16); assert_eq!(server_hello_head[0], 0x16);
read_tls_record_body(&mut client_side, server_hello_head).await; read_tls_record_body(&mut client_side, server_hello_head).await;

View File

@ -8,7 +8,7 @@ use crate::proxy::handshake::HandshakeSuccess;
use crate::stream::{CryptoReader, CryptoWriter}; use crate::stream::{CryptoReader, CryptoWriter};
use crate::transport::proxy_protocol::ProxyProtocolV1Builder; use crate::transport::proxy_protocol::ProxyProtocolV1Builder;
use std::net::Ipv4Addr; use std::net::Ipv4Addr;
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::net::{TcpListener, TcpStream}; use tokio::net::{TcpListener, TcpStream};
#[test] #[test]
@ -57,16 +57,24 @@ async fn user_connection_reservation_drop_enqueues_cleanup_synchronously() {
assert_eq!(ip_tracker.get_active_ip_count(&user).await, 1); assert_eq!(ip_tracker.get_active_ip_count(&user).await, 1);
assert_eq!(stats.get_user_curr_connects(&user), 1); assert_eq!(stats.get_user_curr_connects(&user), 1);
let reservation = UserConnectionReservation::new(stats.clone(), ip_tracker.clone(), user.clone(), ip); let reservation =
UserConnectionReservation::new(stats.clone(), ip_tracker.clone(), user.clone(), ip);
// Drop the reservation synchronously without any tokio::spawn/await yielding! // Drop the reservation synchronously without any tokio::spawn/await yielding!
drop(reservation); drop(reservation);
// The IP is now inside the cleanup_queue, check that the queue has length 1 // The IP is now inside the cleanup_queue, check that the queue has length 1
let queue_len = ip_tracker.cleanup_queue.lock().unwrap().len(); let queue_len = ip_tracker.cleanup_queue_len_for_tests();
assert_eq!(queue_len, 1, "Reservation drop must push directly to synchronized IP queue"); assert_eq!(
queue_len, 1,
"Reservation drop must push directly to synchronized IP queue"
);
assert_eq!(stats.get_user_curr_connects(&user), 0, "Stats must decrement immediately"); assert_eq!(
stats.get_user_curr_connects(&user),
0,
"Stats must decrement immediately"
);
ip_tracker.drain_cleanup_queue().await; ip_tracker.drain_cleanup_queue().await;
assert_eq!(ip_tracker.get_active_ip_count(&user).await, 0); assert_eq!(ip_tracker.get_active_ip_count(&user).await, 0);
@ -286,7 +294,10 @@ async fn relay_cutover_releases_user_gate_and_ip_reservation() {
.await .await
.expect("relay must terminate after cutover") .expect("relay must terminate after cutover")
.expect("relay task must not panic"); .expect("relay task must not panic");
assert!(relay_result.is_err(), "cutover must terminate direct relay session"); assert!(
relay_result.is_err(),
"cutover must terminate direct relay session"
);
assert_eq!( assert_eq!(
stats.get_user_curr_connects(user), stats.get_user_curr_connects(user),
@ -447,7 +458,12 @@ async fn stress_drop_without_release_converges_to_zero_user_and_ip_state() {
let mut reservations = Vec::new(); let mut reservations = Vec::new();
for idx in 0..512u16 { for idx in 0..512u16 {
let peer = std::net::SocketAddr::new( let peer = std::net::SocketAddr::new(
std::net::IpAddr::V4(std::net::Ipv4Addr::new(198, 51, (idx >> 8) as u8, (idx & 0xff) as u8)), std::net::IpAddr::V4(std::net::Ipv4Addr::new(
198,
51,
(idx >> 8) as u8,
(idx & 0xff) as u8,
)),
30_000 + idx, 30_000 + idx,
); );
let reservation = RunningClientHandler::acquire_user_connection_reservation_static( let reservation = RunningClientHandler::acquire_user_connection_reservation_static(
@ -510,10 +526,15 @@ async fn proxy_protocol_header_is_rejected_when_trust_list_is_empty() {
false, false,
stats.clone(), stats.clone(),
)); ));
let replay_checker = std::sync::Arc::new(crate::stats::ReplayChecker::new(128, std::time::Duration::from_secs(60))); let replay_checker = std::sync::Arc::new(crate::stats::ReplayChecker::new(
128,
std::time::Duration::from_secs(60),
));
let buffer_pool = std::sync::Arc::new(crate::stream::BufferPool::new()); let buffer_pool = std::sync::Arc::new(crate::stream::BufferPool::new());
let rng = std::sync::Arc::new(crate::crypto::SecureRandom::new()); let rng = std::sync::Arc::new(crate::crypto::SecureRandom::new());
let route_runtime = std::sync::Arc::new(crate::proxy::route_mode::RouteRuntimeController::new(crate::proxy::route_mode::RelayRouteMode::Direct)); let route_runtime = std::sync::Arc::new(crate::proxy::route_mode::RouteRuntimeController::new(
crate::proxy::route_mode::RelayRouteMode::Direct,
));
let ip_tracker = std::sync::Arc::new(crate::ip_tracker::UserIpTracker::new()); let ip_tracker = std::sync::Arc::new(crate::ip_tracker::UserIpTracker::new());
let beobachten = std::sync::Arc::new(crate::stats::beobachten::BeobachtenStore::new()); let beobachten = std::sync::Arc::new(crate::stats::beobachten::BeobachtenStore::new());
@ -581,10 +602,16 @@ async fn proxy_protocol_header_from_untrusted_peer_range_is_rejected_under_load(
false, false,
stats.clone(), stats.clone(),
)); ));
let replay_checker = std::sync::Arc::new(crate::stats::ReplayChecker::new(64, std::time::Duration::from_secs(60))); let replay_checker = std::sync::Arc::new(crate::stats::ReplayChecker::new(
64,
std::time::Duration::from_secs(60),
));
let buffer_pool = std::sync::Arc::new(crate::stream::BufferPool::new()); let buffer_pool = std::sync::Arc::new(crate::stream::BufferPool::new());
let rng = std::sync::Arc::new(crate::crypto::SecureRandom::new()); let rng = std::sync::Arc::new(crate::crypto::SecureRandom::new());
let route_runtime = std::sync::Arc::new(crate::proxy::route_mode::RouteRuntimeController::new(crate::proxy::route_mode::RelayRouteMode::Direct)); let route_runtime =
std::sync::Arc::new(crate::proxy::route_mode::RouteRuntimeController::new(
crate::proxy::route_mode::RelayRouteMode::Direct,
));
let ip_tracker = std::sync::Arc::new(crate::ip_tracker::UserIpTracker::new()); let ip_tracker = std::sync::Arc::new(crate::ip_tracker::UserIpTracker::new());
let beobachten = std::sync::Arc::new(crate::stats::beobachten::BeobachtenStore::new()); let beobachten = std::sync::Arc::new(crate::stats::beobachten::BeobachtenStore::new());
@ -669,8 +696,16 @@ async fn reservation_limit_failure_does_not_leak_curr_connects_counter() {
matches!(second, Err(crate::error::ProxyError::ConnectionLimitExceeded { user: denied }) if denied == user), matches!(second, Err(crate::error::ProxyError::ConnectionLimitExceeded { user: denied }) if denied == user),
"second reservation must be rejected at the configured tcp-conns limit" "second reservation must be rejected at the configured tcp-conns limit"
); );
assert_eq!(stats.get_user_curr_connects(user), 1, "failed acquisition must not leak a counter increment"); assert_eq!(
assert_eq!(ip_tracker.get_active_ip_count(user).await, 1, "failed acquisition must not mutate IP tracker state"); stats.get_user_curr_connects(user),
1,
"failed acquisition must not leak a counter increment"
);
assert_eq!(
ip_tracker.get_active_ip_count(user).await,
1,
"failed acquisition must not mutate IP tracker state"
);
first.release().await; first.release().await;
ip_tracker.drain_cleanup_queue().await; ip_tracker.drain_cleanup_queue().await;
@ -1119,7 +1154,10 @@ async fn partial_tls_header_stall_triggers_handshake_timeout() {
} }
fn make_valid_tls_client_hello_with_len(secret: &[u8], timestamp: u32, tls_len: usize) -> Vec<u8> { fn make_valid_tls_client_hello_with_len(secret: &[u8], timestamp: u32, tls_len: usize) -> Vec<u8> {
assert!(tls_len <= u16::MAX as usize, "TLS length must fit into record header"); assert!(
tls_len <= u16::MAX as usize,
"TLS length must fit into record header"
);
let total_len = 5 + tls_len; let total_len = 5 + tls_len;
let mut handshake = vec![0x42u8; total_len]; let mut handshake = vec![0x42u8; total_len];
@ -1140,7 +1178,8 @@ fn make_valid_tls_client_hello_with_len(secret: &[u8], timestamp: u32, tls_len:
digest[28 + i] ^= ts[i]; digest[28 + i] ^= ts[i];
} }
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].copy_from_slice(&digest); handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN]
.copy_from_slice(&digest);
handshake handshake
} }
@ -1203,8 +1242,7 @@ fn make_valid_tls_client_hello_with_alpn(
digest[28 + i] ^= ts[i]; digest[28 + i] ^= ts[i];
} }
record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].copy_from_slice(&digest);
.copy_from_slice(&digest);
record record
} }
@ -1233,9 +1271,10 @@ async fn valid_tls_path_does_not_fall_back_to_mask_backend() {
cfg.censorship.mask_port = backend_addr.port(); cfg.censorship.mask_port = backend_addr.port();
cfg.censorship.mask_proxy_protocol = 0; cfg.censorship.mask_proxy_protocol = 0;
cfg.access.ignore_time_skew = true; cfg.access.ignore_time_skew = true;
cfg.access cfg.access.users.insert(
.users "user".to_string(),
.insert("user".to_string(), "11111111111111111111111111111111".to_string()); "11111111111111111111111111111111".to_string(),
);
let config = Arc::new(cfg); let config = Arc::new(cfg);
let stats = Arc::new(Stats::new()); let stats = Arc::new(Stats::new());
@ -1307,8 +1346,7 @@ async fn valid_tls_path_does_not_fall_back_to_mask_backend() {
let bad_after = stats_for_assert.get_connects_bad(); let bad_after = stats_for_assert.get_connects_bad();
assert_eq!( assert_eq!(
bad_before, bad_before, bad_after,
bad_after,
"Authenticated TLS path must not increment connects_bad" "Authenticated TLS path must not increment connects_bad"
); );
} }
@ -1341,9 +1379,10 @@ async fn valid_tls_with_invalid_mtproto_falls_back_to_mask_backend() {
cfg.censorship.mask_port = backend_addr.port(); cfg.censorship.mask_port = backend_addr.port();
cfg.censorship.mask_proxy_protocol = 0; cfg.censorship.mask_proxy_protocol = 0;
cfg.access.ignore_time_skew = true; cfg.access.ignore_time_skew = true;
cfg.access cfg.access.users.insert(
.users "user".to_string(),
.insert("user".to_string(), "33333333333333333333333333333333".to_string()); "33333333333333333333333333333333".to_string(),
);
let config = Arc::new(cfg); let config = Arc::new(cfg);
let stats = Arc::new(Stats::new()); let stats = Arc::new(Stats::new());
@ -1394,7 +1433,10 @@ async fn valid_tls_with_invalid_mtproto_falls_back_to_mask_backend() {
client_side.write_all(&client_hello).await.unwrap(); client_side.write_all(&client_hello).await.unwrap();
let mut tls_response_head = [0u8; 5]; let mut tls_response_head = [0u8; 5];
client_side.read_exact(&mut tls_response_head).await.unwrap(); client_side
.read_exact(&mut tls_response_head)
.await
.unwrap();
assert_eq!(tls_response_head[0], 0x16); assert_eq!(tls_response_head[0], 0x16);
client_side.write_all(&tls_app_record).await.unwrap(); client_side.write_all(&tls_app_record).await.unwrap();
@ -1443,9 +1485,10 @@ async fn client_handler_tls_bad_mtproto_is_forwarded_to_mask_backend() {
cfg.censorship.mask_port = backend_addr.port(); cfg.censorship.mask_port = backend_addr.port();
cfg.censorship.mask_proxy_protocol = 0; cfg.censorship.mask_proxy_protocol = 0;
cfg.access.ignore_time_skew = true; cfg.access.ignore_time_skew = true;
cfg.access cfg.access.users.insert(
.users "user".to_string(),
.insert("user".to_string(), "44444444444444444444444444444444".to_string()); "44444444444444444444444444444444".to_string(),
);
let config = Arc::new(cfg); let config = Arc::new(cfg);
let stats = Arc::new(Stats::new()); let stats = Arc::new(Stats::new());
@ -1563,9 +1606,10 @@ async fn alpn_mismatch_tls_probe_is_masked_through_client_pipeline() {
cfg.censorship.mask_proxy_protocol = 0; cfg.censorship.mask_proxy_protocol = 0;
cfg.censorship.alpn_enforce = true; cfg.censorship.alpn_enforce = true;
cfg.access.ignore_time_skew = true; cfg.access.ignore_time_skew = true;
cfg.access cfg.access.users.insert(
.users "user".to_string(),
.insert("user".to_string(), "66666666666666666666666666666666".to_string()); "66666666666666666666666666666666".to_string(),
);
let config = Arc::new(cfg); let config = Arc::new(cfg);
let stats = Arc::new(Stats::new()); let stats = Arc::new(Stats::new());
@ -1654,9 +1698,10 @@ async fn invalid_hmac_tls_probe_is_masked_through_client_pipeline() {
cfg.censorship.mask_port = backend_addr.port(); cfg.censorship.mask_port = backend_addr.port();
cfg.censorship.mask_proxy_protocol = 0; cfg.censorship.mask_proxy_protocol = 0;
cfg.access.ignore_time_skew = true; cfg.access.ignore_time_skew = true;
cfg.access cfg.access.users.insert(
.users "user".to_string(),
.insert("user".to_string(), "77777777777777777777777777777777".to_string()); "77777777777777777777777777777777".to_string(),
);
let config = Arc::new(cfg); let config = Arc::new(cfg);
let stats = Arc::new(Stats::new()); let stats = Arc::new(Stats::new());
@ -1751,9 +1796,10 @@ async fn burst_invalid_tls_probes_are_masked_verbatim() {
cfg.censorship.mask_port = backend_addr.port(); cfg.censorship.mask_port = backend_addr.port();
cfg.censorship.mask_proxy_protocol = 0; cfg.censorship.mask_proxy_protocol = 0;
cfg.access.ignore_time_skew = true; cfg.access.ignore_time_skew = true;
cfg.access cfg.access.users.insert(
.users "user".to_string(),
.insert("user".to_string(), "88888888888888888888888888888888".to_string()); "88888888888888888888888888888888".to_string(),
);
let config = Arc::new(cfg); let config = Arc::new(cfg);
let stats = Arc::new(Stats::new()); let stats = Arc::new(Stats::new());
@ -1981,10 +2027,7 @@ async fn zero_tcp_limit_rejects_without_ip_or_counter_side_effects() {
async fn check_user_limits_static_success_does_not_leak_counter_or_ip_reservation() { async fn check_user_limits_static_success_does_not_leak_counter_or_ip_reservation() {
let user = "check-helper-user"; let user = "check-helper-user";
let mut config = ProxyConfig::default(); let mut config = ProxyConfig::default();
config config.access.user_max_tcp_conns.insert(user.to_string(), 1);
.access
.user_max_tcp_conns
.insert(user.to_string(), 1);
let stats = Stats::new(); let stats = Stats::new();
let ip_tracker = UserIpTracker::new(); let ip_tracker = UserIpTracker::new();
@ -1998,7 +2041,10 @@ async fn check_user_limits_static_success_does_not_leak_counter_or_ip_reservatio
&ip_tracker, &ip_tracker,
) )
.await; .await;
assert!(first.is_ok(), "first check-only limit validation must succeed"); assert!(
first.is_ok(),
"first check-only limit validation must succeed"
);
let second = RunningClientHandler::check_user_limits_static( let second = RunningClientHandler::check_user_limits_static(
user, user,
@ -2008,7 +2054,10 @@ async fn check_user_limits_static_success_does_not_leak_counter_or_ip_reservatio
&ip_tracker, &ip_tracker,
) )
.await; .await;
assert!(second.is_ok(), "second check-only validation must not fail from leaked state"); assert!(
second.is_ok(),
"second check-only validation must not fail from leaked state"
);
assert_eq!(stats.get_user_curr_connects(user), 0); assert_eq!(stats.get_user_curr_connects(user), 0);
assert_eq!(ip_tracker.get_active_ip_count(user).await, 0); assert_eq!(ip_tracker.get_active_ip_count(user).await, 0);
} }
@ -2017,10 +2066,7 @@ async fn check_user_limits_static_success_does_not_leak_counter_or_ip_reservatio
async fn stress_check_user_limits_static_success_never_leaks_state() { async fn stress_check_user_limits_static_success_never_leaks_state() {
let user = "check-helper-stress-user"; let user = "check-helper-stress-user";
let mut config = ProxyConfig::default(); let mut config = ProxyConfig::default();
config config.access.user_max_tcp_conns.insert(user.to_string(), 1);
.access
.user_max_tcp_conns
.insert(user.to_string(), 1);
let stats = Stats::new(); let stats = Stats::new();
let ip_tracker = UserIpTracker::new(); let ip_tracker = UserIpTracker::new();
@ -2039,7 +2085,10 @@ async fn stress_check_user_limits_static_success_never_leaks_state() {
&ip_tracker, &ip_tracker,
) )
.await; .await;
assert!(result.is_ok(), "check-only helper must remain leak-free under stress"); assert!(
result.is_ok(),
"check-only helper must remain leak-free under stress"
);
} }
assert_eq!( assert_eq!(
@ -2090,11 +2139,7 @@ async fn concurrent_distinct_ip_rejections_rollback_user_counter_without_leak()
41000 + i as u16, 41000 + i as u16,
); );
let result = RunningClientHandler::acquire_user_connection_reservation_static( let result = RunningClientHandler::acquire_user_connection_reservation_static(
user, user, &config, stats, peer, ip_tracker,
&config,
stats,
peer,
ip_tracker,
) )
.await; .await;
assert!(matches!( assert!(matches!(
@ -2130,10 +2175,7 @@ async fn explicit_reservation_release_cleans_user_and_ip_immediately() {
let peer_addr: SocketAddr = "198.51.100.240:50002".parse().unwrap(); let peer_addr: SocketAddr = "198.51.100.240:50002".parse().unwrap();
let mut config = ProxyConfig::default(); let mut config = ProxyConfig::default();
config config.access.user_max_tcp_conns.insert(user.to_string(), 4);
.access
.user_max_tcp_conns
.insert(user.to_string(), 4);
let stats = Arc::new(Stats::new()); let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new()); let ip_tracker = Arc::new(UserIpTracker::new());
@ -2171,10 +2213,7 @@ async fn explicit_reservation_release_does_not_double_decrement_on_drop() {
let peer_addr: SocketAddr = "198.51.100.241:50003".parse().unwrap(); let peer_addr: SocketAddr = "198.51.100.241:50003".parse().unwrap();
let mut config = ProxyConfig::default(); let mut config = ProxyConfig::default();
config config.access.user_max_tcp_conns.insert(user.to_string(), 4);
.access
.user_max_tcp_conns
.insert(user.to_string(), 4);
let stats = Arc::new(Stats::new()); let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new()); let ip_tracker = Arc::new(UserIpTracker::new());
@ -2204,10 +2243,7 @@ async fn drop_fallback_eventually_cleans_user_and_ip_reservation() {
let peer_addr: SocketAddr = "198.51.100.242:50004".parse().unwrap(); let peer_addr: SocketAddr = "198.51.100.242:50004".parse().unwrap();
let mut config = ProxyConfig::default(); let mut config = ProxyConfig::default();
config config.access.user_max_tcp_conns.insert(user.to_string(), 4);
.access
.user_max_tcp_conns
.insert(user.to_string(), 4);
let stats = Arc::new(Stats::new()); let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new()); let ip_tracker = Arc::new(UserIpTracker::new());
@ -2248,10 +2284,7 @@ async fn explicit_release_allows_immediate_cross_ip_reacquire_under_limit() {
let peer2: SocketAddr = "198.51.100.244:50006".parse().unwrap(); let peer2: SocketAddr = "198.51.100.244:50006".parse().unwrap();
let mut config = ProxyConfig::default(); let mut config = ProxyConfig::default();
config config.access.user_max_tcp_conns.insert(user.to_string(), 4);
.access
.user_max_tcp_conns
.insert(user.to_string(), 4);
let stats = Arc::new(Stats::new()); let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new()); let ip_tracker = Arc::new(UserIpTracker::new());
@ -2473,8 +2506,14 @@ async fn parallel_users_abort_release_isolation_preserves_independent_cleanup()
let user_b = "abort-isolation-b"; let user_b = "abort-isolation-b";
let mut config = ProxyConfig::default(); let mut config = ProxyConfig::default();
config.access.user_max_tcp_conns.insert(user_a.to_string(), 64); config
config.access.user_max_tcp_conns.insert(user_b.to_string(), 64); .access
.user_max_tcp_conns
.insert(user_a.to_string(), 64);
config
.access
.user_max_tcp_conns
.insert(user_b.to_string(), 64);
let stats = Arc::new(Stats::new()); let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new()); let ip_tracker = Arc::new(UserIpTracker::new());
@ -2595,10 +2634,7 @@ async fn relay_connect_error_releases_user_and_ip_before_return() {
let ip_tracker = Arc::new(UserIpTracker::new()); let ip_tracker = Arc::new(UserIpTracker::new());
let mut config = ProxyConfig::default(); let mut config = ProxyConfig::default();
config config.access.user_max_tcp_conns.insert(user.to_string(), 1);
.access
.user_max_tcp_conns
.insert(user.to_string(), 1);
config config
.dc_overrides .dc_overrides
.insert("2".to_string(), vec![format!("127.0.0.1:{dead_port}")]); .insert("2".to_string(), vec![format!("127.0.0.1:{dead_port}")]);
@ -2661,7 +2697,10 @@ async fn relay_connect_error_releases_user_and_ip_before_return() {
) )
.await; .await;
assert!(result.is_err(), "relay must fail when upstream DC is unreachable"); assert!(
result.is_err(),
"relay must fail when upstream DC is unreachable"
);
assert_eq!( assert_eq!(
stats.get_user_curr_connects(user), stats.get_user_curr_connects(user),
0, 0,
@ -2680,10 +2719,7 @@ async fn mixed_release_and_drop_same_ip_preserves_counter_correctness() {
let peer_addr: SocketAddr = "198.51.100.246:50008".parse().unwrap(); let peer_addr: SocketAddr = "198.51.100.246:50008".parse().unwrap();
let mut config = ProxyConfig::default(); let mut config = ProxyConfig::default();
config config.access.user_max_tcp_conns.insert(user.to_string(), 8);
.access
.user_max_tcp_conns
.insert(user.to_string(), 8);
let stats = Arc::new(Stats::new()); let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new()); let ip_tracker = Arc::new(UserIpTracker::new());
@ -2743,10 +2779,7 @@ async fn drop_one_of_two_same_ip_reservations_keeps_ip_active() {
let peer_addr: SocketAddr = "198.51.100.247:50009".parse().unwrap(); let peer_addr: SocketAddr = "198.51.100.247:50009".parse().unwrap();
let mut config = ProxyConfig::default(); let mut config = ProxyConfig::default();
config config.access.user_max_tcp_conns.insert(user.to_string(), 8);
.access
.user_max_tcp_conns
.insert(user.to_string(), 8);
let stats = Arc::new(Stats::new()); let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new()); let ip_tracker = Arc::new(UserIpTracker::new());
@ -2802,7 +2835,10 @@ async fn drop_one_of_two_same_ip_reservations_keeps_ip_active() {
#[tokio::test] #[tokio::test]
async fn quota_rejection_does_not_reserve_ip_or_trigger_rollback() { async fn quota_rejection_does_not_reserve_ip_or_trigger_rollback() {
let mut config = ProxyConfig::default(); let mut config = ProxyConfig::default();
config.access.user_data_quota.insert("user".to_string(), 1024); config
.access
.user_data_quota
.insert("user".to_string(), 1024);
let stats = Stats::new(); let stats = Stats::new();
stats.add_user_octets_from("user", 1024); stats.add_user_octets_from("user", 1024);
@ -2838,10 +2874,10 @@ async fn quota_rejection_does_not_reserve_ip_or_trigger_rollback() {
#[tokio::test] #[tokio::test]
async fn expired_user_rejection_does_not_reserve_ip_or_increment_curr_connects() { async fn expired_user_rejection_does_not_reserve_ip_or_increment_curr_connects() {
let mut config = ProxyConfig::default(); let mut config = ProxyConfig::default();
config config.access.user_expirations.insert(
.access "user".to_string(),
.user_expirations chrono::Utc::now() - chrono::Duration::seconds(1),
.insert("user".to_string(), chrono::Utc::now() - chrono::Duration::seconds(1)); );
let stats = Stats::new(); let stats = Stats::new();
let ip_tracker = UserIpTracker::new(); let ip_tracker = UserIpTracker::new();
@ -2870,10 +2906,7 @@ async fn same_ip_second_reservation_succeeds_under_unique_ip_limit_one() {
let peer_addr: SocketAddr = "198.51.100.248:50010".parse().unwrap(); let peer_addr: SocketAddr = "198.51.100.248:50010".parse().unwrap();
let mut config = ProxyConfig::default(); let mut config = ProxyConfig::default();
config config.access.user_max_tcp_conns.insert(user.to_string(), 8);
.access
.user_max_tcp_conns
.insert(user.to_string(), 8);
let stats = Arc::new(Stats::new()); let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new()); let ip_tracker = Arc::new(UserIpTracker::new());
@ -2914,10 +2947,7 @@ async fn second_distinct_ip_is_rejected_under_unique_ip_limit_one() {
let peer2: SocketAddr = "198.51.100.250:50012".parse().unwrap(); let peer2: SocketAddr = "198.51.100.250:50012".parse().unwrap();
let mut config = ProxyConfig::default(); let mut config = ProxyConfig::default();
config config.access.user_max_tcp_conns.insert(user.to_string(), 8);
.access
.user_max_tcp_conns
.insert(user.to_string(), 8);
let stats = Arc::new(Stats::new()); let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new()); let ip_tracker = Arc::new(UserIpTracker::new());
@ -2958,10 +2988,7 @@ async fn cross_thread_drop_uses_captured_runtime_for_ip_cleanup() {
let peer_addr: SocketAddr = "198.51.100.251:50013".parse().unwrap(); let peer_addr: SocketAddr = "198.51.100.251:50013".parse().unwrap();
let mut config = ProxyConfig::default(); let mut config = ProxyConfig::default();
config config.access.user_max_tcp_conns.insert(user.to_string(), 8);
.access
.user_max_tcp_conns
.insert(user.to_string(), 8);
let stats = Arc::new(Stats::new()); let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new()); let ip_tracker = Arc::new(UserIpTracker::new());
@ -3005,10 +3032,7 @@ async fn immediate_reacquire_after_cross_thread_drop_succeeds() {
let peer_addr: SocketAddr = "198.51.100.252:50014".parse().unwrap(); let peer_addr: SocketAddr = "198.51.100.252:50014".parse().unwrap();
let mut config = ProxyConfig::default(); let mut config = ProxyConfig::default();
config config.access.user_max_tcp_conns.insert(user.to_string(), 1);
.access
.user_max_tcp_conns
.insert(user.to_string(), 1);
let stats = Arc::new(Stats::new()); let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new()); let ip_tracker = Arc::new(UserIpTracker::new());
@ -3043,11 +3067,7 @@ async fn immediate_reacquire_after_cross_thread_drop_succeeds() {
.expect("cross-thread cleanup must settle before reacquire check"); .expect("cross-thread cleanup must settle before reacquire check");
let reacquire = RunningClientHandler::acquire_user_connection_reservation_static( let reacquire = RunningClientHandler::acquire_user_connection_reservation_static(
user, user, &config, stats, peer_addr, ip_tracker,
&config,
stats,
peer_addr,
ip_tracker,
) )
.await; .await;
assert!( assert!(
@ -3113,10 +3133,7 @@ async fn concurrent_limit_rejections_from_mixed_ips_leave_no_ip_footprint() {
.get_recent_ips_for_users(&["user".to_string()]) .get_recent_ips_for_users(&["user".to_string()])
.await; .await;
assert!( assert!(
recent recent.get("user").map(|ips| ips.is_empty()).unwrap_or(true),
.get("user")
.map(|ips| ips.is_empty())
.unwrap_or(true),
"Concurrent rejected attempts must not leave recent IP footprint" "Concurrent rejected attempts must not leave recent IP footprint"
); );
@ -3150,11 +3167,7 @@ async fn atomic_limit_gate_allows_only_one_concurrent_acquire() {
30000 + i, 30000 + i,
); );
RunningClientHandler::acquire_user_connection_reservation_static( RunningClientHandler::acquire_user_connection_reservation_static(
"user", "user", &config, stats, peer, ip_tracker,
&config,
stats,
peer,
ip_tracker,
) )
.await .await
.ok() .ok()
@ -3769,9 +3782,10 @@ async fn tls_record_len_16384_is_accepted_in_generic_stream_pipeline() {
cfg.censorship.mask_port = backend_addr.port(); cfg.censorship.mask_port = backend_addr.port();
cfg.censorship.mask_proxy_protocol = 0; cfg.censorship.mask_proxy_protocol = 0;
cfg.access.ignore_time_skew = true; cfg.access.ignore_time_skew = true;
cfg.access cfg.access.users.insert(
.users "user".to_string(),
.insert("user".to_string(), "55555555555555555555555555555555".to_string()); "55555555555555555555555555555555".to_string(),
);
let config = Arc::new(cfg); let config = Arc::new(cfg);
let stats = Arc::new(Stats::new()); let stats = Arc::new(Stats::new());
@ -3824,7 +3838,10 @@ async fn tls_record_len_16384_is_accepted_in_generic_stream_pipeline() {
client_side.write_all(&client_hello).await.unwrap(); client_side.write_all(&client_hello).await.unwrap();
let mut record_header = [0u8; 5]; let mut record_header = [0u8; 5];
client_side.read_exact(&mut record_header).await.unwrap(); client_side.read_exact(&mut record_header).await.unwrap();
assert_eq!(record_header[0], 0x16, "Valid max-length ClientHello must be accepted"); assert_eq!(
record_header[0], 0x16,
"Valid max-length ClientHello must be accepted"
);
drop(client_side); drop(client_side);
let handler_result = tokio::time::timeout(Duration::from_secs(3), handler) let handler_result = tokio::time::timeout(Duration::from_secs(3), handler)
@ -3865,9 +3882,10 @@ async fn tls_record_len_16384_is_accepted_in_client_handler_pipeline() {
cfg.censorship.mask_port = backend_addr.port(); cfg.censorship.mask_port = backend_addr.port();
cfg.censorship.mask_proxy_protocol = 0; cfg.censorship.mask_proxy_protocol = 0;
cfg.access.ignore_time_skew = true; cfg.access.ignore_time_skew = true;
cfg.access cfg.access.users.insert(
.users "user".to_string(),
.insert("user".to_string(), "66666666666666666666666666666666".to_string()); "66666666666666666666666666666666".to_string(),
);
let config = Arc::new(cfg); let config = Arc::new(cfg);
let stats = Arc::new(Stats::new()); let stats = Arc::new(Stats::new());
@ -3938,7 +3956,10 @@ async fn tls_record_len_16384_is_accepted_in_client_handler_pipeline() {
let mut record_header = [0u8; 5]; let mut record_header = [0u8; 5];
client.read_exact(&mut record_header).await.unwrap(); client.read_exact(&mut record_header).await.unwrap();
assert_eq!(record_header[0], 0x16, "Valid max-length ClientHello must be accepted"); assert_eq!(
record_header[0], 0x16,
"Valid max-length ClientHello must be accepted"
);
drop(client); drop(client);
@ -3947,7 +3968,8 @@ async fn tls_record_len_16384_is_accepted_in_client_handler_pipeline() {
.unwrap() .unwrap()
.unwrap(); .unwrap();
let no_mask_connect = tokio::time::timeout(Duration::from_millis(250), mask_listener.accept()).await; let no_mask_connect =
tokio::time::timeout(Duration::from_millis(250), mask_listener.accept()).await;
assert!( assert!(
no_mask_connect.is_err(), no_mask_connect.is_err(),
"Valid max-length ClientHello must not trigger mask fallback in ClientHandler path" "Valid max-length ClientHello must not trigger mask fallback in ClientHandler path"
@ -4004,11 +4026,7 @@ async fn burst_acquire_distinct_ips(
55000 + i, 55000 + i,
); );
RunningClientHandler::acquire_user_connection_reservation_static( RunningClientHandler::acquire_user_connection_reservation_static(
user, user, &config, stats, peer, ip_tracker,
&config,
stats,
peer,
ip_tracker,
) )
.await .await
}); });
@ -4190,11 +4208,7 @@ async fn cross_thread_drop_storm_then_parallel_reacquire_wave_has_no_leak() {
54000 + i, 54000 + i,
); );
RunningClientHandler::acquire_user_connection_reservation_static( RunningClientHandler::acquire_user_connection_reservation_static(
user, user, &config, stats, peer, ip_tracker,
&config,
stats,
peer,
ip_tracker,
) )
.await .await
}); });
@ -4228,10 +4242,7 @@ async fn cross_thread_drop_storm_then_parallel_reacquire_wave_has_no_leak() {
async fn scheduled_near_limit_and_burst_windows_preserve_admission_invariants() { async fn scheduled_near_limit_and_burst_windows_preserve_admission_invariants() {
let user: &'static str = "scheduled-attack-user"; let user: &'static str = "scheduled-attack-user";
let mut config = ProxyConfig::default(); let mut config = ProxyConfig::default();
config config.access.user_max_tcp_conns.insert(user.to_string(), 6);
.access
.user_max_tcp_conns
.insert(user.to_string(), 6);
let config = Arc::new(config); let config = Arc::new(config);
let stats = Arc::new(Stats::new()); let stats = Arc::new(Stats::new());
@ -4240,7 +4251,10 @@ async fn scheduled_near_limit_and_burst_windows_preserve_admission_invariants()
let mut base = Vec::new(); let mut base = Vec::new();
for i in 0..5u16 { for i in 0..5u16 {
let peer = SocketAddr::new(IpAddr::V4(std::net::Ipv4Addr::new(198, 51, 130, 1)), 56000 + i); let peer = SocketAddr::new(
IpAddr::V4(std::net::Ipv4Addr::new(198, 51, 130, 1)),
56000 + i,
);
let reservation = RunningClientHandler::acquire_user_connection_reservation_static( let reservation = RunningClientHandler::acquire_user_connection_reservation_static(
user, user,
&config, &config,
@ -4288,15 +4302,8 @@ async fn scheduled_near_limit_and_burst_windows_preserve_admission_invariants()
.await .await
.expect("window cleanup must settle to expected occupancy"); .expect("window cleanup must settle to expected occupancy");
let (wave2_success, wave2_fail) = burst_acquire_distinct_ips( let (wave2_success, wave2_fail) =
user, burst_acquire_distinct_ips(user, config, stats.clone(), ip_tracker.clone(), 132, 32).await;
config,
stats.clone(),
ip_tracker.clone(),
132,
32,
)
.await;
assert_eq!(wave2_success.len(), 1); assert_eq!(wave2_success.len(), 1);
assert_eq!(wave2_fail, 31); assert_eq!(wave2_fail, 31);
assert_eq!(stats.get_user_curr_connects(user), 5); assert_eq!(stats.get_user_curr_connects(user), 5);

View File

@ -7,7 +7,7 @@ use crate::config::{UpstreamConfig, UpstreamType};
use crate::protocol::constants::MIN_TLS_CLIENT_HELLO_SIZE; use crate::protocol::constants::MIN_TLS_CLIENT_HELLO_SIZE;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::net::{TcpListener, TcpStream}; use tokio::net::{TcpListener, TcpStream};
const REPLY_404: &[u8] = b"HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n"; const REPLY_404: &[u8] = b"HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n";
@ -135,7 +135,10 @@ async fn run_generic_once(class: ProbeClass) -> u128 {
client_side.shutdown().await.unwrap(); client_side.shutdown().await.unwrap();
let mut observed = vec![0u8; REPLY_404.len()]; let mut observed = vec![0u8; REPLY_404.len()];
tokio::time::timeout(Duration::from_secs(2), client_side.read_exact(&mut observed)) tokio::time::timeout(
Duration::from_secs(2),
client_side.read_exact(&mut observed),
)
.await .await
.unwrap() .unwrap()
.unwrap(); .unwrap();

View File

@ -7,7 +7,7 @@ use crate::config::{UpstreamConfig, UpstreamType};
use crate::protocol::constants::{MAX_TLS_PLAINTEXT_SIZE, MIN_TLS_CLIENT_HELLO_SIZE}; use crate::protocol::constants::{MAX_TLS_PLAINTEXT_SIZE, MIN_TLS_CLIENT_HELLO_SIZE};
use std::net::SocketAddr; use std::net::SocketAddr;
use std::time::Duration; use std::time::Duration;
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::net::TcpListener; use tokio::net::TcpListener;
fn test_probe_for_len(len: usize) -> [u8; 5] { fn test_probe_for_len(len: usize) -> [u8; 5] {
@ -100,7 +100,10 @@ async fn run_probe_and_assert_masking(len: usize, expect_bad_increment: bool) {
client_side.write_all(&probe).await.unwrap(); client_side.write_all(&probe).await.unwrap();
let mut observed = vec![0u8; backend_reply.len()]; let mut observed = vec![0u8; backend_reply.len()];
client_side.read_exact(&mut observed).await.unwrap(); client_side.read_exact(&mut observed).await.unwrap();
assert_eq!(observed, backend_reply, "invalid TLS path must be masked as a real site"); assert_eq!(
observed, backend_reply,
"invalid TLS path must be masked as a real site"
);
drop(client_side); drop(client_side);
let _ = tokio::time::timeout(Duration::from_secs(3), handler) let _ = tokio::time::timeout(Duration::from_secs(3), handler)
@ -109,7 +112,11 @@ async fn run_probe_and_assert_masking(len: usize, expect_bad_increment: bool) {
.unwrap(); .unwrap();
accept_task.await.unwrap(); accept_task.await.unwrap();
let expected_bad = if expect_bad_increment { bad_before + 1 } else { bad_before }; let expected_bad = if expect_bad_increment {
bad_before + 1
} else {
bad_before
};
assert_eq!( assert_eq!(
stats.get_connects_bad(), stats.get_connects_bad(),
expected_bad, expected_bad,
@ -187,7 +194,9 @@ fn tls_client_hello_len_bounds_stress_many_evaluations() {
for _ in 0..100_000 { for _ in 0..100_000 {
assert!(tls_clienthello_len_in_bounds(MIN_TLS_CLIENT_HELLO_SIZE)); assert!(tls_clienthello_len_in_bounds(MIN_TLS_CLIENT_HELLO_SIZE));
assert!(tls_clienthello_len_in_bounds(MAX_TLS_PLAINTEXT_SIZE)); assert!(tls_clienthello_len_in_bounds(MAX_TLS_PLAINTEXT_SIZE));
assert!(!tls_clienthello_len_in_bounds(MIN_TLS_CLIENT_HELLO_SIZE - 1)); assert!(!tls_clienthello_len_in_bounds(
MIN_TLS_CLIENT_HELLO_SIZE - 1
));
assert!(!tls_clienthello_len_in_bounds(MAX_TLS_PLAINTEXT_SIZE + 1)); assert!(!tls_clienthello_len_in_bounds(MAX_TLS_PLAINTEXT_SIZE + 1));
} }
} }

View File

@ -7,7 +7,7 @@ use crate::config::{UpstreamConfig, UpstreamType};
use crate::protocol::constants::MIN_TLS_CLIENT_HELLO_SIZE; use crate::protocol::constants::MIN_TLS_CLIENT_HELLO_SIZE;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::time::Duration; use std::time::Duration;
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::net::{TcpListener, TcpStream}; use tokio::net::{TcpListener, TcpStream};
use tokio::time::sleep; use tokio::time::sleep;
@ -48,7 +48,12 @@ fn truncated_in_range_record(actual_body_len: usize) -> Vec<u8> {
out out
} }
async fn write_fragmented<W: AsyncWriteExt + Unpin>(writer: &mut W, bytes: &[u8], chunks: &[usize], delay_ms: u64) { async fn write_fragmented<W: AsyncWriteExt + Unpin>(
writer: &mut W,
bytes: &[u8],
chunks: &[usize],
delay_ms: u64,
) {
let mut offset = 0usize; let mut offset = 0usize;
for &chunk in chunks { for &chunk in chunks {
if offset >= bytes.len() { if offset >= bytes.len() {
@ -130,7 +135,10 @@ async fn run_blackhat_generic_fragmented_probe_should_mask(
client_side.shutdown().await.unwrap(); client_side.shutdown().await.unwrap();
let mut observed = vec![0u8; backend_reply.len()]; let mut observed = vec![0u8; backend_reply.len()];
tokio::time::timeout(Duration::from_secs(2), client_side.read_exact(&mut observed)) tokio::time::timeout(
Duration::from_secs(2),
client_side.read_exact(&mut observed),
)
.await .await
.unwrap() .unwrap()
.unwrap(); .unwrap();
@ -311,7 +319,10 @@ async fn blackhat_truncated_in_range_clienthello_generic_stream_should_mask() {
// Security expectation: even malformed in-range TLS should be masked. // Security expectation: even malformed in-range TLS should be masked.
// This invariant must hold to avoid probe-distinguishable EOF/timeout behavior. // This invariant must hold to avoid probe-distinguishable EOF/timeout behavior.
let mut observed = vec![0u8; backend_reply.len()]; let mut observed = vec![0u8; backend_reply.len()];
tokio::time::timeout(Duration::from_secs(2), client_side.read_exact(&mut observed)) tokio::time::timeout(
Duration::from_secs(2),
client_side.read_exact(&mut observed),
)
.await .await
.unwrap() .unwrap()
.unwrap(); .unwrap();

View File

@ -2,16 +2,11 @@ use super::*;
use crate::config::{UpstreamConfig, UpstreamType}; use crate::config::{UpstreamConfig, UpstreamType};
use crate::crypto::sha256_hmac; use crate::crypto::sha256_hmac;
use crate::protocol::constants::{ use crate::protocol::constants::{
HANDSHAKE_LEN, HANDSHAKE_LEN, MAX_TLS_CIPHERTEXT_SIZE, TLS_RECORD_ALERT, TLS_RECORD_APPLICATION,
MAX_TLS_CIPHERTEXT_SIZE, TLS_RECORD_CHANGE_CIPHER, TLS_RECORD_HANDSHAKE, TLS_VERSION,
TLS_RECORD_ALERT,
TLS_RECORD_APPLICATION,
TLS_RECORD_CHANGE_CIPHER,
TLS_RECORD_HANDSHAKE,
TLS_VERSION,
}; };
use crate::protocol::tls; use crate::protocol::tls;
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::net::TcpListener; use tokio::net::TcpListener;
struct PipelineHarness { struct PipelineHarness {
@ -74,7 +69,10 @@ fn build_harness(secret_hex: &str, mask_port: u16) -> PipelineHarness {
} }
fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec<u8> { fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec<u8> {
assert!(tls_len <= u16::MAX as usize, "TLS length must fit into record header"); assert!(
tls_len <= u16::MAX as usize,
"TLS length must fit into record header"
);
let total_len = 5 + tls_len; let total_len = 5 + tls_len;
let mut handshake = vec![fill; total_len]; let mut handshake = vec![fill; total_len];
@ -181,11 +179,17 @@ async fn tls_bad_mtproto_fallback_preserves_wire_and_backend_response() {
client_side.write_all(&client_hello).await.unwrap(); client_side.write_all(&client_hello).await.unwrap();
let mut tls_response_head = [0u8; 5]; let mut tls_response_head = [0u8; 5];
client_side.read_exact(&mut tls_response_head).await.unwrap(); client_side
.read_exact(&mut tls_response_head)
.await
.unwrap();
assert_eq!(tls_response_head[0], 0x16); assert_eq!(tls_response_head[0], 0x16);
read_and_discard_tls_record_body(&mut client_side, tls_response_head).await; read_and_discard_tls_record_body(&mut client_side, tls_response_head).await;
client_side.write_all(&invalid_mtproto_record).await.unwrap(); client_side
.write_all(&invalid_mtproto_record)
.await
.unwrap();
client_side.write_all(&trailing_record).await.unwrap(); client_side.write_all(&trailing_record).await.unwrap();
tokio::time::timeout(Duration::from_secs(3), accept_task) tokio::time::timeout(Duration::from_secs(3), accept_task)
@ -246,10 +250,16 @@ async fn tls_bad_mtproto_fallback_keeps_connects_bad_accounting() {
client_side.write_all(&client_hello).await.unwrap(); client_side.write_all(&client_hello).await.unwrap();
let mut tls_response_head = [0u8; 5]; let mut tls_response_head = [0u8; 5];
client_side.read_exact(&mut tls_response_head).await.unwrap(); client_side
.read_exact(&mut tls_response_head)
.await
.unwrap();
assert_eq!(tls_response_head[0], 0x16); assert_eq!(tls_response_head[0], 0x16);
client_side.write_all(&invalid_mtproto_record).await.unwrap(); client_side
.write_all(&invalid_mtproto_record)
.await
.unwrap();
client_side.write_all(&trailing_record).await.unwrap(); client_side.write_all(&trailing_record).await.unwrap();
tokio::time::timeout(Duration::from_secs(3), accept_task) tokio::time::timeout(Duration::from_secs(3), accept_task)
@ -264,7 +274,11 @@ async fn tls_bad_mtproto_fallback_keeps_connects_bad_accounting() {
.unwrap(); .unwrap();
let bad_after = stats_for_assert.get_connects_bad(); let bad_after = stats_for_assert.get_connects_bad();
assert_eq!(bad_after, bad_before + 1, "connects_bad must increase exactly once for invalid MTProto after valid TLS"); assert_eq!(
bad_after,
bad_before + 1,
"connects_bad must increase exactly once for invalid MTProto after valid TLS"
);
} }
#[tokio::test] #[tokio::test]
@ -310,10 +324,16 @@ async fn tls_bad_mtproto_fallback_forwards_zero_length_tls_record_verbatim() {
client_side.write_all(&client_hello).await.unwrap(); client_side.write_all(&client_hello).await.unwrap();
let mut tls_response_head = [0u8; 5]; let mut tls_response_head = [0u8; 5];
client_side.read_exact(&mut tls_response_head).await.unwrap(); client_side
.read_exact(&mut tls_response_head)
.await
.unwrap();
assert_eq!(tls_response_head[0], 0x16); assert_eq!(tls_response_head[0], 0x16);
client_side.write_all(&invalid_mtproto_record).await.unwrap(); client_side
.write_all(&invalid_mtproto_record)
.await
.unwrap();
client_side.write_all(&trailing_record).await.unwrap(); client_side.write_all(&trailing_record).await.unwrap();
tokio::time::timeout(Duration::from_secs(3), accept_task) tokio::time::timeout(Duration::from_secs(3), accept_task)
@ -372,10 +392,16 @@ async fn tls_bad_mtproto_fallback_forwards_max_tls_record_verbatim() {
client_side.write_all(&client_hello).await.unwrap(); client_side.write_all(&client_hello).await.unwrap();
let mut tls_response_head = [0u8; 5]; let mut tls_response_head = [0u8; 5];
client_side.read_exact(&mut tls_response_head).await.unwrap(); client_side
.read_exact(&mut tls_response_head)
.await
.unwrap();
assert_eq!(tls_response_head[0], 0x16); assert_eq!(tls_response_head[0], 0x16);
client_side.write_all(&invalid_mtproto_record).await.unwrap(); client_side
.write_all(&invalid_mtproto_record)
.await
.unwrap();
client_side.write_all(&trailing_record).await.unwrap(); client_side.write_all(&trailing_record).await.unwrap();
tokio::time::timeout(Duration::from_secs(3), accept_task) tokio::time::timeout(Duration::from_secs(3), accept_task)
@ -399,7 +425,8 @@ async fn tls_bad_mtproto_fallback_light_fuzz_tls_record_lengths_verbatim() {
let backend_addr = listener.local_addr().unwrap(); let backend_addr = listener.local_addr().unwrap();
let secret = [0x85u8; 16]; let secret = [0x85u8; 16];
let client_hello = make_valid_tls_client_hello(&secret, idx as u32 + 4, 600, 0x46 + idx as u8); let client_hello =
make_valid_tls_client_hello(&secret, idx as u32 + 4, 600, 0x46 + idx as u8);
let invalid_mtproto = vec![0u8; HANDSHAKE_LEN]; let invalid_mtproto = vec![0u8; HANDSHAKE_LEN];
let invalid_mtproto_record = wrap_tls_application_data(&invalid_mtproto); let invalid_mtproto_record = wrap_tls_application_data(&invalid_mtproto);
@ -443,10 +470,16 @@ async fn tls_bad_mtproto_fallback_light_fuzz_tls_record_lengths_verbatim() {
client_side.write_all(&client_hello).await.unwrap(); client_side.write_all(&client_hello).await.unwrap();
let mut tls_response_head = [0u8; 5]; let mut tls_response_head = [0u8; 5];
client_side.read_exact(&mut tls_response_head).await.unwrap(); client_side
.read_exact(&mut tls_response_head)
.await
.unwrap();
assert_eq!(tls_response_head[0], 0x16); assert_eq!(tls_response_head[0], 0x16);
client_side.write_all(&invalid_mtproto_record).await.unwrap(); client_side
.write_all(&invalid_mtproto_record)
.await
.unwrap();
client_side.write_all(&trailing_record).await.unwrap(); client_side.write_all(&trailing_record).await.unwrap();
tokio::time::timeout(Duration::from_secs(3), accept_task) tokio::time::timeout(Duration::from_secs(3), accept_task)
@ -498,7 +531,10 @@ async fn tls_bad_mtproto_fallback_concurrent_sessions_are_isolated() {
); );
} }
assert!(remaining.is_empty(), "all expected client sessions must be matched exactly once"); assert!(
remaining.is_empty(),
"all expected client sessions must be matched exactly once"
);
}); });
let mut client_tasks = Vec::with_capacity(sessions); let mut client_tasks = Vec::with_capacity(sessions);
@ -506,7 +542,8 @@ async fn tls_bad_mtproto_fallback_concurrent_sessions_are_isolated() {
for idx in 0..sessions { for idx in 0..sessions {
let harness = build_harness("86868686868686868686868686868686", backend_addr.port()); let harness = build_harness("86868686868686868686868686868686", backend_addr.port());
let secret = [0x86u8; 16]; let secret = [0x86u8; 16];
let client_hello = make_valid_tls_client_hello(&secret, idx as u32 + 100, 600, 0x60 + idx as u8); let client_hello =
make_valid_tls_client_hello(&secret, idx as u32 + 100, 600, 0x60 + idx as u8);
let invalid_mtproto = vec![0u8; HANDSHAKE_LEN]; let invalid_mtproto = vec![0u8; HANDSHAKE_LEN];
let invalid_mtproto_record = wrap_tls_application_data(&invalid_mtproto); let invalid_mtproto_record = wrap_tls_application_data(&invalid_mtproto);
let trailing_payload = vec![idx as u8; 64 + idx]; let trailing_payload = vec![idx as u8; 64 + idx];
@ -538,10 +575,16 @@ async fn tls_bad_mtproto_fallback_concurrent_sessions_are_isolated() {
client_side.write_all(&client_hello).await.unwrap(); client_side.write_all(&client_hello).await.unwrap();
let mut tls_response_head = [0u8; 5]; let mut tls_response_head = [0u8; 5];
client_side.read_exact(&mut tls_response_head).await.unwrap(); client_side
.read_exact(&mut tls_response_head)
.await
.unwrap();
assert_eq!(tls_response_head[0], 0x16); assert_eq!(tls_response_head[0], 0x16);
client_side.write_all(&invalid_mtproto_record).await.unwrap(); client_side
.write_all(&invalid_mtproto_record)
.await
.unwrap();
client_side.write_all(&trailing_record).await.unwrap(); client_side.write_all(&trailing_record).await.unwrap();
drop(client_side); drop(client_side);
@ -606,10 +649,16 @@ async fn tls_bad_mtproto_fallback_forwards_fragmented_client_writes_verbatim() {
client_side.write_all(&client_hello).await.unwrap(); client_side.write_all(&client_hello).await.unwrap();
let mut tls_response_head = [0u8; 5]; let mut tls_response_head = [0u8; 5];
client_side.read_exact(&mut tls_response_head).await.unwrap(); client_side
.read_exact(&mut tls_response_head)
.await
.unwrap();
assert_eq!(tls_response_head[0], 0x16); assert_eq!(tls_response_head[0], 0x16);
client_side.write_all(&invalid_mtproto_record).await.unwrap(); client_side
.write_all(&invalid_mtproto_record)
.await
.unwrap();
for chunk in trailing_record.chunks(3) { for chunk in trailing_record.chunks(3) {
client_side.write_all(chunk).await.unwrap(); client_side.write_all(chunk).await.unwrap();
@ -669,10 +718,16 @@ async fn tls_bad_mtproto_fallback_header_fragmentation_bytewise_is_verbatim() {
client_side.write_all(&client_hello).await.unwrap(); client_side.write_all(&client_hello).await.unwrap();
let mut tls_response_head = [0u8; 5]; let mut tls_response_head = [0u8; 5];
client_side.read_exact(&mut tls_response_head).await.unwrap(); client_side
.read_exact(&mut tls_response_head)
.await
.unwrap();
assert_eq!(tls_response_head[0], 0x16); assert_eq!(tls_response_head[0], 0x16);
client_side.write_all(&invalid_mtproto_record).await.unwrap(); client_side
.write_all(&invalid_mtproto_record)
.await
.unwrap();
for b in trailing_record.iter().copied() { for b in trailing_record.iter().copied() {
client_side.write_all(&[b]).await.unwrap(); client_side.write_all(&[b]).await.unwrap();
} }
@ -736,10 +791,16 @@ async fn tls_bad_mtproto_fallback_record_splitting_chaos_is_verbatim() {
client_side.write_all(&client_hello).await.unwrap(); client_side.write_all(&client_hello).await.unwrap();
let mut tls_response_head = [0u8; 5]; let mut tls_response_head = [0u8; 5];
client_side.read_exact(&mut tls_response_head).await.unwrap(); client_side
.read_exact(&mut tls_response_head)
.await
.unwrap();
assert_eq!(tls_response_head[0], 0x16); assert_eq!(tls_response_head[0], 0x16);
client_side.write_all(&invalid_mtproto_record).await.unwrap(); client_side
.write_all(&invalid_mtproto_record)
.await
.unwrap();
let chaos = [7usize, 1, 19, 3, 5, 31, 2, 11, 13, 17]; let chaos = [7usize, 1, 19, 3, 5, 31, 2, 11, 13, 17];
let mut pos = 0usize; let mut pos = 0usize;
@ -747,7 +808,10 @@ async fn tls_bad_mtproto_fallback_record_splitting_chaos_is_verbatim() {
while pos < trailing_record.len() { while pos < trailing_record.len() {
let step = chaos[idx % chaos.len()]; let step = chaos[idx % chaos.len()];
let end = (pos + step).min(trailing_record.len()); let end = (pos + step).min(trailing_record.len());
client_side.write_all(&trailing_record[pos..end]).await.unwrap(); client_side
.write_all(&trailing_record[pos..end])
.await
.unwrap();
pos = end; pos = end;
idx += 1; idx += 1;
} }
@ -809,10 +873,16 @@ async fn tls_bad_mtproto_fallback_multiple_tls_records_are_forwarded_in_order()
client_side.write_all(&client_hello).await.unwrap(); client_side.write_all(&client_hello).await.unwrap();
let mut tls_response_head = [0u8; 5]; let mut tls_response_head = [0u8; 5];
client_side.read_exact(&mut tls_response_head).await.unwrap(); client_side
.read_exact(&mut tls_response_head)
.await
.unwrap();
assert_eq!(tls_response_head[0], 0x16); assert_eq!(tls_response_head[0], 0x16);
client_side.write_all(&invalid_mtproto_record).await.unwrap(); client_side
.write_all(&invalid_mtproto_record)
.await
.unwrap();
client_side.write_all(&r1).await.unwrap(); client_side.write_all(&r1).await.unwrap();
client_side.write_all(&r2).await.unwrap(); client_side.write_all(&r2).await.unwrap();
client_side.write_all(&r3).await.unwrap(); client_side.write_all(&r3).await.unwrap();
@ -848,7 +918,10 @@ async fn tls_bad_mtproto_fallback_client_half_close_propagates_eof_to_backend()
let mut tail = [0u8; 1]; let mut tail = [0u8; 1];
let n = stream.read(&mut tail).await.unwrap(); let n = stream.read(&mut tail).await.unwrap();
assert_eq!(n, 0, "backend must observe EOF after client write half-close"); assert_eq!(
n, 0,
"backend must observe EOF after client write half-close"
);
}); });
let harness = build_harness("8b8b8b8b8b8b8b8b8b8b8b8b8b8b8b8b", backend_addr.port()); let harness = build_harness("8b8b8b8b8b8b8b8b8b8b8b8b8b8b8b8b", backend_addr.port());
@ -874,10 +947,16 @@ async fn tls_bad_mtproto_fallback_client_half_close_propagates_eof_to_backend()
client_side.write_all(&client_hello).await.unwrap(); client_side.write_all(&client_hello).await.unwrap();
let mut tls_response_head = [0u8; 5]; let mut tls_response_head = [0u8; 5];
client_side.read_exact(&mut tls_response_head).await.unwrap(); client_side
.read_exact(&mut tls_response_head)
.await
.unwrap();
assert_eq!(tls_response_head[0], 0x16); assert_eq!(tls_response_head[0], 0x16);
client_side.write_all(&invalid_mtproto_record).await.unwrap(); client_side
.write_all(&invalid_mtproto_record)
.await
.unwrap();
client_side.write_all(&trailing_record).await.unwrap(); client_side.write_all(&trailing_record).await.unwrap();
client_side.shutdown().await.unwrap(); client_side.shutdown().await.unwrap();
@ -938,11 +1017,17 @@ async fn tls_bad_mtproto_fallback_backend_half_close_after_response_is_tolerated
client_side.write_all(&client_hello).await.unwrap(); client_side.write_all(&client_hello).await.unwrap();
let mut tls_response_head = [0u8; 5]; let mut tls_response_head = [0u8; 5];
client_side.read_exact(&mut tls_response_head).await.unwrap(); client_side
.read_exact(&mut tls_response_head)
.await
.unwrap();
assert_eq!(tls_response_head[0], 0x16); assert_eq!(tls_response_head[0], 0x16);
read_and_discard_tls_record_body(&mut client_side, tls_response_head).await; read_and_discard_tls_record_body(&mut client_side, tls_response_head).await;
client_side.write_all(&invalid_mtproto_record).await.unwrap(); client_side
.write_all(&invalid_mtproto_record)
.await
.unwrap();
client_side.write_all(&trailing_record).await.unwrap(); client_side.write_all(&trailing_record).await.unwrap();
tokio::time::timeout(Duration::from_secs(3), accept_task) tokio::time::timeout(Duration::from_secs(3), accept_task)
@ -994,10 +1079,16 @@ async fn tls_bad_mtproto_fallback_backend_reset_after_clienthello_is_handled() {
client_side.write_all(&client_hello).await.unwrap(); client_side.write_all(&client_hello).await.unwrap();
let mut tls_response_head = [0u8; 5]; let mut tls_response_head = [0u8; 5];
client_side.read_exact(&mut tls_response_head).await.unwrap(); client_side
.read_exact(&mut tls_response_head)
.await
.unwrap();
assert_eq!(tls_response_head[0], 0x16); assert_eq!(tls_response_head[0], 0x16);
client_side.write_all(&invalid_mtproto_record).await.unwrap(); client_side
.write_all(&invalid_mtproto_record)
.await
.unwrap();
let write_res = client_side.write_all(&trailing_record).await; let write_res = client_side.write_all(&trailing_record).await;
assert!( assert!(
write_res.is_ok() || write_res.is_err(), write_res.is_ok() || write_res.is_err(),
@ -1068,10 +1159,16 @@ async fn tls_bad_mtproto_fallback_backend_slow_reader_preserves_byte_identity()
client_side.write_all(&client_hello).await.unwrap(); client_side.write_all(&client_hello).await.unwrap();
let mut tls_response_head = [0u8; 5]; let mut tls_response_head = [0u8; 5];
client_side.read_exact(&mut tls_response_head).await.unwrap(); client_side
.read_exact(&mut tls_response_head)
.await
.unwrap();
assert_eq!(tls_response_head[0], 0x16); assert_eq!(tls_response_head[0], 0x16);
client_side.write_all(&invalid_mtproto_record).await.unwrap(); client_side
.write_all(&invalid_mtproto_record)
.await
.unwrap();
client_side.write_all(&trailing_record).await.unwrap(); client_side.write_all(&trailing_record).await.unwrap();
tokio::time::timeout(Duration::from_secs(5), accept_task) tokio::time::timeout(Duration::from_secs(5), accept_task)
@ -1152,7 +1249,10 @@ async fn tls_bad_mtproto_fallback_replay_pressure_masks_replay_without_serverhel
let mut head = [0u8; 5]; let mut head = [0u8; 5];
client_side.read_exact(&mut head).await.unwrap(); client_side.read_exact(&mut head).await.unwrap();
assert_eq!(head[0], 0x16); assert_eq!(head[0], 0x16);
client_side.write_all(&invalid_mtproto_record).await.unwrap(); client_side
.write_all(&invalid_mtproto_record)
.await
.unwrap();
client_side.write_all(&trailing_record).await.unwrap(); client_side.write_all(&trailing_record).await.unwrap();
} else { } else {
let mut one = [0u8; 1]; let mut one = [0u8; 1];
@ -1241,10 +1341,16 @@ async fn tls_bad_mtproto_fallback_large_multi_record_chaos_under_backpressure()
client_side.write_all(&client_hello).await.unwrap(); client_side.write_all(&client_hello).await.unwrap();
let mut tls_response_head = [0u8; 5]; let mut tls_response_head = [0u8; 5];
client_side.read_exact(&mut tls_response_head).await.unwrap(); client_side
.read_exact(&mut tls_response_head)
.await
.unwrap();
assert_eq!(tls_response_head[0], 0x16); assert_eq!(tls_response_head[0], 0x16);
client_side.write_all(&invalid_mtproto_record).await.unwrap(); client_side
.write_all(&invalid_mtproto_record)
.await
.unwrap();
let chaos = [5usize, 23, 11, 47, 3, 19, 29, 13, 7, 31]; let chaos = [5usize, 23, 11, 47, 3, 19, 29, 13, 7, 31];
for record in [&a, &b, &c] { for record in [&a, &b, &c] {
@ -1316,10 +1422,16 @@ async fn tls_bad_mtproto_fallback_interleaved_control_and_application_records_ve
client_side.write_all(&client_hello).await.unwrap(); client_side.write_all(&client_hello).await.unwrap();
let mut tls_response_head = [0u8; 5]; let mut tls_response_head = [0u8; 5];
client_side.read_exact(&mut tls_response_head).await.unwrap(); client_side
.read_exact(&mut tls_response_head)
.await
.unwrap();
assert_eq!(tls_response_head[0], 0x16); assert_eq!(tls_response_head[0], 0x16);
client_side.write_all(&invalid_mtproto_record).await.unwrap(); client_side
.write_all(&invalid_mtproto_record)
.await
.unwrap();
client_side.write_all(&ccs).await.unwrap(); client_side.write_all(&ccs).await.unwrap();
client_side.write_all(&app).await.unwrap(); client_side.write_all(&app).await.unwrap();
client_side.write_all(&alert).await.unwrap(); client_side.write_all(&alert).await.unwrap();
@ -1372,7 +1484,10 @@ async fn tls_bad_mtproto_fallback_many_short_sessions_with_chaos_no_cross_leak()
); );
} }
assert!(remaining.is_empty(), "all expected sessions must be consumed exactly once"); assert!(
remaining.is_empty(),
"all expected sessions must be consumed exactly once"
);
}); });
let mut tasks = Vec::with_capacity(sessions); let mut tasks = Vec::with_capacity(sessions);
@ -1413,7 +1528,10 @@ async fn tls_bad_mtproto_fallback_many_short_sessions_with_chaos_no_cross_leak()
client_side.read_exact(&mut head).await.unwrap(); client_side.read_exact(&mut head).await.unwrap();
assert_eq!(head[0], 0x16); assert_eq!(head[0], 0x16);
client_side.write_all(&invalid_mtproto_record).await.unwrap(); client_side
.write_all(&invalid_mtproto_record)
.await
.unwrap();
for chunk in record.chunks((idx % 9) + 1) { for chunk in record.chunks((idx % 9) + 1) {
client_side.write_all(chunk).await.unwrap(); client_side.write_all(chunk).await.unwrap();
} }
@ -2520,7 +2638,10 @@ async fn blackhat_coalesced_tail_parallel_32_sessions_no_cross_bleed() {
"session mixup detected in parallel-32 blackhat test" "session mixup detected in parallel-32 blackhat test"
); );
} }
assert!(remaining.is_empty(), "all expected sessions must be consumed"); assert!(
remaining.is_empty(),
"all expected sessions must be consumed"
);
}); });
let mut tasks = Vec::with_capacity(sessions); let mut tasks = Vec::with_capacity(sessions);

View File

@ -0,0 +1,37 @@
use super::*;
#[test]
fn wrap_tls_application_record_empty_payload_emits_zero_length_record() {
let record = wrap_tls_application_record(&[]);
assert_eq!(record.len(), 5);
assert_eq!(record[0], TLS_RECORD_APPLICATION);
assert_eq!(&record[1..3], &TLS_VERSION);
assert_eq!(&record[3..5], &0u16.to_be_bytes());
}
#[test]
fn wrap_tls_application_record_oversized_payload_is_chunked_without_truncation() {
let total = (u16::MAX as usize) + 37;
let payload = vec![0xA5u8; total];
let record = wrap_tls_application_record(&payload);
let mut offset = 0usize;
let mut recovered = Vec::with_capacity(total);
let mut frames = 0usize;
while offset + 5 <= record.len() {
assert_eq!(record[offset], TLS_RECORD_APPLICATION);
assert_eq!(&record[offset + 1..offset + 3], &TLS_VERSION);
let len = u16::from_be_bytes([record[offset + 3], record[offset + 4]]) as usize;
let body_start = offset + 5;
let body_end = body_start + len;
assert!(body_end <= record.len(), "declared TLS record length must be in-bounds");
recovered.extend_from_slice(&record[body_start..body_end]);
offset = body_end;
frames += 1;
}
assert_eq!(offset, record.len(), "record parser must consume exact output size");
assert_eq!(frames, 2, "oversized payload should split into exactly two records");
assert_eq!(recovered, payload, "chunked records must preserve full payload");
}

View File

@ -5,7 +5,10 @@ use std::net::SocketAddr;
#[test] #[test]
fn business_scope_hint_accepts_exact_boundary_length() { fn business_scope_hint_accepts_exact_boundary_length() {
let value = format!("scope_{}", "a".repeat(MAX_SCOPE_HINT_LEN)); let value = format!("scope_{}", "a".repeat(MAX_SCOPE_HINT_LEN));
assert_eq!(validated_scope_hint(&value), Some("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa")); assert_eq!(
validated_scope_hint(&value),
Some("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa")
);
} }
#[test] #[test]
@ -24,7 +27,8 @@ fn business_known_dc_uses_ipv4_table_by_default() {
#[test] #[test]
fn business_negative_dc_maps_by_absolute_value() { fn business_negative_dc_maps_by_absolute_value() {
let cfg = ProxyConfig::default(); let cfg = ProxyConfig::default();
let resolved = get_dc_addr_static(-3, &cfg).expect("negative dc index must map by absolute value"); let resolved =
get_dc_addr_static(-3, &cfg).expect("negative dc index must map by absolute value");
let expected = SocketAddr::new(TG_DATACENTERS_V4[2], TG_DATACENTER_PORT); let expected = SocketAddr::new(TG_DATACENTERS_V4[2], TG_DATACENTER_PORT);
assert_eq!(resolved, expected); assert_eq!(resolved, expected);
} }
@ -45,7 +49,8 @@ fn business_unknown_dc_uses_configured_default_dc_when_in_range() {
let mut cfg = ProxyConfig::default(); let mut cfg = ProxyConfig::default();
cfg.default_dc = Some(4); cfg.default_dc = Some(4);
let resolved = get_dc_addr_static(29_999, &cfg).expect("unknown dc must resolve to configured default"); let resolved =
get_dc_addr_static(29_999, &cfg).expect("unknown dc must resolve to configured default");
let expected = SocketAddr::new(TG_DATACENTERS_V4[3], TG_DATACENTER_PORT); let expected = SocketAddr::new(TG_DATACENTERS_V4[3], TG_DATACENTER_PORT);
assert_eq!(resolved, expected); assert_eq!(resolved, expected);
} }

View File

@ -12,7 +12,8 @@ fn common_invalid_override_entries_fallback_to_static_table() {
vec!["bad-address".to_string(), "still-bad".to_string()], vec!["bad-address".to_string(), "still-bad".to_string()],
); );
let resolved = get_dc_addr_static(2, &cfg).expect("fallback to static table must still resolve"); let resolved =
get_dc_addr_static(2, &cfg).expect("fallback to static table must still resolve");
let expected = SocketAddr::new(TG_DATACENTERS_V4[1], TG_DATACENTER_PORT); let expected = SocketAddr::new(TG_DATACENTERS_V4[1], TG_DATACENTER_PORT);
assert_eq!(resolved, expected); assert_eq!(resolved, expected);
} }
@ -25,7 +26,8 @@ fn common_prefer_v6_with_only_ipv4_override_uses_override_instead_of_ignoring_it
cfg.dc_overrides cfg.dc_overrides
.insert("3".to_string(), vec!["203.0.113.203:443".to_string()]); .insert("3".to_string(), vec!["203.0.113.203:443".to_string()]);
let resolved = get_dc_addr_static(3, &cfg).expect("ipv4 override must be used if no ipv6 override exists"); let resolved =
get_dc_addr_static(3, &cfg).expect("ipv4 override must be used if no ipv6 override exists");
assert_eq!(resolved, "203.0.113.203:443".parse::<SocketAddr>().unwrap()); assert_eq!(resolved, "203.0.113.203:443".parse::<SocketAddr>().unwrap());
} }

View File

@ -15,7 +15,7 @@ use std::time::Duration;
use tokio::io::AsyncReadExt; use tokio::io::AsyncReadExt;
use tokio::io::duplex; use tokio::io::duplex;
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio::time::{timeout, Duration as TokioDuration}; use tokio::time::{Duration as TokioDuration, timeout};
fn make_crypto_reader<R>(reader: R) -> CryptoReader<R> fn make_crypto_reader<R>(reader: R) -> CryptoReader<R>
where where
@ -79,7 +79,9 @@ fn unknown_dc_log_respects_distinct_limit() {
#[test] #[test]
fn unknown_dc_log_fails_closed_when_dedup_lock_is_poisoned() { fn unknown_dc_log_fails_closed_when_dedup_lock_is_poisoned() {
let poisoned = Arc::new(std::sync::Mutex::new(std::collections::HashSet::<i16>::new())); let poisoned = Arc::new(std::sync::Mutex::new(
std::collections::HashSet::<i16>::new(),
));
let poisoned_for_thread = poisoned.clone(); let poisoned_for_thread = poisoned.clone();
let _ = std::thread::spawn(move || { let _ = std::thread::spawn(move || {
@ -243,7 +245,10 @@ fn unknown_dc_log_path_sanitizer_accepts_safe_relative_path() {
fs::create_dir_all(&base).expect("temp test directory must be creatable"); fs::create_dir_all(&base).expect("temp test directory must be creatable");
let candidate = base.join("unknown-dc.txt"); let candidate = base.join("unknown-dc.txt");
let candidate_relative = format!("target/telemt-unknown-dc-log-{}/unknown-dc.txt", std::process::id()); let candidate_relative = format!(
"target/telemt-unknown-dc-log-{}/unknown-dc.txt",
std::process::id()
);
let sanitized = sanitize_unknown_dc_log_path(&candidate_relative) let sanitized = sanitize_unknown_dc_log_path(&candidate_relative)
.expect("safe relative path with existing parent must be accepted"); .expect("safe relative path with existing parent must be accepted");
@ -325,7 +330,10 @@ fn unknown_dc_log_path_sanitizer_accepts_symlinked_parent_inside_workspace() {
let base = std::env::current_dir() let base = std::env::current_dir()
.expect("cwd must be available") .expect("cwd must be available")
.join("target") .join("target")
.join(format!("telemt-unknown-dc-log-symlink-internal-{}", std::process::id())); .join(format!(
"telemt-unknown-dc-log-symlink-internal-{}",
std::process::id()
));
let real_parent = base.join("real_parent"); let real_parent = base.join("real_parent");
fs::create_dir_all(&real_parent).expect("real parent dir must be creatable"); fs::create_dir_all(&real_parent).expect("real parent dir must be creatable");
@ -354,7 +362,10 @@ fn unknown_dc_log_path_sanitizer_accepts_symlink_parent_escape_as_canonical_path
let base = std::env::current_dir() let base = std::env::current_dir()
.expect("cwd must be available") .expect("cwd must be available")
.join("target") .join("target")
.join(format!("telemt-unknown-dc-log-symlink-{}", std::process::id())); .join(format!(
"telemt-unknown-dc-log-symlink-{}",
std::process::id()
));
fs::create_dir_all(&base).expect("symlink test directory must be creatable"); fs::create_dir_all(&base).expect("symlink test directory must be creatable");
let symlink_parent = base.join("escape_link"); let symlink_parent = base.join("escape_link");
@ -382,7 +393,10 @@ fn unknown_dc_log_path_revalidation_rejects_symlinked_target_escape() {
let base = std::env::current_dir() let base = std::env::current_dir()
.expect("cwd must be available") .expect("cwd must be available")
.join("target") .join("target")
.join(format!("telemt-unknown-dc-target-link-{}", std::process::id())); .join(format!(
"telemt-unknown-dc-target-link-{}",
std::process::id()
));
fs::create_dir_all(&base).expect("target-link base must be creatable"); fs::create_dir_all(&base).expect("target-link base must be creatable");
let outside = std::env::temp_dir().join(format!("telemt-outside-{}", std::process::id())); let outside = std::env::temp_dir().join(format!("telemt-outside-{}", std::process::id()));
@ -445,7 +459,10 @@ fn unknown_dc_open_append_rejects_broken_symlink_target_with_nofollow() {
let base = std::env::current_dir() let base = std::env::current_dir()
.expect("cwd must be available") .expect("cwd must be available")
.join("target") .join("target")
.join(format!("telemt-unknown-dc-broken-link-{}", std::process::id())); .join(format!(
"telemt-unknown-dc-broken-link-{}",
std::process::id()
));
fs::create_dir_all(&base).expect("broken-link base must be creatable"); fs::create_dir_all(&base).expect("broken-link base must be creatable");
let linked_target = base.join("unknown-dc.log"); let linked_target = base.join("unknown-dc.log");
@ -470,7 +487,10 @@ fn adversarial_unknown_dc_open_append_symlink_flip_never_writes_outside_file() {
let base = std::env::current_dir() let base = std::env::current_dir()
.expect("cwd must be available") .expect("cwd must be available")
.join("target") .join("target")
.join(format!("telemt-unknown-dc-symlink-flip-{}", std::process::id())); .join(format!(
"telemt-unknown-dc-symlink-flip-{}",
std::process::id()
));
fs::create_dir_all(&base).expect("symlink-flip base must be creatable"); fs::create_dir_all(&base).expect("symlink-flip base must be creatable");
let outside = std::env::temp_dir().join(format!( let outside = std::env::temp_dir().join(format!(
@ -530,7 +550,10 @@ fn stress_unknown_dc_open_append_regular_file_preserves_line_integrity() {
let base = std::env::current_dir() let base = std::env::current_dir()
.expect("cwd must be available") .expect("cwd must be available")
.join("target") .join("target")
.join(format!("telemt-unknown-dc-open-stress-{}", std::process::id())); .join(format!(
"telemt-unknown-dc-open-stress-{}",
std::process::id()
));
fs::create_dir_all(&base).expect("stress open base must be creatable"); fs::create_dir_all(&base).expect("stress open base must be creatable");
let target = base.join("unknown-dc.log"); let target = base.join("unknown-dc.log");
@ -556,7 +579,10 @@ fn unknown_dc_log_path_revalidation_accepts_regular_existing_target() {
let base = std::env::current_dir() let base = std::env::current_dir()
.expect("cwd must be available") .expect("cwd must be available")
.join("target") .join("target")
.join(format!("telemt-unknown-dc-safe-target-{}", std::process::id())); .join(format!(
"telemt-unknown-dc-safe-target-{}",
std::process::id()
));
fs::create_dir_all(&base).expect("safe target base must be creatable"); fs::create_dir_all(&base).expect("safe target base must be creatable");
let target = base.join("unknown-dc.log"); let target = base.join("unknown-dc.log");
@ -566,8 +592,8 @@ fn unknown_dc_log_path_revalidation_accepts_regular_existing_target() {
"target/telemt-unknown-dc-safe-target-{}/unknown-dc.log", "target/telemt-unknown-dc-safe-target-{}/unknown-dc.log",
std::process::id() std::process::id()
); );
let sanitized = sanitize_unknown_dc_log_path(&rel_candidate) let sanitized =
.expect("safe candidate must sanitize"); sanitize_unknown_dc_log_path(&rel_candidate).expect("safe candidate must sanitize");
assert!( assert!(
unknown_dc_log_path_is_still_safe(&sanitized), unknown_dc_log_path_is_still_safe(&sanitized),
"revalidation must allow safe existing regular files" "revalidation must allow safe existing regular files"
@ -579,7 +605,10 @@ fn unknown_dc_log_path_revalidation_rejects_deleted_parent_after_sanitize() {
let base = std::env::current_dir() let base = std::env::current_dir()
.expect("cwd must be available") .expect("cwd must be available")
.join("target") .join("target")
.join(format!("telemt-unknown-dc-vanish-parent-{}", std::process::id())); .join(format!(
"telemt-unknown-dc-vanish-parent-{}",
std::process::id()
));
fs::create_dir_all(&base).expect("vanish-parent base must be creatable"); fs::create_dir_all(&base).expect("vanish-parent base must be creatable");
let rel_candidate = format!( let rel_candidate = format!(
@ -604,7 +633,10 @@ fn unknown_dc_log_path_revalidation_rejects_parent_swapped_to_symlink() {
let parent = std::env::current_dir() let parent = std::env::current_dir()
.expect("cwd must be available") .expect("cwd must be available")
.join("target") .join("target")
.join(format!("telemt-unknown-dc-parent-swap-{}", std::process::id())); .join(format!(
"telemt-unknown-dc-parent-swap-{}",
std::process::id()
));
fs::create_dir_all(&parent).expect("parent-swap test parent must be creatable"); fs::create_dir_all(&parent).expect("parent-swap test parent must be creatable");
let rel_candidate = format!( let rel_candidate = format!(
@ -633,7 +665,10 @@ fn adversarial_check_then_symlink_flip_is_blocked_by_nofollow_open() {
let parent = std::env::current_dir() let parent = std::env::current_dir()
.expect("cwd must be available") .expect("cwd must be available")
.join("target") .join("target")
.join(format!("telemt-unknown-dc-check-open-race-{}", std::process::id())); .join(format!(
"telemt-unknown-dc-check-open-race-{}",
std::process::id()
));
fs::create_dir_all(&parent).expect("check-open-race parent must be creatable"); fs::create_dir_all(&parent).expect("check-open-race parent must be creatable");
let target = parent.join("unknown-dc.log"); let target = parent.join("unknown-dc.log");
@ -642,8 +677,7 @@ fn adversarial_check_then_symlink_flip_is_blocked_by_nofollow_open() {
"target/telemt-unknown-dc-check-open-race-{}/unknown-dc.log", "target/telemt-unknown-dc-check-open-race-{}/unknown-dc.log",
std::process::id() std::process::id()
); );
let sanitized = sanitize_unknown_dc_log_path(&rel_candidate) let sanitized = sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize");
.expect("candidate must sanitize");
assert!( assert!(
unknown_dc_log_path_is_still_safe(&sanitized), unknown_dc_log_path_is_still_safe(&sanitized),
@ -675,7 +709,10 @@ fn adversarial_parent_swap_after_check_is_blocked_by_anchored_open() {
let base = std::env::current_dir() let base = std::env::current_dir()
.expect("cwd must be available") .expect("cwd must be available")
.join("target") .join("target")
.join(format!("telemt-unknown-dc-parent-swap-openat-{}", std::process::id())); .join(format!(
"telemt-unknown-dc-parent-swap-openat-{}",
std::process::id()
));
fs::create_dir_all(&base).expect("parent-swap-openat base must be creatable"); fs::create_dir_all(&base).expect("parent-swap-openat base must be creatable");
let rel_candidate = format!( let rel_candidate = format!(
@ -708,7 +745,10 @@ fn adversarial_parent_swap_after_check_is_blocked_by_anchored_open() {
.expect_err("anchored open must fail when parent is swapped to symlink"); .expect_err("anchored open must fail when parent is swapped to symlink");
let raw = err.raw_os_error(); let raw = err.raw_os_error();
assert!( assert!(
matches!(raw, Some(libc::ELOOP) | Some(libc::ENOTDIR) | Some(libc::ENOENT)), matches!(
raw,
Some(libc::ELOOP) | Some(libc::ENOTDIR) | Some(libc::ENOENT)
),
"anchored open must fail closed on parent swap race, got raw_os_error={raw:?}" "anchored open must fail closed on parent swap race, got raw_os_error={raw:?}"
); );
assert!( assert!(
@ -717,6 +757,284 @@ fn adversarial_parent_swap_after_check_is_blocked_by_anchored_open() {
); );
} }
#[cfg(unix)]
#[test]
fn anchored_open_nix_path_writes_expected_lines() {
let base = std::env::current_dir()
.expect("cwd must be available")
.join("target")
.join(format!(
"telemt-unknown-dc-anchored-open-ok-{}",
std::process::id()
));
fs::create_dir_all(&base).expect("anchored-open-ok base must be creatable");
let rel_candidate = format!(
"target/telemt-unknown-dc-anchored-open-ok-{}/unknown-dc.log",
std::process::id()
);
let sanitized =
sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize");
let _ = fs::remove_file(&sanitized.resolved_path);
let mut first = open_unknown_dc_log_append_anchored(&sanitized)
.expect("anchored open must create log file in allowed parent");
append_unknown_dc_line(&mut first, 31_200).expect("first append must succeed");
let mut second = open_unknown_dc_log_append_anchored(&sanitized)
.expect("anchored reopen must succeed for existing regular file");
append_unknown_dc_line(&mut second, 31_201).expect("second append must succeed");
let content =
fs::read_to_string(&sanitized.resolved_path).expect("anchored log file must be readable");
let lines: Vec<&str> = content.lines().filter(|line| !line.trim().is_empty()).collect();
assert_eq!(lines.len(), 2, "expected one line per anchored append call");
assert!(
lines.contains(&"dc_idx=31200") && lines.contains(&"dc_idx=31201"),
"anchored append output must contain both expected dc_idx lines"
);
}
#[cfg(unix)]
#[test]
fn anchored_open_parallel_appends_preserve_line_integrity() {
let base = std::env::current_dir()
.expect("cwd must be available")
.join("target")
.join(format!(
"telemt-unknown-dc-anchored-open-parallel-{}",
std::process::id()
));
fs::create_dir_all(&base).expect("anchored-open-parallel base must be creatable");
let rel_candidate = format!(
"target/telemt-unknown-dc-anchored-open-parallel-{}/unknown-dc.log",
std::process::id()
);
let sanitized =
sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize");
let _ = fs::remove_file(&sanitized.resolved_path);
let mut workers = Vec::new();
for idx in 0..64i16 {
let sanitized = sanitized.clone();
workers.push(std::thread::spawn(move || {
let mut file = open_unknown_dc_log_append_anchored(&sanitized)
.expect("anchored open must succeed in worker");
append_unknown_dc_line(&mut file, 32_000 + idx).expect("worker append must succeed");
}));
}
for worker in workers {
worker.join().expect("worker must not panic");
}
let content =
fs::read_to_string(&sanitized.resolved_path).expect("parallel log file must be readable");
let lines: Vec<&str> = content.lines().filter(|line| !line.trim().is_empty()).collect();
assert_eq!(lines.len(), 64, "expected one complete line per worker append");
for line in lines {
assert!(
line.starts_with("dc_idx="),
"line must keep dc_idx prefix and not be interleaved: {line}"
);
let value = line
.strip_prefix("dc_idx=")
.expect("prefix checked above")
.parse::<i16>();
assert!(
value.is_ok(),
"line payload must remain parseable i16 and not be corrupted: {line}"
);
}
}
#[cfg(unix)]
#[test]
fn anchored_open_creates_private_0600_file_permissions() {
use std::os::unix::fs::PermissionsExt;
let base = std::env::current_dir()
.expect("cwd must be available")
.join("target")
.join(format!(
"telemt-unknown-dc-anchored-perms-{}",
std::process::id()
));
fs::create_dir_all(&base).expect("anchored-perms base must be creatable");
let rel_candidate = format!(
"target/telemt-unknown-dc-anchored-perms-{}/unknown-dc.log",
std::process::id()
);
let sanitized =
sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize");
let _ = fs::remove_file(&sanitized.resolved_path);
let mut file = open_unknown_dc_log_append_anchored(&sanitized)
.expect("anchored open must create file with restricted mode");
append_unknown_dc_line(&mut file, 31_210).expect("initial append must succeed");
drop(file);
let mode = fs::metadata(&sanitized.resolved_path)
.expect("created log file metadata must be readable")
.permissions()
.mode()
& 0o777;
assert_eq!(
mode, 0o600,
"anchored open must create unknown-dc log file with owner-only rw permissions"
);
}
#[cfg(unix)]
#[test]
fn anchored_open_rejects_existing_symlink_target() {
use std::os::unix::fs::symlink;
let base = std::env::current_dir()
.expect("cwd must be available")
.join("target")
.join(format!(
"telemt-unknown-dc-anchored-symlink-target-{}",
std::process::id()
));
fs::create_dir_all(&base).expect("anchored-symlink-target base must be creatable");
let rel_candidate = format!(
"target/telemt-unknown-dc-anchored-symlink-target-{}/unknown-dc.log",
std::process::id()
);
let sanitized =
sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize");
let outside = std::env::temp_dir().join(format!(
"telemt-unknown-dc-anchored-symlink-outside-{}.log",
std::process::id()
));
fs::write(&outside, "outside\n").expect("outside baseline file must be writable");
let _ = fs::remove_file(&sanitized.resolved_path);
symlink(&outside, &sanitized.resolved_path)
.expect("target symlink for anchored-open rejection test must be creatable");
let err = open_unknown_dc_log_append_anchored(&sanitized)
.expect_err("anchored open must reject symlinked filename target");
assert_eq!(
err.raw_os_error(),
Some(libc::ELOOP),
"anchored open should fail closed with ELOOP on symlinked target"
);
}
#[cfg(unix)]
#[test]
fn anchored_open_high_contention_multi_write_preserves_complete_lines() {
let base = std::env::current_dir()
.expect("cwd must be available")
.join("target")
.join(format!(
"telemt-unknown-dc-anchored-contention-{}",
std::process::id()
));
fs::create_dir_all(&base).expect("anchored-contention base must be creatable");
let rel_candidate = format!(
"target/telemt-unknown-dc-anchored-contention-{}/unknown-dc.log",
std::process::id()
);
let sanitized =
sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize");
let _ = fs::remove_file(&sanitized.resolved_path);
let workers = 24usize;
let rounds = 40usize;
let mut threads = Vec::new();
for worker in 0..workers {
let sanitized = sanitized.clone();
threads.push(std::thread::spawn(move || {
for round in 0..rounds {
let mut file = open_unknown_dc_log_append_anchored(&sanitized)
.expect("anchored open must succeed under contention");
let dc_idx = 20_000i16.wrapping_add((worker * rounds + round) as i16);
append_unknown_dc_line(&mut file, dc_idx)
.expect("each contention append must complete");
}
}));
}
for thread in threads {
thread.join().expect("contention worker must not panic");
}
let content = fs::read_to_string(&sanitized.resolved_path)
.expect("contention output file must be readable");
let lines: Vec<&str> = content.lines().filter(|line| !line.trim().is_empty()).collect();
assert_eq!(
lines.len(),
workers * rounds,
"every contention append must produce exactly one line"
);
let mut unique = std::collections::HashSet::new();
for line in lines {
assert!(
line.starts_with("dc_idx="),
"line must preserve expected prefix under heavy contention: {line}"
);
let value = line
.strip_prefix("dc_idx=")
.expect("prefix validated")
.parse::<i16>()
.expect("line payload must remain parseable i16 under contention");
unique.insert(value);
}
assert_eq!(
unique.len(),
workers * rounds,
"contention output must not lose or duplicate logical writes"
);
}
#[cfg(unix)]
#[test]
fn append_unknown_dc_line_returns_error_for_read_only_descriptor() {
let base = std::env::current_dir()
.expect("cwd must be available")
.join("target")
.join(format!(
"telemt-unknown-dc-append-ro-{}",
std::process::id()
));
fs::create_dir_all(&base).expect("append-ro base must be creatable");
let rel_candidate = format!(
"target/telemt-unknown-dc-append-ro-{}/unknown-dc.log",
std::process::id()
);
let sanitized =
sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize");
fs::write(&sanitized.resolved_path, "seed\n").expect("seed file must be writable");
let mut readonly = std::fs::OpenOptions::new()
.read(true)
.open(&sanitized.resolved_path)
.expect("readonly file open must succeed");
append_unknown_dc_line(&mut readonly, 31_222)
.expect_err("append on readonly descriptor must fail closed");
let content_after =
fs::read_to_string(&sanitized.resolved_path).expect("seed file must remain readable");
assert_eq!(
nonempty_line_count(&content_after),
1,
"failed readonly append must not modify persisted unknown-dc log content"
);
}
#[tokio::test] #[tokio::test]
async fn unknown_dc_absolute_log_path_writes_one_entry() { async fn unknown_dc_absolute_log_path_writes_one_entry() {
let _guard = unknown_dc_test_lock() let _guard = unknown_dc_test_lock()
@ -896,7 +1214,10 @@ async fn unknown_dc_symlinked_target_escape_is_not_written_integration() {
let base = std::env::current_dir() let base = std::env::current_dir()
.expect("cwd must be available") .expect("cwd must be available")
.join("target") .join("target")
.join(format!("telemt-unknown-dc-no-write-link-{}", std::process::id())); .join(format!(
"telemt-unknown-dc-no-write-link-{}",
std::process::id()
));
fs::create_dir_all(&base).expect("integration symlink base must be creatable"); fs::create_dir_all(&base).expect("integration symlink base must be creatable");
let outside = std::env::temp_dir().join(format!( let outside = std::env::temp_dir().join(format!(
@ -1024,11 +1345,17 @@ async fn direct_relay_abort_midflight_releases_route_gauge() {
} }
}) })
.await; .await;
assert!(started.is_ok(), "direct relay must increment route gauge before abort"); assert!(
started.is_ok(),
"direct relay must increment route gauge before abort"
);
relay_task.abort(); relay_task.abort();
let joined = relay_task.await; let joined = relay_task.await;
assert!(joined.is_err(), "aborted direct relay task must return join error"); assert!(
joined.is_err(),
"aborted direct relay task must return join error"
);
tokio::time::sleep(Duration::from_millis(20)).await; tokio::time::sleep(Duration::from_millis(20)).await;
assert_eq!( assert_eq!(
@ -1313,15 +1640,22 @@ fn prefer_v6_override_matrix_prefers_matching_family_then_degrades_safely() {
], ],
); );
let a = get_dc_addr_static(dc_idx, &cfg_a).expect("v6+v4 override set must resolve"); let a = get_dc_addr_static(dc_idx, &cfg_a).expect("v6+v4 override set must resolve");
assert!(a.is_ipv6(), "prefer_v6 should choose v6 override when present"); assert!(
a.is_ipv6(),
"prefer_v6 should choose v6 override when present"
);
let mut cfg_b = ProxyConfig::default(); let mut cfg_b = ProxyConfig::default();
cfg_b.network.prefer = 6; cfg_b.network.prefer = 6;
cfg_b.network.ipv6 = Some(true); cfg_b.network.ipv6 = Some(true);
cfg_b.dc_overrides cfg_b
.dc_overrides
.insert(dc_idx.to_string(), vec!["203.0.113.91:443".to_string()]); .insert(dc_idx.to_string(), vec!["203.0.113.91:443".to_string()]);
let b = get_dc_addr_static(dc_idx, &cfg_b).expect("v4-only override must still resolve"); let b = get_dc_addr_static(dc_idx, &cfg_b).expect("v4-only override must still resolve");
assert!(b.is_ipv4(), "when no v6 override exists, v4 override must be used"); assert!(
b.is_ipv4(),
"when no v6 override exists, v4 override must be used"
);
let mut cfg_c = ProxyConfig::default(); let mut cfg_c = ProxyConfig::default();
cfg_c.network.prefer = 6; cfg_c.network.prefer = 6;
@ -1350,7 +1684,8 @@ fn prefer_v6_override_matrix_ignores_invalid_entries_and_keeps_fail_closed_fallb
], ],
); );
let addr = get_dc_addr_static(dc_idx, &cfg).expect("at least one valid override must keep resolution alive"); let addr = get_dc_addr_static(dc_idx, &cfg)
.expect("at least one valid override must keep resolution alive");
assert_eq!(addr, "203.0.113.55:443".parse::<SocketAddr>().unwrap()); assert_eq!(addr, "203.0.113.55:443".parse::<SocketAddr>().unwrap());
} }
@ -1370,7 +1705,10 @@ fn stress_prefer_v6_override_matrix_is_deterministic_under_mixed_inputs() {
let first = get_dc_addr_static(idx, &cfg).expect("first lookup must resolve"); let first = get_dc_addr_static(idx, &cfg).expect("first lookup must resolve");
let second = get_dc_addr_static(idx, &cfg).expect("second lookup must resolve"); let second = get_dc_addr_static(idx, &cfg).expect("second lookup must resolve");
assert_eq!(first, second, "override resolution must stay deterministic for dc {idx}"); assert_eq!(
first, second,
"override resolution must stay deterministic for dc {idx}"
);
assert!(first.is_ipv6(), "dc {idx}: v6 override should be preferred"); assert!(first.is_ipv6(), "dc {idx}: v6 override should be preferred");
} }
} }
@ -1397,7 +1735,9 @@ async fn negative_direct_relay_dc_connection_refused_fails_fast() {
drop(listener); drop(listener);
let mut config_with_override = ProxyConfig::default(); let mut config_with_override = ProxyConfig::default();
config_with_override.dc_overrides.insert("1".to_string(), vec![dc_addr.to_string()]); config_with_override
.dc_overrides
.insert("1".to_string(), vec![dc_addr.to_string()]);
let config = Arc::new(config_with_override); let config = Arc::new(config_with_override);
let upstream_manager = Arc::new(UpstreamManager::new( let upstream_manager = Arc::new(UpstreamManager::new(
@ -1485,7 +1825,9 @@ async fn adversarial_direct_relay_cutover_integrity() {
}); });
let mut config_with_override = ProxyConfig::default(); let mut config_with_override = ProxyConfig::default();
config_with_override.dc_overrides.insert("1".to_string(), vec![dc_addr.to_string()]); config_with_override
.dc_overrides
.insert("1".to_string(), vec![dc_addr.to_string()]);
let config = Arc::new(config_with_override); let config = Arc::new(config_with_override);
let upstream_manager = Arc::new(UpstreamManager::new( let upstream_manager = Arc::new(UpstreamManager::new(
@ -1534,7 +1876,8 @@ async fn adversarial_direct_relay_cutover_integrity() {
runtime_clone.subscribe(), runtime_clone.subscribe(),
runtime_clone.snapshot(), runtime_clone.snapshot(),
0xABCD_1234, 0xABCD_1234,
).await )
.await
}); });
timeout(TokioDuration::from_secs(2), async { timeout(TokioDuration::from_secs(2), async {

View File

@ -40,9 +40,7 @@ fn subtle_light_fuzz_scope_hint_matches_oracle() {
}; };
!rest.is_empty() !rest.is_empty()
&& rest.len() <= MAX_SCOPE_HINT_LEN && rest.len() <= MAX_SCOPE_HINT_LEN
&& rest && rest.bytes().all(|b| b.is_ascii_alphanumeric() || b == b'-')
.bytes()
.all(|b| b.is_ascii_alphanumeric() || b == b'-')
} }
let mut state: u64 = 0xC0FF_EE11_D15C_AFE5; let mut state: u64 = 0xC0FF_EE11_D15C_AFE5;
@ -94,7 +92,10 @@ fn subtle_light_fuzz_dc_resolution_never_panics_and_preserves_port() {
let dc_idx = (state as i16).wrapping_sub(16_384); let dc_idx = (state as i16).wrapping_sub(16_384);
let resolved = get_dc_addr_static(dc_idx, &cfg).expect("dc resolution must never fail"); let resolved = get_dc_addr_static(dc_idx, &cfg).expect("dc resolution must never fail");
assert_eq!(resolved.port(), crate::protocol::constants::TG_DATACENTER_PORT); assert_eq!(
resolved.port(),
crate::protocol::constants::TG_DATACENTER_PORT
);
let expect_v6 = cfg.network.prefer == 6 && cfg.network.ipv6.unwrap_or(true); let expect_v6 = cfg.network.prefer == 6 && cfg.network.ipv6.unwrap_or(true);
assert_eq!(resolved.is_ipv6(), expect_v6); assert_eq!(resolved.is_ipv6(), expect_v6);
} }
@ -166,7 +167,9 @@ async fn subtle_integration_parallel_unique_dcs_log_unique_lines() {
cfg.general.unknown_dc_log_path = Some(rel_file); cfg.general.unknown_dc_log_path = Some(rel_file);
let cfg = Arc::new(cfg); let cfg = Arc::new(cfg);
let dcs = [31_901_i16, 31_902, 31_903, 31_904, 31_905, 31_906, 31_907, 31_908]; let dcs = [
31_901_i16, 31_902, 31_903, 31_904, 31_905, 31_906, 31_907, 31_908,
];
let mut tasks = Vec::new(); let mut tasks = Vec::new();
for dc in dcs { for dc in dcs {

View File

@ -1,10 +1,14 @@
use super::*; use super::*;
use std::sync::Arc;
use std::net::{IpAddr, Ipv4Addr};
use std::time::{Duration, Instant};
use crate::crypto::sha256; use crate::crypto::sha256;
use std::net::{IpAddr, Ipv4Addr};
use std::sync::Arc;
use std::time::{Duration, Instant};
fn make_valid_mtproto_handshake(secret_hex: &str, proto_tag: ProtoTag, dc_idx: i16) -> [u8; HANDSHAKE_LEN] { fn make_valid_mtproto_handshake(
secret_hex: &str,
proto_tag: ProtoTag,
dc_idx: i16,
) -> [u8; HANDSHAKE_LEN] {
let secret = hex::decode(secret_hex).expect("secret hex must decode"); let secret = hex::decode(secret_hex).expect("secret hex must decode");
let mut handshake = [0x5Au8; HANDSHAKE_LEN]; let mut handshake = [0x5Au8; HANDSHAKE_LEN];
for (idx, b) in handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN] for (idx, b) in handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]
@ -49,7 +53,9 @@ fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> {
fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig { fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig {
let mut cfg = ProxyConfig::default(); let mut cfg = ProxyConfig::default();
cfg.access.users.clear(); cfg.access.users.clear();
cfg.access.users.insert("user".to_string(), secret_hex.to_string()); cfg.access
.users
.insert("user".to_string(), secret_hex.to_string());
cfg.access.ignore_time_skew = true; cfg.access.ignore_time_skew = true;
cfg.general.modes.secure = true; cfg.general.modes.secure = true;
cfg cfg
@ -71,9 +77,19 @@ async fn mtproto_handshake_bit_flip_anywhere_rejected() {
let peer: SocketAddr = "192.0.2.1:12345".parse().unwrap(); let peer: SocketAddr = "192.0.2.1:12345".parse().unwrap();
// Baseline check // Baseline check
let res = handle_mtproto_handshake(&base, tokio::io::empty(), tokio::io::sink(), peer, &config, &replay_checker, false, None).await; let res = handle_mtproto_handshake(
&base,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
false,
None,
)
.await;
match res { match res {
HandshakeResult::Success(_) => {}, HandshakeResult::Success(_) => {}
_ => panic!("Baseline failed: expected Success"), _ => panic!("Baseline failed: expected Success"),
} }
@ -81,8 +97,21 @@ async fn mtproto_handshake_bit_flip_anywhere_rejected() {
for byte_pos in SKIP_LEN..HANDSHAKE_LEN { for byte_pos in SKIP_LEN..HANDSHAKE_LEN {
let mut h = base; let mut h = base;
h[byte_pos] ^= 0x01; // Flip 1 bit h[byte_pos] ^= 0x01; // Flip 1 bit
let res = handle_mtproto_handshake(&h, tokio::io::empty(), tokio::io::sink(), peer, &config, &replay_checker, false, None).await; let res = handle_mtproto_handshake(
assert!(matches!(res, HandshakeResult::BadClient { .. }), "Flip at byte {byte_pos} bit 0 must be rejected"); &h,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
false,
None,
)
.await;
assert!(
matches!(res, HandshakeResult::BadClient { .. }),
"Flip at byte {byte_pos} bit 0 must be rejected"
);
} }
} }
@ -102,7 +131,17 @@ async fn mtproto_handshake_timing_neutrality_mocked() {
let mut start = Instant::now(); let mut start = Instant::now();
for _ in 0..ITER { for _ in 0..ITER {
let _ = handle_mtproto_handshake(&base, tokio::io::empty(), tokio::io::sink(), peer, &config, &replay_checker, false, None).await; let _ = handle_mtproto_handshake(
&base,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
false,
None,
)
.await;
} }
let duration_success = start.elapsed(); let duration_success = start.elapsed();
@ -110,14 +149,30 @@ async fn mtproto_handshake_timing_neutrality_mocked() {
for i in 0..ITER { for i in 0..ITER {
let mut h = base; let mut h = base;
h[SKIP_LEN + (i % 48)] ^= 0xFF; h[SKIP_LEN + (i % 48)] ^= 0xFF;
let _ = handle_mtproto_handshake(&h, tokio::io::empty(), tokio::io::sink(), peer, &config, &replay_checker, false, None).await; let _ = handle_mtproto_handshake(
&h,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
false,
None,
)
.await;
} }
let duration_fail = start.elapsed(); let duration_fail = start.elapsed();
let avg_diff_ms = (duration_success.as_millis() as f64 - duration_fail.as_millis() as f64).abs() / ITER as f64; let avg_diff_ms = (duration_success.as_millis() as f64 - duration_fail.as_millis() as f64)
.abs()
/ ITER as f64;
// Threshold (loose for CI) // Threshold (loose for CI)
assert!(avg_diff_ms < 100.0, "Timing difference too large: {} ms/iter", avg_diff_ms); assert!(
avg_diff_ms < 100.0,
"Timing difference too large: {} ms/iter",
avg_diff_ms
);
} }
// ------------------------------------------------------------------ // ------------------------------------------------------------------
@ -145,10 +200,7 @@ async fn auth_probe_throttle_saturation_stress() {
auth_probe_record_failure(ip, now); auth_probe_record_failure(ip, now);
} }
let tracked = AUTH_PROBE_STATE let tracked = AUTH_PROBE_STATE.get().map(|state| state.len()).unwrap_or(0);
.get()
.map(|state| state.len())
.unwrap_or(0);
assert!( assert!(
tracked <= AUTH_PROBE_TRACK_MAX_ENTRIES, tracked <= AUTH_PROBE_TRACK_MAX_ENTRIES,
"auth probe state grew past hard cap: {tracked} > {AUTH_PROBE_TRACK_MAX_ENTRIES}" "auth probe state grew past hard cap: {tracked} > {AUTH_PROBE_TRACK_MAX_ENTRIES}"
@ -166,7 +218,17 @@ async fn mtproto_handshake_abridged_prefix_rejected() {
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let peer: SocketAddr = "192.0.2.3:12345".parse().unwrap(); let peer: SocketAddr = "192.0.2.3:12345".parse().unwrap();
let res = handle_mtproto_handshake(&handshake, tokio::io::empty(), tokio::io::sink(), peer, &config, &replay_checker, false, None).await; let res = handle_mtproto_handshake(
&handshake,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
false,
None,
)
.await;
// MTProxy stops immediately on 0xef // MTProxy stops immediately on 0xef
assert!(matches!(res, HandshakeResult::BadClient { .. })); assert!(matches!(res, HandshakeResult::BadClient { .. }));
} }
@ -181,8 +243,14 @@ async fn mtproto_handshake_preferred_user_mismatch_continues() {
let base = make_valid_mtproto_handshake(secret2_hex, ProtoTag::Secure, 1); let base = make_valid_mtproto_handshake(secret2_hex, ProtoTag::Secure, 1);
let mut config = ProxyConfig::default(); let mut config = ProxyConfig::default();
config.access.users.insert("user1".to_string(), secret1_hex.to_string()); config
config.access.users.insert("user2".to_string(), secret2_hex.to_string()); .access
.users
.insert("user1".to_string(), secret1_hex.to_string());
config
.access
.users
.insert("user2".to_string(), secret2_hex.to_string());
config.access.ignore_time_skew = true; config.access.ignore_time_skew = true;
config.general.modes.secure = true; config.general.modes.secure = true;
@ -190,7 +258,17 @@ async fn mtproto_handshake_preferred_user_mismatch_continues() {
let peer: SocketAddr = "192.0.2.4:12345".parse().unwrap(); let peer: SocketAddr = "192.0.2.4:12345".parse().unwrap();
// Even if we prefer user1, if user2 matches, it should succeed. // Even if we prefer user1, if user2 matches, it should succeed.
let res = handle_mtproto_handshake(&base, tokio::io::empty(), tokio::io::sink(), peer, &config, &replay_checker, false, Some("user1")).await; let res = handle_mtproto_handshake(
&base,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
false,
Some("user1"),
)
.await;
if let HandshakeResult::Success((_, _, success)) = res { if let HandshakeResult::Success((_, _, success)) = res {
assert_eq!(success.user, "user2"); assert_eq!(success.user, "user2");
} else { } else {
@ -218,7 +296,17 @@ async fn mtproto_handshake_concurrent_flood_stability() {
let peer: SocketAddr = format!("192.0.2.{}:12345", (i % 254) + 1).parse().unwrap(); let peer: SocketAddr = format!("192.0.2.{}:12345", (i % 254) + 1).parse().unwrap();
tasks.push(tokio::spawn(async move { tasks.push(tokio::spawn(async move {
let res = handle_mtproto_handshake(&base, tokio::io::empty(), tokio::io::sink(), peer, &config, &replay_checker, false, None).await; let res = handle_mtproto_handshake(
&base,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
false,
None,
)
.await;
matches!(res, HandshakeResult::Success(_)) matches!(res, HandshakeResult::Success(_))
})); }));
} }
@ -306,7 +394,10 @@ async fn mtproto_blackhat_mutation_corpus_never_panics_and_stays_fail_closed() {
.expect("fuzzed mutation must complete in bounded time"); .expect("fuzzed mutation must complete in bounded time");
assert!( assert!(
matches!(res, HandshakeResult::BadClient { .. } | HandshakeResult::Success(_)), matches!(
res,
HandshakeResult::BadClient { .. } | HandshakeResult::Success(_)
),
"mutation corpus must stay within explicit handshake outcomes" "mutation corpus must stay within explicit handshake outcomes"
); );
} }
@ -345,7 +436,12 @@ async fn mtproto_invalid_storm_over_cap_keeps_probe_map_hard_bounded() {
for i in 0..(AUTH_PROBE_TRACK_MAX_ENTRIES + 512) { for i in 0..(AUTH_PROBE_TRACK_MAX_ENTRIES + 512) {
let peer: SocketAddr = SocketAddr::new( let peer: SocketAddr = SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(10, (i / 65535) as u8, ((i / 255) % 255) as u8, (i % 255 + 1) as u8)), IpAddr::V4(Ipv4Addr::new(
10,
(i / 65535) as u8,
((i / 255) % 255) as u8,
(i % 255 + 1) as u8,
)),
43000 + (i % 20000) as u16, 43000 + (i % 20000) as u16,
); );
let res = handle_mtproto_handshake( let res = handle_mtproto_handshake(
@ -362,10 +458,7 @@ async fn mtproto_invalid_storm_over_cap_keeps_probe_map_hard_bounded() {
assert!(matches!(res, HandshakeResult::BadClient { .. })); assert!(matches!(res, HandshakeResult::BadClient { .. }));
} }
let tracked = AUTH_PROBE_STATE let tracked = AUTH_PROBE_STATE.get().map(|state| state.len()).unwrap_or(0);
.get()
.map(|state| state.len())
.unwrap_or(0);
assert!( assert!(
tracked <= AUTH_PROBE_TRACK_MAX_ENTRIES, tracked <= AUTH_PROBE_TRACK_MAX_ENTRIES,
"probe map must remain bounded under invalid storm: {tracked}" "probe map must remain bounded under invalid storm: {tracked}"
@ -415,7 +508,10 @@ async fn mtproto_property_style_multi_bit_mutations_fail_closed_or_auth_only() {
.expect("mutation iteration must complete in bounded time"); .expect("mutation iteration must complete in bounded time");
assert!( assert!(
matches!(outcome, HandshakeResult::BadClient { .. } | HandshakeResult::Success(_)), matches!(
outcome,
HandshakeResult::BadClient { .. } | HandshakeResult::Success(_)
),
"mutations must remain fail-closed/auth-only" "mutations must remain fail-closed/auth-only"
); );
} }

View File

@ -6,7 +6,7 @@ use crate::protocol::constants::ProtoTag;
use crate::stats::ReplayChecker; use crate::stats::ReplayChecker;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::MutexGuard; use std::sync::MutexGuard;
use tokio::time::{timeout, Duration as TokioDuration}; use tokio::time::{Duration as TokioDuration, timeout};
fn make_mtproto_handshake_with_proto_bytes( fn make_mtproto_handshake_with_proto_bytes(
secret_hex: &str, secret_hex: &str,
@ -48,14 +48,20 @@ fn make_mtproto_handshake_with_proto_bytes(
handshake handshake
} }
fn make_valid_mtproto_handshake(secret_hex: &str, proto_tag: ProtoTag, dc_idx: i16) -> [u8; HANDSHAKE_LEN] { fn make_valid_mtproto_handshake(
secret_hex: &str,
proto_tag: ProtoTag,
dc_idx: i16,
) -> [u8; HANDSHAKE_LEN] {
make_mtproto_handshake_with_proto_bytes(secret_hex, proto_tag.to_bytes(), dc_idx) make_mtproto_handshake_with_proto_bytes(secret_hex, proto_tag.to_bytes(), dc_idx)
} }
fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig { fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig {
let mut cfg = ProxyConfig::default(); let mut cfg = ProxyConfig::default();
cfg.access.users.clear(); cfg.access.users.clear();
cfg.access.users.insert("user".to_string(), secret_hex.to_string()); cfg.access
.users
.insert("user".to_string(), secret_hex.to_string());
cfg.access.ignore_time_skew = true; cfg.access.ignore_time_skew = true;
cfg.general.modes.secure = true; cfg.general.modes.secure = true;
cfg cfg
@ -140,7 +146,9 @@ async fn mtproto_handshake_fuzz_corpus_never_panics_and_stays_fail_closed() {
for _ in 0..32 { for _ in 0..32 {
let mut mutated = base; let mut mutated = base;
for _ in 0..4 { for _ in 0..4 {
seed = seed.wrapping_mul(2862933555777941757).wrapping_add(3037000493); seed = seed
.wrapping_mul(2862933555777941757)
.wrapping_add(3037000493);
let idx = SKIP_LEN + (seed as usize % (PREKEY_LEN + IV_LEN)); let idx = SKIP_LEN + (seed as usize % (PREKEY_LEN + IV_LEN));
mutated[idx] ^= ((seed >> 19) as u8).wrapping_add(1); mutated[idx] ^= ((seed >> 19) as u8).wrapping_add(1);
} }

View File

@ -1,8 +1,8 @@
use super::*; use super::*;
use crate::crypto::{sha256, sha256_hmac}; use crate::crypto::{sha256, sha256_hmac};
use dashmap::DashMap; use dashmap::DashMap;
use rand::{RngExt, SeedableRng};
use rand::rngs::StdRng; use rand::rngs::StdRng;
use rand::{RngExt, SeedableRng};
use std::net::{IpAddr, Ipv4Addr}; use std::net::{IpAddr, Ipv4Addr};
use std::sync::Arc; use std::sync::Arc;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
@ -80,8 +80,7 @@ fn make_valid_tls_client_hello_with_alpn(
for i in 0..4 { for i in 0..4 {
digest[28 + i] ^= ts[i]; digest[28 + i] ^= ts[i];
} }
record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].copy_from_slice(&digest);
.copy_from_slice(&digest);
record record
} }
@ -151,8 +150,7 @@ fn make_valid_tls_client_hello_with_sni_and_alpn(
for i in 0..4 { for i in 0..4 {
digest[28 + i] ^= ts[i]; digest[28 + i] ^= ts[i];
} }
record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].copy_from_slice(&digest);
.copy_from_slice(&digest);
record record
} }
@ -167,7 +165,11 @@ fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig {
cfg cfg
} }
fn make_valid_mtproto_handshake(secret_hex: &str, proto_tag: ProtoTag, dc_idx: i16) -> [u8; HANDSHAKE_LEN] { fn make_valid_mtproto_handshake(
secret_hex: &str,
proto_tag: ProtoTag,
dc_idx: i16,
) -> [u8; HANDSHAKE_LEN] {
let secret = hex::decode(secret_hex).expect("secret hex must decode for mtproto test helper"); let secret = hex::decode(secret_hex).expect("secret hex must decode for mtproto test helper");
let mut handshake = [0x5Au8; HANDSHAKE_LEN]; let mut handshake = [0x5Au8; HANDSHAKE_LEN];
@ -328,7 +330,10 @@ fn test_generate_tg_nonce_fast_mode_embeds_reversed_client_enc_material() {
expected.extend_from_slice(&client_enc_iv.to_be_bytes()); expected.extend_from_slice(&client_enc_iv.to_be_bytes());
expected.reverse(); expected.reverse();
assert_eq!(&nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN], expected.as_slice()); assert_eq!(
&nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN],
expected.as_slice()
);
} }
#[test] #[test]
@ -445,7 +450,9 @@ async fn tls_replay_with_ignore_time_skew_and_small_boot_timestamp_is_still_bloc
#[tokio::test] #[tokio::test]
async fn tls_replay_concurrent_identical_handshake_allows_exactly_one_success() { async fn tls_replay_concurrent_identical_handshake_allows_exactly_one_success() {
let secret = [0x77u8; 16]; let secret = [0x77u8; 16];
let config = Arc::new(test_config_with_secret_hex("77777777777777777777777777777777")); let config = Arc::new(test_config_with_secret_hex(
"77777777777777777777777777777777",
));
let replay_checker = Arc::new(ReplayChecker::new(4096, Duration::from_secs(60))); let replay_checker = Arc::new(ReplayChecker::new(4096, Duration::from_secs(60)));
let rng = Arc::new(SecureRandom::new()); let rng = Arc::new(SecureRandom::new());
let handshake = Arc::new(make_valid_tls_handshake(&secret, 0)); let handshake = Arc::new(make_valid_tls_handshake(&secret, 0));
@ -785,10 +792,10 @@ async fn mixed_secret_lengths_keep_valid_user_authenticating() {
.access .access
.users .users
.insert("broken_user".to_string(), "aa".to_string()); .insert("broken_user".to_string(), "aa".to_string());
config config.access.users.insert(
.access "valid_user".to_string(),
.users "22222222222222222222222222222222".to_string(),
.insert("valid_user".to_string(), "22222222222222222222222222222222".to_string()); );
config.access.ignore_time_skew = true; config.access.ignore_time_skew = true;
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
@ -829,12 +836,8 @@ async fn tls_sni_preferred_user_hint_selects_matching_identity_first() {
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let rng = SecureRandom::new(); let rng = SecureRandom::new();
let peer: SocketAddr = "198.51.100.188:44326".parse().unwrap(); let peer: SocketAddr = "198.51.100.188:44326".parse().unwrap();
let handshake = make_valid_tls_client_hello_with_sni_and_alpn( let handshake =
&shared_secret, make_valid_tls_client_hello_with_sni_and_alpn(&shared_secret, 0, "user-b", &[b"h2"]);
0,
"user-b",
&[b"h2"],
);
let result = handle_tls_handshake( let result = handle_tls_handshake(
&handshake, &handshake,
@ -868,10 +871,10 @@ fn stress_decode_user_secrets_keeps_preferred_user_first_in_large_set() {
let secret_hex = "7f7f7f7f7f7f7f7f7f7f7f7f7f7f7f7f".to_string(); let secret_hex = "7f7f7f7f7f7f7f7f7f7f7f7f7f7f7f7f".to_string();
for i in 0..4096usize { for i in 0..4096usize {
config.access.users.insert( config
format!("decoy-{i:04}.example"), .access
secret_hex.clone(), .users
); .insert(format!("decoy-{i:04}.example"), secret_hex.clone());
} }
config config
.access .access
@ -910,10 +913,10 @@ async fn stress_tls_sni_preferred_user_hint_scales_to_large_user_set() {
let secret_hex = "7f7f7f7f7f7f7f7f7f7f7f7f7f7f7f7f".to_string(); let secret_hex = "7f7f7f7f7f7f7f7f7f7f7f7f7f7f7f7f".to_string();
for i in 0..4096usize { for i in 0..4096usize {
config.access.users.insert( config
format!("decoy-{i:04}.example"), .access
secret_hex.clone(), .users
); .insert(format!("decoy-{i:04}.example"), secret_hex.clone());
} }
config config
.access .access
@ -945,8 +948,7 @@ async fn stress_tls_sni_preferred_user_hint_scales_to_large_user_set() {
match result { match result {
HandshakeResult::Success((_, _, user)) => { HandshakeResult::Success((_, _, user)) => {
assert_eq!( assert_eq!(
user, user, preferred_user,
preferred_user,
"SNI preferred-user hint must remain stable under large user cardinality" "SNI preferred-user hint must remain stable under large user cardinality"
); );
} }
@ -1880,11 +1882,15 @@ fn auth_probe_ipv6_different_prefixes_use_distinct_buckets() {
"different IPv6 /64 prefixes must not share throttle buckets" "different IPv6 /64 prefixes must not share throttle buckets"
); );
assert_eq!( assert_eq!(
state.get(&normalize_auth_probe_ip(ip_a)).map(|entry| entry.fail_streak), state
.get(&normalize_auth_probe_ip(ip_a))
.map(|entry| entry.fail_streak),
Some(1) Some(1)
); );
assert_eq!( assert_eq!(
state.get(&normalize_auth_probe_ip(ip_b)).map(|entry| entry.fail_streak), state
.get(&normalize_auth_probe_ip(ip_b))
.map(|entry| entry.fail_streak),
Some(1) Some(1)
); );
} }
@ -1944,7 +1950,6 @@ fn auth_probe_eviction_offset_changes_with_time_component() {
); );
} }
#[test] #[test]
fn auth_probe_round_limited_overcap_eviction_marks_saturation_and_keeps_newcomer_trackable() { fn auth_probe_round_limited_overcap_eviction_marks_saturation_and_keeps_newcomer_trackable() {
let _guard = auth_probe_test_lock() let _guard = auth_probe_test_lock()
@ -1986,7 +1991,10 @@ fn auth_probe_round_limited_overcap_eviction_marks_saturation_and_keeps_newcomer
let newcomer = IpAddr::V4(Ipv4Addr::new(203, 0, 113, 40)); let newcomer = IpAddr::V4(Ipv4Addr::new(203, 0, 113, 40));
auth_probe_record_failure_with_state(&state, newcomer, now + Duration::from_millis(1)); auth_probe_record_failure_with_state(&state, newcomer, now + Duration::from_millis(1));
assert!(state.get(&newcomer).is_some(), "newcomer must still be tracked under over-cap pressure"); assert!(
state.get(&newcomer).is_some(),
"newcomer must still be tracked under over-cap pressure"
);
assert!( assert!(
state.get(&sentinel).is_some(), state.get(&sentinel).is_some(),
"high fail-streak sentinel must survive round-limited eviction" "high fail-streak sentinel must survive round-limited eviction"
@ -2077,13 +2085,20 @@ fn stress_auth_probe_overcap_churn_does_not_starve_high_threat_sentinel_bucket()
((step >> 8) & 0xff) as u8, ((step >> 8) & 0xff) as u8,
(step & 0xff) as u8, (step & 0xff) as u8,
)); ));
auth_probe_record_failure_with_state(&state, newcomer, base_now + Duration::from_millis(step as u64 + 1)); auth_probe_record_failure_with_state(
&state,
newcomer,
base_now + Duration::from_millis(step as u64 + 1),
);
assert!( assert!(
state.get(&sentinel).is_some(), state.get(&sentinel).is_some(),
"step {step}: high-threat sentinel must not be starved by newcomer churn" "step {step}: high-threat sentinel must not be starved by newcomer churn"
); );
assert!(state.get(&newcomer).is_some(), "step {step}: newcomer must be tracked"); assert!(
state.get(&newcomer).is_some(),
"step {step}: newcomer must be tracked"
);
} }
} }
@ -2129,10 +2144,22 @@ fn light_fuzz_auth_probe_overcap_eviction_prefers_less_threatening_entries() {
); );
} }
let newcomer = IpAddr::V4(Ipv4Addr::new(203, 10, ((round >> 8) & 0xff) as u8, (round & 0xff) as u8)); let newcomer = IpAddr::V4(Ipv4Addr::new(
auth_probe_record_failure_with_state(&state, newcomer, now + Duration::from_millis(round as u64 + 1)); 203,
10,
((round >> 8) & 0xff) as u8,
(round & 0xff) as u8,
));
auth_probe_record_failure_with_state(
&state,
newcomer,
now + Duration::from_millis(round as u64 + 1),
);
assert!(state.get(&newcomer).is_some(), "round {round}: newcomer should be tracked"); assert!(
state.get(&newcomer).is_some(),
"round {round}: newcomer should be tracked"
);
assert!( assert!(
state.get(&sentinel).is_some(), state.get(&sentinel).is_some(),
"round {round}: high fail-streak sentinel should survive mixed low-threat pool" "round {round}: high fail-streak sentinel should survive mixed low-threat pool"
@ -2145,7 +2172,12 @@ fn light_fuzz_auth_probe_eviction_offset_is_deterministic_per_input_pair() {
let base = Instant::now(); let base = Instant::now();
for _ in 0..4096usize { for _ in 0..4096usize {
let ip = IpAddr::V4(Ipv4Addr::new(rng.random(), rng.random(), rng.random(), rng.random())); let ip = IpAddr::V4(Ipv4Addr::new(
rng.random(),
rng.random(),
rng.random(),
rng.random(),
));
let offset_ns = rng.random_range(0_u64..2_000_000); let offset_ns = rng.random_range(0_u64..2_000_000);
let when = base + Duration::from_nanos(offset_ns); let when = base + Duration::from_nanos(offset_ns);
@ -2244,8 +2276,7 @@ async fn auth_probe_concurrent_failures_do_not_lose_fail_streak_updates() {
let streak = auth_probe_fail_streak_for_testing(peer_ip) let streak = auth_probe_fail_streak_for_testing(peer_ip)
.expect("tracked peer must exist after concurrent failure burst"); .expect("tracked peer must exist after concurrent failure burst");
assert_eq!( assert_eq!(
streak as usize, streak as usize, tasks,
tasks,
"concurrent failures for one source must account every attempt" "concurrent failures for one source must account every attempt"
); );
} }
@ -2258,7 +2289,9 @@ async fn invalid_probe_noise_from_other_ips_does_not_break_valid_tls_handshake()
clear_auth_probe_state_for_testing(); clear_auth_probe_state_for_testing();
let secret = [0x31u8; 16]; let secret = [0x31u8; 16];
let config = Arc::new(test_config_with_secret_hex("31313131313131313131313131313131")); let config = Arc::new(test_config_with_secret_hex(
"31313131313131313131313131313131",
));
let replay_checker = Arc::new(ReplayChecker::new(4096, Duration::from_secs(60))); let replay_checker = Arc::new(ReplayChecker::new(4096, Duration::from_secs(60)));
let rng = Arc::new(SecureRandom::new()); let rng = Arc::new(SecureRandom::new());
let victim_peer: SocketAddr = "198.51.100.91:44391".parse().unwrap(); let victim_peer: SocketAddr = "198.51.100.91:44391".parse().unwrap();
@ -2845,7 +2878,10 @@ async fn saturation_grace_progression_tls_reaches_cap_then_stops_incrementing()
) )
.await; .await;
assert!(matches!(result, HandshakeResult::BadClient { .. })); assert!(matches!(result, HandshakeResult::BadClient { .. }));
assert_eq!(auth_probe_fail_streak_for_testing(peer.ip()), Some(expected)); assert_eq!(
auth_probe_fail_streak_for_testing(peer.ip()),
Some(expected)
);
} }
{ {
@ -2924,7 +2960,10 @@ async fn saturation_grace_progression_mtproto_reaches_cap_then_stops_incrementin
) )
.await; .await;
assert!(matches!(result, HandshakeResult::BadClient { .. })); assert!(matches!(result, HandshakeResult::BadClient { .. }));
assert_eq!(auth_probe_fail_streak_for_testing(peer.ip()), Some(expected)); assert_eq!(
auth_probe_fail_streak_for_testing(peer.ip()),
Some(expected)
);
} }
{ {
@ -3148,7 +3187,9 @@ async fn adversarial_same_peer_invalid_tls_storm_does_not_bypass_saturation_grac
.unwrap_or_else(|poisoned| poisoned.into_inner()); .unwrap_or_else(|poisoned| poisoned.into_inner());
clear_auth_probe_state_for_testing(); clear_auth_probe_state_for_testing();
let config = Arc::new(test_config_with_secret_hex("75757575757575757575757575757575")); let config = Arc::new(test_config_with_secret_hex(
"75757575757575757575757575757575",
));
let replay_checker = Arc::new(ReplayChecker::new(1024, Duration::from_secs(60))); let replay_checker = Arc::new(ReplayChecker::new(1024, Duration::from_secs(60)));
let rng = Arc::new(SecureRandom::new()); let rng = Arc::new(SecureRandom::new());
let peer: SocketAddr = "198.51.100.212:45212".parse().unwrap(); let peer: SocketAddr = "198.51.100.212:45212".parse().unwrap();
@ -3296,7 +3337,11 @@ async fn adversarial_saturation_burst_only_admits_valid_tls_and_mtproto_handshak
} }
let valid_tls = Arc::new(make_valid_tls_handshake(&secret, 0)); let valid_tls = Arc::new(make_valid_tls_handshake(&secret, 0));
let valid_mtproto = Arc::new(make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 3)); let valid_mtproto = Arc::new(make_valid_mtproto_handshake(
secret_hex,
ProtoTag::Secure,
3,
));
let mut invalid_tls = vec![0x42u8; tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + 32]; let mut invalid_tls = vec![0x42u8; tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + 32];
invalid_tls[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = 32; invalid_tls[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = 32;
let invalid_tls = Arc::new(invalid_tls); let invalid_tls = Arc::new(invalid_tls);
@ -3368,7 +3413,9 @@ async fn adversarial_saturation_burst_only_admits_valid_tls_and_mtproto_handshak
match task.await.unwrap() { match task.await.unwrap() {
HandshakeResult::BadClient { .. } => bad_clients += 1, HandshakeResult::BadClient { .. } => bad_clients += 1,
HandshakeResult::Success(_) => panic!("invalid TLS probe unexpectedly authenticated"), HandshakeResult::Success(_) => panic!("invalid TLS probe unexpectedly authenticated"),
HandshakeResult::Error(err) => panic!("unexpected error in invalid TLS saturation burst test: {err}"), HandshakeResult::Error(err) => {
panic!("unexpected error in invalid TLS saturation burst test: {err}")
}
} }
} }
@ -3385,8 +3432,7 @@ async fn adversarial_saturation_burst_only_admits_valid_tls_and_mtproto_handshak
); );
assert_eq!( assert_eq!(
bad_clients, bad_clients, 48,
48,
"all invalid TLS probes in mixed saturation burst must be rejected" "all invalid TLS probes in mixed saturation burst must be rejected"
); );
} }

View File

@ -57,6 +57,25 @@ fn spread_u128(values: &[u128]) -> u128 {
max_v - min_v max_v - min_v
} }
fn interval_gap_usize(a: &BTreeSet<usize>, b: &BTreeSet<usize>) -> usize {
if a.is_empty() || b.is_empty() {
return 0;
}
let a_min = *a.iter().next().unwrap();
let a_max = *a.iter().next_back().unwrap();
let b_min = *b.iter().next().unwrap();
let b_max = *b.iter().next_back().unwrap();
if a_max < b_min {
b_min - a_max
} else if b_max < a_min {
a_min - b_max
} else {
0
}
}
async fn collect_timing_samples(path: PathClass, timing_norm_enabled: bool, n: usize) -> Vec<u128> { async fn collect_timing_samples(path: PathClass, timing_norm_enabled: bool, n: usize) -> Vec<u128> {
let mut out = Vec::with_capacity(n); let mut out = Vec::with_capacity(n);
for _ in 0..n { for _ in 0..n {
@ -266,23 +285,40 @@ async fn integration_ab_harness_envelope_and_blur_improve_obfuscation_vs_baselin
let baseline_overlap = baseline_a.intersection(&baseline_b).count(); let baseline_overlap = baseline_a.intersection(&baseline_b).count();
let hardened_overlap = hardened_a.intersection(&hardened_b).count(); let hardened_overlap = hardened_a.intersection(&hardened_b).count();
let baseline_gap = interval_gap_usize(&baseline_a, &baseline_b);
let hardened_gap = interval_gap_usize(&hardened_a, &hardened_b);
println!( println!(
"ab_harness_length baseline_overlap={} hardened_overlap={} baseline_a={} baseline_b={} hardened_a={} hardened_b={}", "ab_harness_length baseline_overlap={} hardened_overlap={} baseline_gap={} hardened_gap={} baseline_a={} baseline_b={} hardened_a={} hardened_b={}",
baseline_overlap, baseline_overlap,
hardened_overlap, hardened_overlap,
baseline_gap,
hardened_gap,
baseline_a.len(), baseline_a.len(),
baseline_b.len(), baseline_b.len(),
hardened_a.len(), hardened_a.len(),
hardened_b.len() hardened_b.len()
); );
assert_eq!(baseline_overlap, 0, "baseline above-cap classes should be disjoint"); assert_eq!(
baseline_overlap, 0,
"baseline above-cap classes should be disjoint"
);
assert!( assert!(
hardened_overlap > baseline_overlap, hardened_a.len() > baseline_a.len() && hardened_b.len() > baseline_b.len(),
"above-cap blur should increase cross-class overlap: baseline={} hardened={}", "above-cap blur should widen per-class emitted lengths: baseline_a={} baseline_b={} hardened_a={} hardened_b={}",
baseline_a.len(),
baseline_b.len(),
hardened_a.len(),
hardened_b.len()
);
assert!(
hardened_overlap > baseline_overlap || hardened_gap < baseline_gap,
"above-cap blur should reduce class separability via direct overlap or tighter interval gap: baseline_overlap={} hardened_overlap={} baseline_gap={} hardened_gap={}",
baseline_overlap, baseline_overlap,
hardened_overlap hardened_overlap,
baseline_gap,
hardened_gap
); );
} }
@ -314,7 +350,10 @@ fn timing_classifier_helper_threshold_accuracy_drops_for_identical_sets() {
let a = vec![10u128, 11, 12, 13, 14]; let a = vec![10u128, 11, 12, 13, 14];
let b = vec![10u128, 11, 12, 13, 14]; let b = vec![10u128, 11, 12, 13, 14];
let acc = best_threshold_accuracy_u128(&a, &b); let acc = best_threshold_accuracy_u128(&a, &b);
assert!(acc <= 0.6, "identical sets should not be strongly separable"); assert!(
acc <= 0.6,
"identical sets should not be strongly separable"
);
} }
#[test] #[test]
@ -336,7 +375,10 @@ async fn timing_classifier_baseline_connect_fail_vs_slow_backend_is_highly_separ
let slow = collect_timing_samples(PathClass::SlowBackend, false, 8).await; let slow = collect_timing_samples(PathClass::SlowBackend, false, 8).await;
let acc = best_threshold_accuracy_u128(&fail, &slow); let acc = best_threshold_accuracy_u128(&fail, &slow);
assert!(acc >= 0.80, "baseline timing classes should be separable enough"); assert!(
acc >= 0.80,
"baseline timing classes should be separable enough"
);
} }
#[tokio::test] #[tokio::test]
@ -408,7 +450,10 @@ async fn timing_classifier_normalized_mean_bucket_delta_connect_fail_vs_connect_
let fail_mean = mean_ms(&fail); let fail_mean = mean_ms(&fail);
let success_mean = mean_ms(&success); let success_mean = mean_ms(&success);
let delta_bucket = ((fail_mean as i128 - success_mean as i128).abs()) / 20; let delta_bucket = ((fail_mean as i128 - success_mean as i128).abs()) / 20;
assert!(delta_bucket <= 3, "mean bucket delta too large: {delta_bucket}"); assert!(
delta_bucket <= 3,
"mean bucket delta too large: {delta_bucket}"
);
} }
#[tokio::test] #[tokio::test]
@ -418,7 +463,10 @@ async fn timing_classifier_normalized_p95_bucket_delta_connect_success_vs_slow_i
let p95_success = percentile_ms(success, 95, 100); let p95_success = percentile_ms(success, 95, 100);
let p95_slow = percentile_ms(slow, 95, 100); let p95_slow = percentile_ms(slow, 95, 100);
let delta_bucket = ((p95_success as i128 - p95_slow as i128).abs()) / 20; let delta_bucket = ((p95_success as i128 - p95_slow as i128).abs()) / 20;
assert!(delta_bucket <= 4, "p95 bucket delta too large: {delta_bucket}"); assert!(
delta_bucket <= 4,
"p95 bucket delta too large: {delta_bucket}"
);
} }
#[tokio::test] #[tokio::test]
@ -434,7 +482,10 @@ async fn timing_classifier_normalized_spread_is_not_worse_than_baseline_for_conn
} }
#[tokio::test] #[tokio::test]
async fn timing_classifier_light_fuzz_pairwise_bucketed_accuracy_stays_bounded_under_normalization() { async fn timing_classifier_light_fuzz_pairwise_bucketed_accuracy_stays_bounded_under_normalization()
{
const SAMPLE_COUNT: usize = 6;
let pairs = [ let pairs = [
(PathClass::ConnectFail, PathClass::ConnectSuccess), (PathClass::ConnectFail, PathClass::ConnectSuccess),
(PathClass::ConnectFail, PathClass::SlowBackend), (PathClass::ConnectFail, PathClass::SlowBackend),
@ -445,12 +496,14 @@ async fn timing_classifier_light_fuzz_pairwise_bucketed_accuracy_stays_bounded_u
let mut baseline_sum = 0.0f64; let mut baseline_sum = 0.0f64;
let mut hardened_sum = 0.0f64; let mut hardened_sum = 0.0f64;
let mut pair_count = 0usize; let mut pair_count = 0usize;
let acc_quant_step = 1.0 / (2 * SAMPLE_COUNT) as f64;
let tolerated_pair_regression = acc_quant_step + 0.03;
for (a, b) in pairs { for (a, b) in pairs {
let baseline_a = collect_timing_samples(a, false, 6).await; let baseline_a = collect_timing_samples(a, false, SAMPLE_COUNT).await;
let baseline_b = collect_timing_samples(b, false, 6).await; let baseline_b = collect_timing_samples(b, false, SAMPLE_COUNT).await;
let hardened_a = collect_timing_samples(a, true, 6).await; let hardened_a = collect_timing_samples(a, true, SAMPLE_COUNT).await;
let hardened_b = collect_timing_samples(b, true, 6).await; let hardened_b = collect_timing_samples(b, true, SAMPLE_COUNT).await;
let baseline_acc = best_threshold_accuracy_u128( let baseline_acc = best_threshold_accuracy_u128(
&bucketize_ms(&baseline_a, 20), &bucketize_ms(&baseline_a, 20),
@ -466,11 +519,15 @@ async fn timing_classifier_light_fuzz_pairwise_bucketed_accuracy_stays_bounded_u
// Guard hard only on informative baseline pairs. // Guard hard only on informative baseline pairs.
if baseline_acc >= 0.75 { if baseline_acc >= 0.75 {
assert!( assert!(
hardened_acc <= baseline_acc + 0.05, hardened_acc <= baseline_acc + tolerated_pair_regression,
"normalization should not materially worsen informative pair: baseline={baseline_acc:.3} hardened={hardened_acc:.3}" "normalization should not materially worsen informative pair: baseline={baseline_acc:.3} hardened={hardened_acc:.3} tolerated={tolerated_pair_regression:.3}"
); );
} }
println!(
"timing_classifier_pair baseline={baseline_acc:.3} hardened={hardened_acc:.3} tolerated_pair_regression={tolerated_pair_regression:.3}"
);
if hardened_acc + 0.05 <= baseline_acc { if hardened_acc + 0.05 <= baseline_acc {
meaningful_improvement_seen = true; meaningful_improvement_seen = true;
} }
@ -484,7 +541,7 @@ async fn timing_classifier_light_fuzz_pairwise_bucketed_accuracy_stays_bounded_u
let hardened_avg = hardened_sum / pair_count as f64; let hardened_avg = hardened_sum / pair_count as f64;
assert!( assert!(
hardened_avg <= baseline_avg + 0.08, hardened_avg <= baseline_avg + 0.10,
"normalization should not materially increase average pairwise separability: baseline_avg={baseline_avg:.3} hardened_avg={hardened_avg:.3}" "normalization should not materially increase average pairwise separability: baseline_avg={baseline_avg:.3} hardened_avg={hardened_avg:.3}"
); );
@ -504,7 +561,10 @@ async fn timing_classifier_stress_parallel_sampling_finishes_and_stays_bounded()
_ => PathClass::SlowBackend, _ => PathClass::SlowBackend,
}; };
let sample = measure_masking_duration_ms(class, true).await; let sample = measure_masking_duration_ms(class, true).await;
assert!((100..=1600).contains(&sample), "stress sample out of bounds: {sample}"); assert!(
(100..=1600).contains(&sample),
"stress sample out of bounds: {sample}"
);
})); }));
} }

View File

@ -1,13 +1,13 @@
use super::*; use super::*;
use std::sync::Arc;
use tokio::io::duplex;
use tokio::net::TcpListener;
use tokio::time::{Instant, Duration};
use crate::config::ProxyConfig; use crate::config::ProxyConfig;
use crate::proxy::relay::relay_bidirectional; use crate::proxy::relay::relay_bidirectional;
use crate::stats::Stats; use crate::stats::Stats;
use crate::stats::beobachten::BeobachtenStore; use crate::stats::beobachten::BeobachtenStore;
use crate::stream::BufferPool; use crate::stream::BufferPool;
use std::sync::Arc;
use tokio::io::duplex;
use tokio::net::TcpListener;
use tokio::time::{Duration, Instant};
// ------------------------------------------------------------------ // ------------------------------------------------------------------
// Probing Indistinguishability (OWASP ASVS 5.1.7) // Probing Indistinguishability (OWASP ASVS 5.1.7)
@ -28,7 +28,10 @@ async fn masking_probes_indistinguishable_timing() {
let probes = vec![ let probes = vec![
(b"GET / HTTP/1.1\r\nHost: x\r\n\r\n".to_vec(), "HTTP"), (b"GET / HTTP/1.1\r\nHost: x\r\n\r\n".to_vec(), "HTTP"),
(b"SSH-2.0-probe".to_vec(), "SSH"), (b"SSH-2.0-probe".to_vec(), "SSH"),
(vec![0x16, 0x03, 0x03, 0x00, 0x05, 0x01, 0x00, 0x00, 0x01, 0x00], "TLS-scanner"), (
vec![0x16, 0x03, 0x03, 0x00, 0x05, 0x01, 0x00, 0x00, 0x01, 0x00],
"TLS-scanner",
),
(vec![0x42; 5], "port-scanner"), (vec![0x42; 5], "port-scanner"),
]; ];
@ -45,13 +48,17 @@ async fn masking_probes_indistinguishable_timing() {
local_addr, local_addr,
&config, &config,
&beobachten, &beobachten,
).await; )
.await;
let elapsed = start.elapsed(); let elapsed = start.elapsed();
// We expect any outcome to take roughly MASK_TIMEOUT (50ms in tests) // We expect any outcome to take roughly MASK_TIMEOUT (50ms in tests)
// to mask whether the backend was reachable or refused. // to mask whether the backend was reachable or refused.
assert!(elapsed >= Duration::from_millis(30), "Probe {type_name} finished too fast: {elapsed:?}"); assert!(
elapsed >= Duration::from_millis(30),
"Probe {type_name} finished too fast: {elapsed:?}"
);
} }
} }
@ -87,14 +94,18 @@ async fn masking_budget_stress_under_load() {
local_addr, local_addr,
&config, &config,
&beobachten, &beobachten,
).await; )
.await;
start.elapsed() start.elapsed()
})); }));
} }
for task in tasks { for task in tasks {
let elapsed = task.await.unwrap(); let elapsed = task.await.unwrap();
assert!(elapsed >= Duration::from_millis(30), "Stress probe finished too fast: {elapsed:?}"); assert!(
elapsed >= Duration::from_millis(30),
"Stress probe finished too fast: {elapsed:?}"
);
} }
} }
@ -133,7 +144,9 @@ async fn masking_slowloris_client_idle_timeout_rejected() {
assert_eq!(observed, initial); assert_eq!(observed, initial);
let mut drip = [0u8; 1]; let mut drip = [0u8; 1];
let drip_read = tokio::time::timeout(Duration::from_millis(220), stream.read_exact(&mut drip)).await; let drip_read =
tokio::time::timeout(Duration::from_millis(220), stream.read_exact(&mut drip))
.await;
assert!( assert!(
drip_read.is_err() || drip_read.unwrap().is_err(), drip_read.is_err() || drip_read.unwrap().is_err(),
"backend must not receive post-timeout slowloris drip bytes" "backend must not receive post-timeout slowloris drip bytes"
@ -190,11 +203,24 @@ async fn masking_fallback_down_mimics_timeout() {
let local: SocketAddr = "192.0.2.1:443".parse().unwrap(); let local: SocketAddr = "192.0.2.1:443".parse().unwrap();
let start = Instant::now(); let start = Instant::now();
handle_bad_client(server_reader, server_writer, b"GET / HTTP/1.1\r\n", peer, local, &config, &beobachten).await; handle_bad_client(
server_reader,
server_writer,
b"GET / HTTP/1.1\r\n",
peer,
local,
&config,
&beobachten,
)
.await;
let elapsed = start.elapsed(); let elapsed = start.elapsed();
// It should wait for MASK_TIMEOUT (50ms in tests) even if connection was refused immediately // It should wait for MASK_TIMEOUT (50ms in tests) even if connection was refused immediately
assert!(elapsed >= Duration::from_millis(40), "Must respect connect budget even on failure: {:?}", elapsed); assert!(
elapsed >= Duration::from_millis(40),
"Must respect connect budget even on failure: {:?}",
elapsed
);
} }
// ------------------------------------------------------------------ // ------------------------------------------------------------------
@ -205,7 +231,13 @@ async fn masking_fallback_down_mimics_timeout() {
async fn masking_ssrf_resolve_internal_ranges_blocked() { async fn masking_ssrf_resolve_internal_ranges_blocked() {
use crate::network::dns_overrides::resolve_socket_addr; use crate::network::dns_overrides::resolve_socket_addr;
let blocked_ips = ["127.0.0.1", "169.254.169.254", "10.0.0.1", "192.168.1.1", "0.0.0.0"]; let blocked_ips = [
"127.0.0.1",
"169.254.169.254",
"10.0.0.1",
"192.168.1.1",
"0.0.0.0",
];
for ip in blocked_ips { for ip in blocked_ips {
assert!( assert!(
@ -270,7 +302,10 @@ async fn masking_zero_length_initial_data_does_not_hang_or_panic() {
.await .await
.unwrap() .unwrap()
.unwrap(); .unwrap();
assert_eq!(n, 0, "backend must observe clean EOF for empty initial payload"); assert_eq!(
n, 0,
"backend must observe clean EOF for empty initial payload"
);
}); });
let mut config = ProxyConfig::default(); let mut config = ProxyConfig::default();
@ -312,7 +347,10 @@ async fn masking_oversized_initial_payload_is_forwarded_verbatim() {
let (mut stream, _) = listener.accept().await.unwrap(); let (mut stream, _) = listener.accept().await.unwrap();
let mut observed = vec![0u8; payload.len()]; let mut observed = vec![0u8; payload.len()];
stream.read_exact(&mut observed).await.unwrap(); stream.read_exact(&mut observed).await.unwrap();
assert_eq!(observed, payload, "large initial payload must stay byte-for-byte"); assert_eq!(
observed, payload,
"large initial payload must stay byte-for-byte"
);
} }
}); });
@ -491,7 +529,10 @@ async fn chaos_burst_reconnect_storm_for_masking_and_relay_concurrently() {
}); });
let mut observed = vec![0u8; expected_reply.len()]; let mut observed = vec![0u8; expected_reply.len()];
client_visible_reader.read_exact(&mut observed).await.unwrap(); client_visible_reader
.read_exact(&mut observed)
.await
.unwrap();
assert_eq!(observed, expected_reply); assert_eq!(observed, expected_reply);
timeout(Duration::from_secs(2), handle) timeout(Duration::from_secs(2), handle)
@ -646,7 +687,10 @@ async fn chaos_burst_reconnect_storm_for_masking_and_relay_multiwave_soak() {
}); });
let mut observed = vec![0u8; expected_reply.len()]; let mut observed = vec![0u8; expected_reply.len()];
client_visible_reader.read_exact(&mut observed).await.unwrap(); client_visible_reader
.read_exact(&mut observed)
.await
.unwrap();
assert_eq!(observed, expected_reply); assert_eq!(observed, expected_reply);
timeout(Duration::from_secs(3), handle) timeout(Duration::from_secs(3), handle)

View File

@ -0,0 +1,107 @@
use super::*;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::net::TcpListener;
use tokio::time::Duration;
async fn capture_forwarded_len_with_mode(
body_sent: usize,
close_client_after_write: bool,
aggressive_mode: bool,
above_cap_blur: bool,
above_cap_blur_max_bytes: usize,
) -> usize {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = backend_addr.port();
config.censorship.mask_shape_hardening = true;
config.censorship.mask_shape_hardening_aggressive_mode = aggressive_mode;
config.censorship.mask_shape_bucket_floor_bytes = 512;
config.censorship.mask_shape_bucket_cap_bytes = 4096;
config.censorship.mask_shape_above_cap_blur = above_cap_blur;
config.censorship.mask_shape_above_cap_blur_max_bytes = above_cap_blur_max_bytes;
let accept_task = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut got = Vec::new();
let _ = tokio::time::timeout(Duration::from_secs(2), stream.read_to_end(&mut got)).await;
got.len()
});
let (server_reader, mut client_writer) = duplex(64 * 1024);
let (_client_visible_reader, client_visible_writer) = duplex(64 * 1024);
let peer: SocketAddr = "198.51.100.248:57248".parse().unwrap();
let local: SocketAddr = "127.0.0.1:443".parse().unwrap();
let beobachten = BeobachtenStore::new();
let mut probe = vec![0u8; 5 + body_sent];
probe[0] = 0x16;
probe[1] = 0x03;
probe[2] = 0x01;
probe[3..5].copy_from_slice(&7000u16.to_be_bytes());
probe[5..].fill(0x31);
let fallback = tokio::spawn(async move {
handle_bad_client(
server_reader,
client_visible_writer,
&probe,
peer,
local,
&config,
&beobachten,
)
.await;
});
if close_client_after_write {
client_writer.shutdown().await.unwrap();
} else {
client_writer.write_all(b"keepalive").await.unwrap();
tokio::time::sleep(Duration::from_millis(170)).await;
drop(client_writer);
}
let _ = tokio::time::timeout(Duration::from_secs(4), fallback)
.await
.unwrap()
.unwrap();
tokio::time::timeout(Duration::from_secs(4), accept_task)
.await
.unwrap()
.unwrap()
}
#[tokio::test]
async fn aggressive_mode_shapes_backend_silent_non_eof_path() {
let body_sent = 17usize;
let floor = 512usize;
let legacy = capture_forwarded_len_with_mode(body_sent, false, false, false, 0).await;
let aggressive = capture_forwarded_len_with_mode(body_sent, false, true, false, 0).await;
assert!(legacy < floor, "legacy mode should keep timeout path unshaped");
assert!(
aggressive >= floor,
"aggressive mode must shape backend-silent non-EOF paths (aggressive={aggressive}, floor={floor})"
);
}
#[tokio::test]
async fn aggressive_mode_enforces_positive_above_cap_blur() {
let body_sent = 5000usize;
let base = 5 + body_sent;
for _ in 0..48 {
let observed = capture_forwarded_len_with_mode(body_sent, true, true, true, 1).await;
assert!(
observed > base,
"aggressive mode must not emit exact base length when blur is enabled (observed={observed}, base={base})"
);
}
}

View File

@ -1,14 +1,14 @@
use super::*; use super::*;
use crate::config::ProxyConfig; use crate::config::ProxyConfig;
use std::pin::Pin;
use std::sync::Arc; use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::atomic::{AtomicBool, Ordering};
use std::pin::Pin;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
use tokio::io::{duplex, AsyncBufReadExt, BufReader}; use tokio::io::{AsyncBufReadExt, BufReader, duplex};
use tokio::net::TcpListener; use tokio::net::TcpListener;
#[cfg(unix)] #[cfg(unix)]
use tokio::net::UnixListener; use tokio::net::UnixListener;
use tokio::time::{Instant, sleep, timeout, Duration}; use tokio::time::{Duration, Instant, sleep, timeout};
#[tokio::test] #[tokio::test]
async fn bad_client_probe_is_forwarded_verbatim_to_mask_backend() { async fn bad_client_probe_is_forwarded_verbatim_to_mask_backend() {
@ -56,7 +56,10 @@ async fn bad_client_probe_is_forwarded_verbatim_to_mask_backend() {
.await; .await;
let mut observed = vec![0u8; backend_reply.len()]; let mut observed = vec![0u8; backend_reply.len()];
client_visible_reader.read_exact(&mut observed).await.unwrap(); client_visible_reader
.read_exact(&mut observed)
.await
.unwrap();
assert_eq!(observed, backend_reply); assert_eq!(observed, backend_reply);
accept_task.await.unwrap(); accept_task.await.unwrap();
} }
@ -108,7 +111,10 @@ async fn tls_scanner_probe_keeps_http_like_fallback_surface() {
.await; .await;
let mut observed = vec![0u8; backend_reply.len()]; let mut observed = vec![0u8; backend_reply.len()];
client_visible_reader.read_exact(&mut observed).await.unwrap(); client_visible_reader
.read_exact(&mut observed)
.await
.unwrap();
assert_eq!(observed, backend_reply); assert_eq!(observed, backend_reply);
let snapshot = beobachten.snapshot_text(Duration::from_secs(60)); let snapshot = beobachten.snapshot_text(Duration::from_secs(60));
@ -147,8 +153,8 @@ fn build_mask_proxy_header_v2_matches_builder_output() {
let expected = ProxyProtocolV2Builder::new() let expected = ProxyProtocolV2Builder::new()
.with_addrs(peer, local_addr) .with_addrs(peer, local_addr)
.build(); .build();
let actual = build_mask_proxy_header(2, peer, local_addr) let actual =
.expect("v2 mode must produce a header"); build_mask_proxy_header(2, peer, local_addr).expect("v2 mode must produce a header");
assert_eq!(actual, expected, "v2 header bytes must be deterministic"); assert_eq!(actual, expected, "v2 header bytes must be deterministic");
} }
@ -159,8 +165,8 @@ fn build_mask_proxy_header_v1_mixed_ip_family_uses_generic_unknown_form() {
let local_addr: SocketAddr = "[2001:db8::1]:443".parse().unwrap(); let local_addr: SocketAddr = "[2001:db8::1]:443".parse().unwrap();
let expected = ProxyProtocolV1Builder::new().build(); let expected = ProxyProtocolV1Builder::new().build();
let actual = build_mask_proxy_header(1, peer, local_addr) let actual =
.expect("v1 mode must produce a header"); build_mask_proxy_header(1, peer, local_addr).expect("v1 mode must produce a header");
assert_eq!(actual, expected, "mixed-family v1 must use UNKNOWN form"); assert_eq!(actual, expected, "mixed-family v1 must use UNKNOWN form");
} }
@ -197,7 +203,10 @@ async fn beobachten_records_scanner_class_when_mask_is_disabled() {
client_reader_side.write_all(b"noise").await.unwrap(); client_reader_side.write_all(b"noise").await.unwrap();
drop(client_reader_side); drop(client_reader_side);
let beobachten = timeout(Duration::from_secs(3), task).await.unwrap().unwrap(); let beobachten = timeout(Duration::from_secs(3), task)
.await
.unwrap()
.unwrap();
let snapshot = beobachten.snapshot_text(Duration::from_secs(60)); let snapshot = beobachten.snapshot_text(Duration::from_secs(60));
assert!(snapshot.contains("[SSH]")); assert!(snapshot.contains("[SSH]"));
assert!(snapshot.contains("203.0.113.99-1")); assert!(snapshot.contains("203.0.113.99-1"));
@ -241,7 +250,10 @@ async fn backend_unavailable_falls_back_to_silent_consume() {
client_reader_side.write_all(b"noise").await.unwrap(); client_reader_side.write_all(b"noise").await.unwrap();
drop(client_reader_side); drop(client_reader_side);
timeout(Duration::from_secs(3), task).await.unwrap().unwrap(); timeout(Duration::from_secs(3), task)
.await
.unwrap()
.unwrap();
let mut buf = [0u8; 1]; let mut buf = [0u8; 1];
let n = timeout(Duration::from_secs(1), client_visible_reader.read(&mut buf)) let n = timeout(Duration::from_secs(1), client_visible_reader.read(&mut buf))
@ -393,9 +405,9 @@ async fn proxy_header_write_error_on_tcp_path_still_honors_coarse_outcome_budget
.await; .await;
}); });
timeout(Duration::from_millis(35), task) timeout(Duration::from_millis(35), task).await.expect_err(
.await "proxy-header write error path should remain inside coarse masking budget window",
.expect_err("proxy-header write error path should remain inside coarse masking budget window"); );
assert!( assert!(
started.elapsed() >= Duration::from_millis(35), started.elapsed() >= Duration::from_millis(35),
"proxy-header write error path should avoid immediate-return timing signature" "proxy-header write error path should avoid immediate-return timing signature"
@ -450,9 +462,9 @@ async fn proxy_header_write_error_on_unix_path_still_honors_coarse_outcome_budge
.await; .await;
}); });
timeout(Duration::from_millis(35), task) timeout(Duration::from_millis(35), task).await.expect_err(
.await "unix proxy-header write error path should remain inside coarse masking budget window",
.expect_err("unix proxy-header write error path should remain inside coarse masking budget window"); );
assert!( assert!(
started.elapsed() >= Duration::from_millis(35), started.elapsed() >= Duration::from_millis(35),
"unix proxy-header write error path should avoid immediate-return timing signature" "unix proxy-header write error path should avoid immediate-return timing signature"
@ -486,8 +498,14 @@ async fn unix_socket_proxy_protocol_v1_header_is_sent_before_probe() {
let mut header_line = Vec::new(); let mut header_line = Vec::new();
reader.read_until(b'\n', &mut header_line).await.unwrap(); reader.read_until(b'\n', &mut header_line).await.unwrap();
let header_text = String::from_utf8(header_line).unwrap(); let header_text = String::from_utf8(header_line).unwrap();
assert!(header_text.starts_with("PROXY "), "must start with PROXY prefix"); assert!(
assert!(header_text.ends_with("\r\n"), "v1 header must end with CRLF"); header_text.starts_with("PROXY "),
"must start with PROXY prefix"
);
assert!(
header_text.ends_with("\r\n"),
"v1 header must end with CRLF"
);
let mut received_probe = vec![0u8; probe.len()]; let mut received_probe = vec![0u8; probe.len()];
reader.read_exact(&mut received_probe).await.unwrap(); reader.read_exact(&mut received_probe).await.unwrap();
@ -523,7 +541,10 @@ async fn unix_socket_proxy_protocol_v1_header_is_sent_before_probe() {
.await; .await;
let mut observed = vec![0u8; backend_reply.len()]; let mut observed = vec![0u8; backend_reply.len()];
client_visible_reader.read_exact(&mut observed).await.unwrap(); client_visible_reader
.read_exact(&mut observed)
.await
.unwrap();
assert_eq!(observed, backend_reply); assert_eq!(observed, backend_reply);
accept_task.await.unwrap(); accept_task.await.unwrap();
@ -552,7 +573,10 @@ async fn unix_socket_proxy_protocol_v2_header_is_sent_before_probe() {
let mut sig = [0u8; 12]; let mut sig = [0u8; 12];
stream.read_exact(&mut sig).await.unwrap(); stream.read_exact(&mut sig).await.unwrap();
assert_eq!(&sig, b"\r\n\r\n\0\r\nQUIT\n", "v2 signature must match spec"); assert_eq!(
&sig, b"\r\n\r\n\0\r\nQUIT\n",
"v2 signature must match spec"
);
let mut fixed = [0u8; 4]; let mut fixed = [0u8; 4];
stream.read_exact(&mut fixed).await.unwrap(); stream.read_exact(&mut fixed).await.unwrap();
@ -593,7 +617,10 @@ async fn unix_socket_proxy_protocol_v2_header_is_sent_before_probe() {
.await; .await;
let mut observed = vec![0u8; backend_reply.len()]; let mut observed = vec![0u8; backend_reply.len()];
client_visible_reader.read_exact(&mut observed).await.unwrap(); client_visible_reader
.read_exact(&mut observed)
.await
.unwrap();
assert_eq!(observed, backend_reply); assert_eq!(observed, backend_reply);
accept_task.await.unwrap(); accept_task.await.unwrap();
@ -893,10 +920,16 @@ async fn mask_disabled_consumes_client_data_without_response() {
.await; .await;
}); });
client_reader_side.write_all(b"untrusted payload").await.unwrap(); client_reader_side
.write_all(b"untrusted payload")
.await
.unwrap();
drop(client_reader_side); drop(client_reader_side);
timeout(Duration::from_secs(3), task).await.unwrap().unwrap(); timeout(Duration::from_secs(3), task)
.await
.unwrap()
.unwrap();
let mut buf = [0u8; 1]; let mut buf = [0u8; 1];
let n = timeout(Duration::from_secs(1), client_visible_reader.read(&mut buf)) let n = timeout(Duration::from_secs(1), client_visible_reader.read(&mut buf))
@ -962,7 +995,10 @@ async fn proxy_protocol_v1_header_is_sent_before_probe() {
.await; .await;
let mut observed = vec![0u8; backend_reply.len()]; let mut observed = vec![0u8; backend_reply.len()];
client_visible_reader.read_exact(&mut observed).await.unwrap(); client_visible_reader
.read_exact(&mut observed)
.await
.unwrap();
assert_eq!(observed, backend_reply); assert_eq!(observed, backend_reply);
accept_task.await.unwrap(); accept_task.await.unwrap();
} }
@ -1026,7 +1062,10 @@ async fn proxy_protocol_v2_header_is_sent_before_probe() {
.await; .await;
let mut observed = vec![0u8; backend_reply.len()]; let mut observed = vec![0u8; backend_reply.len()];
client_visible_reader.read_exact(&mut observed).await.unwrap(); client_visible_reader
.read_exact(&mut observed)
.await
.unwrap();
assert_eq!(observed, backend_reply); assert_eq!(observed, backend_reply);
accept_task.await.unwrap(); accept_task.await.unwrap();
} }
@ -1086,7 +1125,10 @@ async fn proxy_protocol_v1_mixed_family_falls_back_to_unknown_header() {
.await; .await;
let mut observed = vec![0u8; backend_reply.len()]; let mut observed = vec![0u8; backend_reply.len()];
client_visible_reader.read_exact(&mut observed).await.unwrap(); client_visible_reader
.read_exact(&mut observed)
.await
.unwrap();
assert_eq!(observed, backend_reply); assert_eq!(observed, backend_reply);
accept_task.await.unwrap(); accept_task.await.unwrap();
} }
@ -1094,7 +1136,11 @@ async fn proxy_protocol_v1_mixed_family_falls_back_to_unknown_header() {
#[cfg(unix)] #[cfg(unix)]
#[tokio::test] #[tokio::test]
async fn unix_socket_mask_path_forwards_probe_and_response() { async fn unix_socket_mask_path_forwards_probe_and_response() {
let sock_path = format!("/tmp/telemt-mask-test-{}-{}.sock", std::process::id(), rand::random::<u64>()); let sock_path = format!(
"/tmp/telemt-mask-test-{}-{}.sock",
std::process::id(),
rand::random::<u64>()
);
let _ = std::fs::remove_file(&sock_path); let _ = std::fs::remove_file(&sock_path);
let listener = UnixListener::bind(&sock_path).unwrap(); let listener = UnixListener::bind(&sock_path).unwrap();
@ -1138,7 +1184,10 @@ async fn unix_socket_mask_path_forwards_probe_and_response() {
.await; .await;
let mut observed = vec![0u8; backend_reply.len()]; let mut observed = vec![0u8; backend_reply.len()];
client_visible_reader.read_exact(&mut observed).await.unwrap(); client_visible_reader
.read_exact(&mut observed)
.await
.unwrap();
assert_eq!(observed, backend_reply); assert_eq!(observed, backend_reply);
accept_task.await.unwrap(); accept_task.await.unwrap();
@ -1171,7 +1220,10 @@ async fn mask_disabled_slowloris_connection_is_closed_by_consume_timeout() {
.await; .await;
}); });
timeout(Duration::from_secs(1), task).await.unwrap().unwrap(); timeout(Duration::from_secs(1), task)
.await
.unwrap()
.unwrap();
} }
#[tokio::test] #[tokio::test]
@ -1323,17 +1375,24 @@ async fn relay_to_mask_keeps_backend_to_client_flow_when_client_to_backend_stall
0, 0,
false, false,
0, 0,
false,
) )
.await; .await;
}); });
// Allow relay tasks to start, then emulate mask backend response. // Allow relay tasks to start, then emulate mask backend response.
sleep(Duration::from_millis(20)).await; sleep(Duration::from_millis(20)).await;
backend_feed_writer.write_all(b"HTTP/1.1 200 OK\r\n\r\n").await.unwrap(); backend_feed_writer
.write_all(b"HTTP/1.1 200 OK\r\n\r\n")
.await
.unwrap();
backend_feed_writer.shutdown().await.unwrap(); backend_feed_writer.shutdown().await.unwrap();
let mut observed = vec![0u8; 19]; let mut observed = vec![0u8; 19];
timeout(Duration::from_secs(1), client_visible_reader.read_exact(&mut observed)) timeout(
Duration::from_secs(1),
client_visible_reader.read_exact(&mut observed),
)
.await .await
.unwrap() .unwrap()
.unwrap(); .unwrap();
@ -1394,14 +1453,23 @@ async fn relay_to_mask_preserves_backend_response_after_client_half_close() {
client_write.shutdown().await.unwrap(); client_write.shutdown().await.unwrap();
let mut observed_resp = vec![0u8; response.len()]; let mut observed_resp = vec![0u8; response.len()];
timeout(Duration::from_secs(1), client_visible_reader.read_exact(&mut observed_resp)) timeout(
Duration::from_secs(1),
client_visible_reader.read_exact(&mut observed_resp),
)
.await .await
.unwrap() .unwrap()
.unwrap(); .unwrap();
assert_eq!(observed_resp, response); assert_eq!(observed_resp, response);
timeout(Duration::from_secs(1), fallback_task).await.unwrap().unwrap(); timeout(Duration::from_secs(1), fallback_task)
timeout(Duration::from_secs(1), backend_task).await.unwrap().unwrap(); .await
.unwrap()
.unwrap();
timeout(Duration::from_secs(1), backend_task)
.await
.unwrap()
.unwrap();
} }
#[tokio::test] #[tokio::test]
@ -1437,6 +1505,7 @@ async fn relay_to_mask_timeout_cancels_and_drops_all_io_endpoints() {
0, 0,
false, false,
0, 0,
false,
), ),
) )
.await; .await;
@ -1574,9 +1643,11 @@ async fn timing_matrix_masking_classes_under_controlled_inputs() {
(mean, min, p95, max) (mean, min, p95, max)
} }
let (disabled_mean, disabled_min, disabled_p95, disabled_max) = summarize(&mut disabled_samples); let (disabled_mean, disabled_min, disabled_p95, disabled_max) =
summarize(&mut disabled_samples);
let (refused_mean, refused_min, refused_p95, refused_max) = summarize(&mut refused_samples); let (refused_mean, refused_min, refused_p95, refused_max) = summarize(&mut refused_samples);
let (reachable_mean, reachable_min, reachable_p95, reachable_max) = summarize(&mut reachable_samples); let (reachable_mean, reachable_min, reachable_p95, reachable_max) =
summarize(&mut reachable_samples);
println!( println!(
"TIMING_MATRIX masking class=disabled_eof mean_ms={:.2} min_ms={} p95_ms={} max_ms={} bucket_mean={}", "TIMING_MATRIX masking class=disabled_eof mean_ms={:.2} min_ms={} p95_ms={} max_ms={} bucket_mean={}",
@ -1698,7 +1769,10 @@ async fn reachable_backend_one_response_then_silence_is_cut_by_idle_timeout() {
let elapsed = started.elapsed(); let elapsed = started.elapsed();
let mut observed = vec![0u8; response.len()]; let mut observed = vec![0u8; response.len()];
client_visible_reader.read_exact(&mut observed).await.unwrap(); client_visible_reader
.read_exact(&mut observed)
.await
.unwrap();
assert_eq!(observed, response); assert_eq!(observed, response);
assert!( assert!(
elapsed < Duration::from_millis(190), elapsed < Duration::from_millis(190),
@ -1763,6 +1837,9 @@ async fn adversarial_client_drip_feed_longer_than_idle_timeout_is_cut_off() {
let _ = client_writer_side.write_all(b"X").await; let _ = client_writer_side.write_all(b"X").await;
drop(client_writer_side); drop(client_writer_side);
timeout(Duration::from_secs(1), relay_task).await.unwrap().unwrap(); timeout(Duration::from_secs(1), relay_task)
.await
.unwrap()
.unwrap();
accept_task.await.unwrap(); accept_task.await.unwrap();
} }

View File

@ -1,5 +1,5 @@
use super::*; use super::*;
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio::time::Duration; use tokio::time::Duration;

View File

@ -0,0 +1,182 @@
use super::*;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::net::TcpListener;
use tokio::time::Duration;
async fn capture_forwarded_len_with_optional_eof(
body_sent: usize,
shape_hardening: bool,
above_cap_blur: bool,
above_cap_blur_max_bytes: usize,
close_client_after_write: bool,
) -> usize {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = backend_addr.port();
config.censorship.mask_shape_hardening = shape_hardening;
config.censorship.mask_shape_bucket_floor_bytes = 512;
config.censorship.mask_shape_bucket_cap_bytes = 4096;
config.censorship.mask_shape_above_cap_blur = above_cap_blur;
config.censorship.mask_shape_above_cap_blur_max_bytes = above_cap_blur_max_bytes;
let accept_task = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut got = Vec::new();
let _ = tokio::time::timeout(Duration::from_secs(2), stream.read_to_end(&mut got)).await;
got.len()
});
let (server_reader, mut client_writer) = duplex(64 * 1024);
let (_client_visible_reader, client_visible_writer) = duplex(64 * 1024);
let peer: SocketAddr = "198.51.100.241:57241".parse().unwrap();
let local: SocketAddr = "127.0.0.1:443".parse().unwrap();
let beobachten = BeobachtenStore::new();
let mut probe = vec![0u8; 5 + body_sent];
probe[0] = 0x16;
probe[1] = 0x03;
probe[2] = 0x01;
probe[3..5].copy_from_slice(&7000u16.to_be_bytes());
probe[5..].fill(0x73);
let fallback = tokio::spawn(async move {
handle_bad_client(
server_reader,
client_visible_writer,
&probe,
peer,
local,
&config,
&beobachten,
)
.await;
});
if close_client_after_write {
client_writer.shutdown().await.unwrap();
} else {
client_writer.write_all(b"keepalive").await.unwrap();
tokio::time::sleep(Duration::from_millis(170)).await;
drop(client_writer);
}
let _ = tokio::time::timeout(Duration::from_secs(4), fallback)
.await
.unwrap()
.unwrap();
tokio::time::timeout(Duration::from_secs(4), accept_task)
.await
.unwrap()
.unwrap()
}
#[tokio::test]
#[ignore = "red-team detector: shaping on non-EOF timeout path is disabled by design to prevent post-timeout tail leaks"]
async fn security_shape_padding_applies_without_client_eof_when_backend_silent() {
let body_sent = 17usize;
let hardened_floor = 512usize;
let with_eof = capture_forwarded_len_with_optional_eof(body_sent, true, false, 0, true).await;
let without_eof =
capture_forwarded_len_with_optional_eof(body_sent, true, false, 0, false).await;
assert!(
with_eof >= hardened_floor,
"EOF path should be shaped to floor (with_eof={with_eof}, floor={hardened_floor})"
);
assert!(
without_eof >= hardened_floor,
"non-EOF path should also be shaped when backend is silent (without_eof={without_eof}, floor={hardened_floor})"
);
}
#[tokio::test]
#[ignore = "red-team detector: blur currently allows zero-extra sample by design within [0..=max] bound"]
async fn security_above_cap_blur_never_emits_exact_base_length() {
let body_sent = 5000usize;
let base = 5 + body_sent;
let max_blur = 1usize;
for _ in 0..64 {
let observed =
capture_forwarded_len_with_optional_eof(body_sent, true, true, max_blur, true).await;
assert!(
observed > base,
"above-cap blur must add at least one byte when enabled (observed={observed}, base={base})"
);
}
}
#[tokio::test]
#[ignore = "red-team detector: shape padding currently depends on EOF, enabling idle-timeout bypass probes"]
async fn redteam_detector_shape_padding_must_not_depend_on_client_eof() {
let body_sent = 17usize;
let hardened_floor = 512usize;
let with_eof = capture_forwarded_len_with_optional_eof(body_sent, true, false, 0, true).await;
let without_eof =
capture_forwarded_len_with_optional_eof(body_sent, true, false, 0, false).await;
assert!(
with_eof >= hardened_floor,
"sanity check failed: EOF path should be shaped to floor (with_eof={with_eof}, floor={hardened_floor})"
);
assert!(
without_eof >= hardened_floor,
"strict anti-probing model expects shaping even without EOF; observed without_eof={without_eof}, floor={hardened_floor}"
);
}
#[tokio::test]
#[ignore = "red-team detector: zero-extra above-cap blur samples leak exact class boundary"]
async fn redteam_detector_above_cap_blur_must_never_emit_exact_base_length() {
let body_sent = 5000usize;
let base = 5 + body_sent;
let mut saw_exact_base = false;
let max_blur = 1usize;
for _ in 0..96 {
let observed =
capture_forwarded_len_with_optional_eof(body_sent, true, true, max_blur, true).await;
if observed == base {
saw_exact_base = true;
break;
}
}
assert!(
!saw_exact_base,
"strict anti-classifier model expects >0 blur always; observed exact base length leaks class"
);
}
#[tokio::test]
#[ignore = "red-team detector: disjoint above-cap ranges enable near-perfect size-class classification"]
async fn redteam_detector_above_cap_blur_ranges_for_far_classes_should_overlap() {
let mut a_min = usize::MAX;
let mut a_max = 0usize;
let mut b_min = usize::MAX;
let mut b_max = 0usize;
for _ in 0..48 {
let a = capture_forwarded_len_with_optional_eof(5000, true, true, 64, true).await;
let b = capture_forwarded_len_with_optional_eof(7000, true, true, 64, true).await;
a_min = a_min.min(a);
a_max = a_max.max(a);
b_min = b_min.min(b);
b_max = b_max.max(b);
}
let overlap = a_min <= b_max && b_min <= a_max;
assert!(
overlap,
"strict anti-classifier model expects overlapping output bands; class_a=[{a_min},{a_max}] class_b=[{b_min},{b_max}]"
);
}

View File

@ -1,5 +1,5 @@
use super::*; use super::*;
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio::time::Duration; use tokio::time::Duration;
@ -90,9 +90,7 @@ fn nearest_centroid_classifier_accuracy(
samples_b: &[usize], samples_b: &[usize],
samples_c: &[usize], samples_c: &[usize],
) -> f64 { ) -> f64 {
let mean = |xs: &[usize]| -> f64 { let mean = |xs: &[usize]| -> f64 { xs.iter().copied().sum::<usize>() as f64 / xs.len() as f64 };
xs.iter().copied().sum::<usize>() as f64 / xs.len() as f64
};
let ca = mean(samples_a); let ca = mean(samples_a);
let cb = mean(samples_b); let cb = mean(samples_b);
@ -104,11 +102,7 @@ fn nearest_centroid_classifier_accuracy(
for &x in samples_a { for &x in samples_a {
total += 1; total += 1;
let xf = x as f64; let xf = x as f64;
let d = [ let d = [(xf - ca).abs(), (xf - cb).abs(), (xf - cc).abs()];
(xf - ca).abs(),
(xf - cb).abs(),
(xf - cc).abs(),
];
if d[0] <= d[1] && d[0] <= d[2] { if d[0] <= d[1] && d[0] <= d[2] {
correct += 1; correct += 1;
} }
@ -117,11 +111,7 @@ fn nearest_centroid_classifier_accuracy(
for &x in samples_b { for &x in samples_b {
total += 1; total += 1;
let xf = x as f64; let xf = x as f64;
let d = [ let d = [(xf - ca).abs(), (xf - cb).abs(), (xf - cc).abs()];
(xf - ca).abs(),
(xf - cb).abs(),
(xf - cc).abs(),
];
if d[1] <= d[0] && d[1] <= d[2] { if d[1] <= d[0] && d[1] <= d[2] {
correct += 1; correct += 1;
} }
@ -130,11 +120,7 @@ fn nearest_centroid_classifier_accuracy(
for &x in samples_c { for &x in samples_c {
total += 1; total += 1;
let xf = x as f64; let xf = x as f64;
let d = [ let d = [(xf - ca).abs(), (xf - cb).abs(), (xf - cc).abs()];
(xf - ca).abs(),
(xf - cb).abs(),
(xf - cc).abs(),
];
if d[2] <= d[0] && d[2] <= d[1] { if d[2] <= d[0] && d[2] <= d[1] {
correct += 1; correct += 1;
} }
@ -166,7 +152,10 @@ async fn masking_shape_classifier_resistance_blur_reduces_threshold_attack_accur
let hardened_acc = best_threshold_accuracy(&hardened_a, &hardened_b); let hardened_acc = best_threshold_accuracy(&hardened_a, &hardened_b);
// Baseline classes are deterministic/non-overlapping -> near-perfect threshold attack. // Baseline classes are deterministic/non-overlapping -> near-perfect threshold attack.
assert!(baseline_acc >= 0.99, "baseline separability unexpectedly low: {baseline_acc:.3}"); assert!(
baseline_acc >= 0.99,
"baseline separability unexpectedly low: {baseline_acc:.3}"
);
// Blur must materially reduce the best one-dimensional length classifier. // Blur must materially reduce the best one-dimensional length classifier.
assert!( assert!(
hardened_acc <= 0.90, hardened_acc <= 0.90,
@ -247,7 +236,11 @@ async fn masking_shape_classifier_resistance_edge_max_extra_one_has_two_point_su
seen.insert(observed); seen.insert(observed);
} }
assert_eq!(seen.len(), 2, "both support points should appear under repeated sampling"); assert_eq!(
seen.len(),
2,
"both support points should appear under repeated sampling"
);
} }
#[tokio::test] #[tokio::test]
@ -262,13 +255,25 @@ async fn masking_shape_classifier_resistance_negative_blur_without_shape_hardeni
bs_observed.insert(capture_forwarded_len(BODY_B, false, true, 96).await); bs_observed.insert(capture_forwarded_len(BODY_B, false, true, 96).await);
} }
assert_eq!(as_observed.len(), 1, "without shape hardening class A must stay deterministic"); assert_eq!(
assert_eq!(bs_observed.len(), 1, "without shape hardening class B must stay deterministic"); as_observed.len(),
assert_ne!(as_observed, bs_observed, "distinct classes should remain separable without shaping"); 1,
"without shape hardening class A must stay deterministic"
);
assert_eq!(
bs_observed.len(),
1,
"without shape hardening class B must stay deterministic"
);
assert_ne!(
as_observed, bs_observed,
"distinct classes should remain separable without shaping"
);
} }
#[tokio::test] #[tokio::test]
async fn masking_shape_classifier_resistance_adversarial_three_class_centroid_attack_degrades_with_blur() { async fn masking_shape_classifier_resistance_adversarial_three_class_centroid_attack_degrades_with_blur()
{
const SAMPLES: usize = 80; const SAMPLES: usize = 80;
const MAX_EXTRA: usize = 96; const MAX_EXTRA: usize = 96;
const C1: usize = 5000; const C1: usize = 5000;
@ -295,13 +300,23 @@ async fn masking_shape_classifier_resistance_adversarial_three_class_centroid_at
let base_acc = nearest_centroid_classifier_accuracy(&base1, &base2, &base3); let base_acc = nearest_centroid_classifier_accuracy(&base1, &base2, &base3);
let hard_acc = nearest_centroid_classifier_accuracy(&hard1, &hard2, &hard3); let hard_acc = nearest_centroid_classifier_accuracy(&hard1, &hard2, &hard3);
assert!(base_acc >= 0.99, "baseline centroid separability should be near-perfect"); assert!(
assert!(hard_acc <= 0.88, "blur should materially degrade 3-class centroid attack"); base_acc >= 0.99,
assert!(hard_acc <= base_acc - 0.1, "accuracy drop should be meaningful"); "baseline centroid separability should be near-perfect"
);
assert!(
hard_acc <= 0.88,
"blur should materially degrade 3-class centroid attack"
);
assert!(
hard_acc <= base_acc - 0.1,
"accuracy drop should be meaningful"
);
} }
#[tokio::test] #[tokio::test]
async fn masking_shape_classifier_resistance_light_fuzz_bounds_hold_for_randomized_above_cap_campaign() { async fn masking_shape_classifier_resistance_light_fuzz_bounds_hold_for_randomized_above_cap_campaign()
{
let mut s: u64 = 0xDEAD_BEEF_CAFE_BABE; let mut s: u64 = 0xDEAD_BEEF_CAFE_BABE;
for _ in 0..96 { for _ in 0..96 {
s ^= s << 7; s ^= s << 7;

View File

@ -1,6 +1,6 @@
use super::*; use super::*;
use tokio::io::{duplex, empty, sink, AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex, empty, sink};
use tokio::time::{sleep, timeout, Duration}; use tokio::time::{Duration, sleep, timeout};
fn oracle_len( fn oracle_len(
total_sent: usize, total_sent: usize,
@ -42,6 +42,7 @@ async fn run_relay_case(
cap, cap,
above_cap_blur, above_cap_blur,
above_cap_blur_max_bytes, above_cap_blur_max_bytes,
false,
) )
.await; .await;
}); });
@ -54,14 +55,20 @@ async fn run_relay_case(
client_writer.shutdown().await.unwrap(); client_writer.shutdown().await.unwrap();
} }
timeout(Duration::from_secs(2), relay).await.unwrap().unwrap(); timeout(Duration::from_secs(2), relay)
.await
.unwrap()
.unwrap();
if !close_client { if !close_client {
drop(client_writer); drop(client_writer);
} }
let mut observed = Vec::new(); let mut observed = Vec::new();
timeout(Duration::from_secs(2), mask_observer.read_to_end(&mut observed)) timeout(
Duration::from_secs(2),
mask_observer.read_to_end(&mut observed),
)
.await .await
.unwrap() .unwrap()
.unwrap(); .unwrap();
@ -97,12 +104,29 @@ async fn masking_shape_guard_positive_clean_eof_path_shapes_and_preserves_prefix
let extra = vec![0x55; 300]; let extra = vec![0x55; 300];
let total = initial.len() + extra.len(); let total = initial.len() + extra.len();
let observed = run_relay_case(initial.clone(), extra.clone(), true, true, 512, 4096, false, 0).await; let observed = run_relay_case(
initial.clone(),
extra.clone(),
true,
true,
512,
4096,
false,
0,
)
.await;
let expected_len = oracle_len(total, true, true, initial.len(), 512, 4096); let expected_len = oracle_len(total, true, true, initial.len(), 512, 4096);
assert_eq!(observed.len(), expected_len, "clean EOF path must be bucket-shaped"); assert_eq!(
observed.len(),
expected_len,
"clean EOF path must be bucket-shaped"
);
assert_eq!(&observed[..initial.len()], initial.as_slice()); assert_eq!(&observed[..initial.len()], initial.as_slice());
assert_eq!(&observed[initial.len()..(initial.len() + extra.len())], extra.as_slice()); assert_eq!(
&observed[initial.len()..(initial.len() + extra.len())],
extra.as_slice()
);
} }
#[tokio::test] #[tokio::test]
@ -112,7 +136,11 @@ async fn masking_shape_guard_edge_empty_initial_remains_transparent_under_clean_
let observed = run_relay_case(initial, extra.clone(), true, true, 512, 4096, false, 0).await; let observed = run_relay_case(initial, extra.clone(), true, true, 512, 4096, false, 0).await;
assert_eq!(observed.len(), extra.len(), "empty initial_data must never trigger shaping"); assert_eq!(
observed.len(),
extra.len(),
"empty initial_data must never trigger shaping"
);
assert_eq!(observed, extra); assert_eq!(observed, extra);
} }
@ -212,13 +240,19 @@ async fn masking_shape_guard_stress_parallel_mixed_sessions_keep_oracle_and_no_h
assert_eq!(&observed[..initial_len], initial.as_slice()); assert_eq!(&observed[..initial_len], initial.as_slice());
} }
if extra_len > 0 { if extra_len > 0 {
assert_eq!(&observed[initial_len..(initial_len + extra_len)], extra.as_slice()); assert_eq!(
&observed[initial_len..(initial_len + extra_len)],
extra.as_slice()
);
} }
})); }));
} }
for task in tasks { for task in tasks {
timeout(Duration::from_secs(3), task).await.unwrap().unwrap(); timeout(Duration::from_secs(3), task)
.await
.unwrap()
.unwrap();
} }
} }
@ -238,7 +272,10 @@ async fn masking_shape_guard_integration_slow_drip_timeout_is_cut_without_tail_l
let mut one = [0u8; 1]; let mut one = [0u8; 1];
let r = timeout(Duration::from_millis(220), stream.read_exact(&mut one)).await; let r = timeout(Duration::from_millis(220), stream.read_exact(&mut one)).await;
assert!(r.is_err() || r.unwrap().is_err(), "no post-timeout drip/tail may reach backend"); assert!(
r.is_err() || r.unwrap().is_err(),
"no post-timeout drip/tail may reach backend"
);
} }
}); });
@ -274,8 +311,14 @@ async fn masking_shape_guard_integration_slow_drip_timeout_is_cut_without_tail_l
sleep(Duration::from_millis(160)).await; sleep(Duration::from_millis(160)).await;
let _ = client_writer.write_all(b"X").await; let _ = client_writer.write_all(b"X").await;
timeout(Duration::from_secs(2), relay).await.unwrap().unwrap(); timeout(Duration::from_secs(2), relay)
timeout(Duration::from_secs(2), accept_task).await.unwrap().unwrap(); .await
.unwrap()
.unwrap();
timeout(Duration::from_secs(2), accept_task)
.await
.unwrap()
.unwrap();
} }
#[tokio::test] #[tokio::test]
@ -352,7 +395,10 @@ async fn masking_shape_guard_above_cap_blur_parallel_stress_keeps_bounds() {
} }
for task in tasks { for task in tasks {
timeout(Duration::from_secs(3), task).await.unwrap().unwrap(); timeout(Duration::from_secs(3), task)
.await
.unwrap()
.unwrap();
} }
} }

View File

@ -1,7 +1,7 @@
use super::*; use super::*;
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio::time::{timeout, Duration}; use tokio::time::{Duration, timeout};
#[tokio::test] #[tokio::test]
async fn shape_guard_empty_initial_data_keeps_transparent_length_on_clean_eof() { async fn shape_guard_empty_initial_data_keeps_transparent_length_on_clean_eof() {
@ -15,7 +15,10 @@ async fn shape_guard_empty_initial_data_keeps_transparent_length_on_clean_eof()
let (mut stream, _) = listener.accept().await.unwrap(); let (mut stream, _) = listener.accept().await.unwrap();
let mut got = Vec::new(); let mut got = Vec::new();
stream.read_to_end(&mut got).await.unwrap(); stream.read_to_end(&mut got).await.unwrap();
assert_eq!(got, expected, "empty initial_data path must not inject shape padding"); assert_eq!(
got, expected,
"empty initial_data path must not inject shape padding"
);
} }
}); });
@ -51,8 +54,14 @@ async fn shape_guard_empty_initial_data_keeps_transparent_length_on_clean_eof()
client_writer.write_all(&client_payload).await.unwrap(); client_writer.write_all(&client_payload).await.unwrap();
client_writer.shutdown().await.unwrap(); client_writer.shutdown().await.unwrap();
timeout(Duration::from_secs(2), relay_task).await.unwrap().unwrap(); timeout(Duration::from_secs(2), relay_task)
timeout(Duration::from_secs(2), accept_task).await.unwrap().unwrap(); .await
.unwrap()
.unwrap();
timeout(Duration::from_secs(2), accept_task)
.await
.unwrap()
.unwrap();
} }
#[tokio::test] #[tokio::test]
@ -105,7 +114,10 @@ async fn shape_guard_timeout_exit_does_not_append_padding_after_initial_probe()
) )
.await; .await;
timeout(Duration::from_secs(2), accept_task).await.unwrap().unwrap(); timeout(Duration::from_secs(2), accept_task)
.await
.unwrap()
.unwrap();
} }
#[tokio::test] #[tokio::test]
@ -126,7 +138,11 @@ async fn shape_guard_clean_eof_with_nonempty_initial_still_applies_bucket_paddin
let expected_prefix_len = initial.len() + extra.len(); let expected_prefix_len = initial.len() + extra.len();
assert_eq!(&got[..initial.len()], initial.as_slice()); assert_eq!(&got[..initial.len()], initial.as_slice());
assert_eq!(&got[initial.len()..expected_prefix_len], extra.as_slice()); assert_eq!(&got[initial.len()..expected_prefix_len], extra.as_slice());
assert_eq!(got.len(), 512, "clean EOF path should still shape to floor bucket"); assert_eq!(
got.len(),
512,
"clean EOF path should still shape to floor bucket"
);
} }
}); });
@ -162,6 +178,12 @@ async fn shape_guard_clean_eof_with_nonempty_initial_still_applies_bucket_paddin
client_writer.write_all(&extra).await.unwrap(); client_writer.write_all(&extra).await.unwrap();
client_writer.shutdown().await.unwrap(); client_writer.shutdown().await.unwrap();
timeout(Duration::from_secs(2), relay_task).await.unwrap().unwrap(); timeout(Duration::from_secs(2), relay_task)
timeout(Duration::from_secs(2), accept_task).await.unwrap().unwrap(); .await
.unwrap()
.unwrap();
timeout(Duration::from_secs(2), accept_task)
.await
.unwrap()
.unwrap();
} }

View File

@ -1,5 +1,5 @@
use super::*; use super::*;
use tokio::io::{duplex, empty, sink, AsyncReadExt, AsyncWrite}; use tokio::io::{AsyncReadExt, AsyncWrite, duplex, empty, sink};
struct CountingWriter { struct CountingWriter {
written: usize, written: usize,
@ -46,21 +46,24 @@ fn shape_bucket_clamps_to_cap_when_next_power_of_two_exceeds_cap() {
fn shape_bucket_never_drops_below_total_for_valid_ranges() { fn shape_bucket_never_drops_below_total_for_valid_ranges() {
for total in [1usize, 32, 127, 512, 999, 1000, 1001, 1499, 1500, 1501] { for total in [1usize, 32, 127, 512, 999, 1000, 1001, 1499, 1500, 1501] {
let bucket = next_mask_shape_bucket(total, 1000, 1500); let bucket = next_mask_shape_bucket(total, 1000, 1500);
assert!(bucket >= total || total >= 1500, "bucket={bucket} total={total}"); assert!(
bucket >= total || total >= 1500,
"bucket={bucket} total={total}"
);
} }
} }
#[tokio::test] #[tokio::test]
async fn maybe_write_shape_padding_writes_exact_delta() { async fn maybe_write_shape_padding_writes_exact_delta() {
let mut writer = CountingWriter::new(); let mut writer = CountingWriter::new();
maybe_write_shape_padding(&mut writer, 1200, true, 1000, 1500, false, 0).await; maybe_write_shape_padding(&mut writer, 1200, true, 1000, 1500, false, 0, false).await;
assert_eq!(writer.written, 300); assert_eq!(writer.written, 300);
} }
#[tokio::test] #[tokio::test]
async fn maybe_write_shape_padding_skips_when_disabled() { async fn maybe_write_shape_padding_skips_when_disabled() {
let mut writer = CountingWriter::new(); let mut writer = CountingWriter::new();
maybe_write_shape_padding(&mut writer, 1200, false, 1000, 1500, false, 0).await; maybe_write_shape_padding(&mut writer, 1200, false, 1000, 1500, false, 0, false).await;
assert_eq!(writer.written, 0); assert_eq!(writer.written, 0);
} }
@ -84,6 +87,7 @@ async fn relay_to_mask_applies_cap_clamped_padding_for_non_power_of_two_cap() {
1500, 1500,
false, false,
0, 0,
false,
) )
.await; .await;
}); });

View File

@ -115,6 +115,12 @@ async fn timing_normalization_does_not_sleep_if_path_already_exceeds_ceiling() {
let slow = measure_bad_client_duration_ms(MaskPath::SlowBackend, floor, ceiling).await; let slow = measure_bad_client_duration_ms(MaskPath::SlowBackend, floor, ceiling).await;
assert!(slow >= 280, "slow backend path should remain slow (got {slow}ms)"); assert!(
assert!(slow <= 520, "slow backend path should remain bounded in tests (got {slow}ms)"); slow >= 280,
"slow backend path should remain slow (got {slow}ms)"
);
assert!(
slow <= 520,
"slow backend path should remain bounded in tests (got {slow}ms)"
);
} }

View File

@ -0,0 +1,112 @@
use super::*;
use crate::stats::Stats;
use dashmap::DashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use tokio::sync::Barrier;
use tokio::time::{Duration, timeout};
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn blackhat_campaign_saturation_quota_race_with_queue_pressure_stays_fail_closed() {
let _guard = super::quota_user_lock_test_scope();
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
for idx in 0..QUOTA_USER_LOCKS_MAX {
retained.push(quota_user_lock(&format!(
"middle-blackhat-held-{}-{idx}",
std::process::id()
)));
}
assert_eq!(
map.len(),
QUOTA_USER_LOCKS_MAX,
"precondition: bounded lock cache must be saturated"
);
let (tx, _rx) = mpsc::channel::<C2MeCommand>(1);
tx.send(C2MeCommand::Close)
.await
.expect("queue prefill should succeed");
let pressure_seq_before = relay_pressure_event_seq();
let pressure_errors = Arc::new(AtomicUsize::new(0));
let mut pressure_workers = Vec::new();
for _ in 0..16 {
let tx = tx.clone();
let pressure_errors = Arc::clone(&pressure_errors);
pressure_workers.push(tokio::spawn(async move {
if enqueue_c2me_command(&tx, C2MeCommand::Close).await.is_err() {
pressure_errors.fetch_add(1, Ordering::Relaxed);
}
}));
}
let stats = Arc::new(Stats::new());
let user = format!("middle-blackhat-quota-race-{}", std::process::id());
let gate = Arc::new(Barrier::new(16));
let mut quota_workers = Vec::new();
for _ in 0..16u8 {
let stats = Arc::clone(&stats);
let user = user.clone();
let gate = Arc::clone(&gate);
quota_workers.push(tokio::spawn(async move {
gate.wait().await;
let user_lock = quota_user_lock(&user);
let _quota_guard = user_lock.lock().await;
if quota_would_be_exceeded_for_user(&stats, &user, Some(1), 1) {
return false;
}
stats.add_user_octets_to(&user, 1);
true
}));
}
let mut ok_count = 0usize;
let mut denied_count = 0usize;
for worker in quota_workers {
let result = timeout(Duration::from_secs(2), worker)
.await
.expect("quota worker must finish")
.expect("quota worker must not panic");
if result {
ok_count += 1;
} else {
denied_count += 1;
}
}
for worker in pressure_workers {
timeout(Duration::from_secs(2), worker)
.await
.expect("pressure worker must finish")
.expect("pressure worker must not panic");
}
assert_eq!(
stats.get_user_total_octets(&user),
1,
"black-hat campaign must not overshoot same-user quota under saturation"
);
assert!(ok_count <= 1, "at most one quota contender may succeed");
assert!(
denied_count >= 15,
"all remaining contenders must be quota-denied"
);
let pressure_seq_after = relay_pressure_event_seq();
assert!(
pressure_seq_after > pressure_seq_before,
"queue pressure leg must trigger pressure accounting"
);
assert!(
pressure_errors.load(Ordering::Relaxed) >= 1,
"at least one pressure worker should fail from persistent backpressure"
);
drop(retained);
}

View File

@ -0,0 +1,708 @@
use super::*;
use crate::crypto::AesCtr;
use crate::crypto::SecureRandom;
use crate::stats::Stats;
use crate::stream::{BufferPool, PooledBuffer};
use std::sync::Arc;
use tokio::io::AsyncReadExt;
use tokio::io::duplex;
use tokio::sync::mpsc;
use tokio::time::{Duration as TokioDuration, timeout};
fn make_pooled_payload(data: &[u8]) -> PooledBuffer {
let pool = Arc::new(BufferPool::with_config(data.len().max(1), 4));
let mut payload = pool.get();
payload.resize(data.len(), 0);
payload[..data.len()].copy_from_slice(data);
payload
}
#[tokio::test]
async fn write_client_payload_abridged_short_quickack_sets_flag_and_preserves_payload() {
let (mut read_side, write_side) = duplex(4096);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
let mut decryptor = AesCtr::new(&key, iv);
let rng = SecureRandom::new();
let mut frame_buf = Vec::new();
let payload = vec![0xA1, 0xB2, 0xC3, 0xD4, 0x10, 0x20, 0x30, 0x40];
write_client_payload(
&mut writer,
ProtoTag::Abridged,
RPC_FLAG_QUICKACK,
&payload,
&rng,
&mut frame_buf,
)
.await
.expect("abridged quickack payload should serialize");
writer.flush().await.expect("flush must succeed");
let mut encrypted = vec![0u8; 1 + payload.len()];
read_side
.read_exact(&mut encrypted)
.await
.expect("must read serialized abridged frame");
let plaintext = decryptor.decrypt(&encrypted);
assert_eq!(plaintext[0], 0x80 | ((payload.len() / 4) as u8));
assert_eq!(&plaintext[1..], payload.as_slice());
}
#[tokio::test]
async fn write_client_payload_abridged_extended_header_is_encoded_correctly() {
let (mut read_side, write_side) = duplex(16 * 1024);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
let mut decryptor = AesCtr::new(&key, iv);
let rng = SecureRandom::new();
let mut frame_buf = Vec::new();
// Boundary where abridged switches to extended length encoding.
let payload = vec![0x5Au8; 0x7f * 4];
write_client_payload(
&mut writer,
ProtoTag::Abridged,
RPC_FLAG_QUICKACK,
&payload,
&rng,
&mut frame_buf,
)
.await
.expect("extended abridged payload should serialize");
writer.flush().await.expect("flush must succeed");
let mut encrypted = vec![0u8; 4 + payload.len()];
read_side
.read_exact(&mut encrypted)
.await
.expect("must read serialized extended abridged frame");
let plaintext = decryptor.decrypt(&encrypted);
assert_eq!(plaintext[0], 0xff, "0x7f with quickack bit must be set");
assert_eq!(&plaintext[1..4], &[0x7f, 0x00, 0x00]);
assert_eq!(&plaintext[4..], payload.as_slice());
}
#[tokio::test]
async fn write_client_payload_abridged_misaligned_is_rejected_fail_closed() {
let (_read_side, write_side) = duplex(1024);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
let rng = SecureRandom::new();
let mut frame_buf = Vec::new();
let err = write_client_payload(
&mut writer,
ProtoTag::Abridged,
0,
&[1, 2, 3],
&rng,
&mut frame_buf,
)
.await
.expect_err("misaligned abridged payload must be rejected");
let msg = format!("{err}");
assert!(
msg.contains("4-byte aligned"),
"error should explain alignment contract, got: {msg}"
);
}
#[tokio::test]
async fn write_client_payload_secure_misaligned_is_rejected_fail_closed() {
let (_read_side, write_side) = duplex(1024);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
let rng = SecureRandom::new();
let mut frame_buf = Vec::new();
let err = write_client_payload(
&mut writer,
ProtoTag::Secure,
0,
&[9, 8, 7, 6, 5],
&rng,
&mut frame_buf,
)
.await
.expect_err("misaligned secure payload must be rejected");
let msg = format!("{err}");
assert!(
msg.contains("Secure payload must be 4-byte aligned"),
"error should be explicit for fail-closed triage, got: {msg}"
);
}
#[tokio::test]
async fn write_client_payload_intermediate_quickack_sets_length_msb() {
let (mut read_side, write_side) = duplex(4096);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
let mut decryptor = AesCtr::new(&key, iv);
let rng = SecureRandom::new();
let mut frame_buf = Vec::new();
let payload = b"hello-middle-relay";
write_client_payload(
&mut writer,
ProtoTag::Intermediate,
RPC_FLAG_QUICKACK,
payload,
&rng,
&mut frame_buf,
)
.await
.expect("intermediate quickack payload should serialize");
writer.flush().await.expect("flush must succeed");
let mut encrypted = vec![0u8; 4 + payload.len()];
read_side
.read_exact(&mut encrypted)
.await
.expect("must read intermediate frame");
let plaintext = decryptor.decrypt(&encrypted);
let mut len_bytes = [0u8; 4];
len_bytes.copy_from_slice(&plaintext[..4]);
let len_with_flags = u32::from_le_bytes(len_bytes);
assert_ne!(len_with_flags & 0x8000_0000, 0, "quickack bit must be set");
assert_eq!((len_with_flags & 0x7fff_ffff) as usize, payload.len());
assert_eq!(&plaintext[4..], payload);
}
#[tokio::test]
async fn write_client_payload_secure_quickack_prefix_and_padding_bounds_hold() {
let (mut read_side, write_side) = duplex(4096);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
let mut decryptor = AesCtr::new(&key, iv);
let rng = SecureRandom::new();
let mut frame_buf = Vec::new();
let payload = vec![0x33u8; 100]; // 4-byte aligned as required by secure mode.
write_client_payload(
&mut writer,
ProtoTag::Secure,
RPC_FLAG_QUICKACK,
&payload,
&rng,
&mut frame_buf,
)
.await
.expect("secure quickack payload should serialize");
writer.flush().await.expect("flush must succeed");
// Secure mode adds 1..=3 bytes of randomized tail padding.
let mut encrypted_header = [0u8; 4];
read_side
.read_exact(&mut encrypted_header)
.await
.expect("must read secure header");
let decrypted_header = decryptor.decrypt(&encrypted_header);
let header: [u8; 4] = decrypted_header
.try_into()
.expect("decrypted secure header must be 4 bytes");
let wire_len_raw = u32::from_le_bytes(header);
assert_ne!(
wire_len_raw & 0x8000_0000,
0,
"secure quickack bit must be set"
);
let wire_len = (wire_len_raw & 0x7fff_ffff) as usize;
assert!(wire_len >= payload.len());
let padding_len = wire_len - payload.len();
assert!(
(1..=3).contains(&padding_len),
"secure writer must add bounded random tail padding, got {padding_len}"
);
let mut encrypted_body = vec![0u8; wire_len];
read_side
.read_exact(&mut encrypted_body)
.await
.expect("must read secure body");
let decrypted_body = decryptor.decrypt(&encrypted_body);
assert_eq!(&decrypted_body[..payload.len()], payload.as_slice());
}
#[tokio::test]
#[ignore = "heavy: allocates >64MiB to validate abridged too-large fail-closed branch"]
async fn write_client_payload_abridged_too_large_is_rejected_fail_closed() {
let (_read_side, write_side) = duplex(1024);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
let rng = SecureRandom::new();
let mut frame_buf = Vec::new();
// Exactly one 4-byte word above the encodable 24-bit abridged length range.
let payload = vec![0x00u8; (1 << 24) * 4];
let err = write_client_payload(
&mut writer,
ProtoTag::Abridged,
0,
&payload,
&rng,
&mut frame_buf,
)
.await
.expect_err("oversized abridged payload must be rejected");
let msg = format!("{err}");
assert!(
msg.contains("Abridged frame too large"),
"error must clearly indicate oversize fail-close path, got: {msg}"
);
}
#[tokio::test]
async fn write_client_ack_intermediate_is_little_endian() {
let (mut read_side, write_side) = duplex(1024);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
let mut decryptor = AesCtr::new(&key, iv);
write_client_ack(&mut writer, ProtoTag::Intermediate, 0x11_22_33_44)
.await
.expect("ack serialization should succeed");
writer.flush().await.expect("flush must succeed");
let mut encrypted = [0u8; 4];
read_side
.read_exact(&mut encrypted)
.await
.expect("must read ack bytes");
let plain = decryptor.decrypt(&encrypted);
assert_eq!(plain.as_slice(), &0x11_22_33_44u32.to_le_bytes());
}
#[tokio::test]
async fn write_client_ack_abridged_is_big_endian() {
let (mut read_side, write_side) = duplex(1024);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
let mut decryptor = AesCtr::new(&key, iv);
write_client_ack(&mut writer, ProtoTag::Abridged, 0xDE_AD_BE_EF)
.await
.expect("ack serialization should succeed");
writer.flush().await.expect("flush must succeed");
let mut encrypted = [0u8; 4];
read_side
.read_exact(&mut encrypted)
.await
.expect("must read ack bytes");
let plain = decryptor.decrypt(&encrypted);
assert_eq!(plain.as_slice(), &0xDE_AD_BE_EFu32.to_be_bytes());
}
#[tokio::test]
async fn write_client_payload_abridged_short_boundary_0x7e_is_single_byte_header() {
let (mut read_side, write_side) = duplex(1024 * 1024);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
let mut decryptor = AesCtr::new(&key, iv);
let rng = SecureRandom::new();
let mut frame_buf = Vec::new();
let payload = vec![0xABu8; 0x7e * 4];
write_client_payload(
&mut writer,
ProtoTag::Abridged,
0,
&payload,
&rng,
&mut frame_buf,
)
.await
.expect("boundary payload should serialize");
writer.flush().await.expect("flush must succeed");
let mut encrypted = vec![0u8; 1 + payload.len()];
read_side.read_exact(&mut encrypted).await.unwrap();
let plain = decryptor.decrypt(&encrypted);
assert_eq!(plain[0], 0x7e);
assert_eq!(&plain[1..], payload.as_slice());
}
#[tokio::test]
async fn write_client_payload_abridged_extended_without_quickack_has_clean_prefix() {
let (mut read_side, write_side) = duplex(16 * 1024);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
let mut decryptor = AesCtr::new(&key, iv);
let rng = SecureRandom::new();
let mut frame_buf = Vec::new();
let payload = vec![0x42u8; 0x80 * 4];
write_client_payload(
&mut writer,
ProtoTag::Abridged,
0,
&payload,
&rng,
&mut frame_buf,
)
.await
.expect("extended payload should serialize");
writer.flush().await.expect("flush must succeed");
let mut encrypted = vec![0u8; 4 + payload.len()];
read_side.read_exact(&mut encrypted).await.unwrap();
let plain = decryptor.decrypt(&encrypted);
assert_eq!(plain[0], 0x7f);
assert_eq!(&plain[1..4], &[0x80, 0x00, 0x00]);
assert_eq!(&plain[4..], payload.as_slice());
}
#[tokio::test]
async fn write_client_payload_intermediate_zero_length_emits_header_only() {
let (mut read_side, write_side) = duplex(1024);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
let mut decryptor = AesCtr::new(&key, iv);
let rng = SecureRandom::new();
let mut frame_buf = Vec::new();
write_client_payload(
&mut writer,
ProtoTag::Intermediate,
0,
&[],
&rng,
&mut frame_buf,
)
.await
.expect("zero-length intermediate payload should serialize");
writer.flush().await.expect("flush must succeed");
let mut encrypted = [0u8; 4];
read_side.read_exact(&mut encrypted).await.unwrap();
let plain = decryptor.decrypt(&encrypted);
assert_eq!(plain.as_slice(), &[0, 0, 0, 0]);
}
#[tokio::test]
async fn write_client_payload_intermediate_ignores_unrelated_flags() {
let (mut read_side, write_side) = duplex(1024);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
let mut decryptor = AesCtr::new(&key, iv);
let rng = SecureRandom::new();
let mut frame_buf = Vec::new();
let payload = [7u8; 12];
write_client_payload(
&mut writer,
ProtoTag::Intermediate,
0x4000_0000,
&payload,
&rng,
&mut frame_buf,
)
.await
.expect("payload should serialize");
writer.flush().await.expect("flush must succeed");
let mut encrypted = [0u8; 16];
read_side.read_exact(&mut encrypted).await.unwrap();
let plain = decryptor.decrypt(&encrypted);
let len = u32::from_le_bytes(plain[0..4].try_into().unwrap());
assert_eq!(len, payload.len() as u32, "only quickack bit may affect header");
assert_eq!(&plain[4..], payload.as_slice());
}
#[tokio::test]
async fn write_client_payload_secure_without_quickack_keeps_msb_clear() {
let (mut read_side, write_side) = duplex(4096);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
let mut decryptor = AesCtr::new(&key, iv);
let rng = SecureRandom::new();
let mut frame_buf = Vec::new();
let payload = [0x1Du8; 64];
write_client_payload(
&mut writer,
ProtoTag::Secure,
0,
&payload,
&rng,
&mut frame_buf,
)
.await
.expect("payload should serialize");
writer.flush().await.expect("flush must succeed");
let mut encrypted_header = [0u8; 4];
read_side.read_exact(&mut encrypted_header).await.unwrap();
let plain_header = decryptor.decrypt(&encrypted_header);
let h: [u8; 4] = plain_header.as_slice().try_into().unwrap();
let wire_len_raw = u32::from_le_bytes(h);
assert_eq!(wire_len_raw & 0x8000_0000, 0, "quickack bit must stay clear");
}
#[tokio::test]
async fn secure_padding_light_fuzz_distribution_has_multiple_outcomes() {
let (mut read_side, write_side) = duplex(256 * 1024);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
let mut decryptor = AesCtr::new(&key, iv);
let rng = SecureRandom::new();
let mut frame_buf = Vec::new();
let payload = [0x55u8; 100];
let mut seen = [false; 4];
for _ in 0..96 {
write_client_payload(
&mut writer,
ProtoTag::Secure,
0,
&payload,
&rng,
&mut frame_buf,
)
.await
.expect("secure payload should serialize");
writer.flush().await.expect("flush must succeed");
let mut encrypted_header = [0u8; 4];
read_side.read_exact(&mut encrypted_header).await.unwrap();
let plain_header = decryptor.decrypt(&encrypted_header);
let h: [u8; 4] = plain_header.as_slice().try_into().unwrap();
let wire_len = (u32::from_le_bytes(h) & 0x7fff_ffff) as usize;
let padding_len = wire_len - payload.len();
assert!((1..=3).contains(&padding_len));
seen[padding_len] = true;
let mut encrypted_body = vec![0u8; wire_len];
read_side.read_exact(&mut encrypted_body).await.unwrap();
let _ = decryptor.decrypt(&encrypted_body);
}
let distinct = (1..=3).filter(|idx| seen[*idx]).count();
assert!(
distinct >= 2,
"padding generator should not collapse to a single outcome under campaign"
);
}
#[tokio::test]
async fn write_client_payload_mixed_proto_sequence_preserves_stream_sync() {
let (mut read_side, write_side) = duplex(128 * 1024);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
let mut decryptor = AesCtr::new(&key, iv);
let rng = SecureRandom::new();
let mut frame_buf = Vec::new();
let p1 = vec![1u8; 8];
let p2 = vec![2u8; 16];
let p3 = vec![3u8; 20];
write_client_payload(&mut writer, ProtoTag::Abridged, 0, &p1, &rng, &mut frame_buf)
.await
.unwrap();
write_client_payload(
&mut writer,
ProtoTag::Intermediate,
RPC_FLAG_QUICKACK,
&p2,
&rng,
&mut frame_buf,
)
.await
.unwrap();
write_client_payload(&mut writer, ProtoTag::Secure, 0, &p3, &rng, &mut frame_buf)
.await
.unwrap();
writer.flush().await.unwrap();
// Frame 1: abridged short.
let mut e1 = vec![0u8; 1 + p1.len()];
read_side.read_exact(&mut e1).await.unwrap();
let d1 = decryptor.decrypt(&e1);
assert_eq!(d1[0], (p1.len() / 4) as u8);
assert_eq!(&d1[1..], p1.as_slice());
// Frame 2: intermediate with quickack.
let mut e2 = vec![0u8; 4 + p2.len()];
read_side.read_exact(&mut e2).await.unwrap();
let d2 = decryptor.decrypt(&e2);
let l2 = u32::from_le_bytes(d2[0..4].try_into().unwrap());
assert_ne!(l2 & 0x8000_0000, 0);
assert_eq!((l2 & 0x7fff_ffff) as usize, p2.len());
assert_eq!(&d2[4..], p2.as_slice());
// Frame 3: secure with bounded tail.
let mut e3h = [0u8; 4];
read_side.read_exact(&mut e3h).await.unwrap();
let d3h = decryptor.decrypt(&e3h);
let l3 = (u32::from_le_bytes(d3h.as_slice().try_into().unwrap()) & 0x7fff_ffff) as usize;
assert!(l3 >= p3.len());
assert!((1..=3).contains(&(l3 - p3.len())));
let mut e3b = vec![0u8; l3];
read_side.read_exact(&mut e3b).await.unwrap();
let d3b = decryptor.decrypt(&e3b);
assert_eq!(&d3b[..p3.len()], p3.as_slice());
}
#[test]
fn should_yield_sender_boundary_matrix_blackhat() {
assert!(!should_yield_c2me_sender(0, false));
assert!(!should_yield_c2me_sender(0, true));
assert!(!should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET - 1, true));
assert!(!should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET, false));
assert!(should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET, true));
assert!(should_yield_c2me_sender(
C2ME_SENDER_FAIRNESS_BUDGET.saturating_add(1024),
true
));
}
#[test]
fn should_yield_sender_light_fuzz_matches_oracle() {
let mut s: u64 = 0xD00D_BAAD_F00D_CAFE;
for _ in 0..5000 {
s ^= s << 7;
s ^= s >> 9;
s ^= s << 8;
let sent = (s as usize) & 0x1fff;
let backlog = (s & 1) != 0;
let expected = backlog && sent >= C2ME_SENDER_FAIRNESS_BUDGET;
assert_eq!(should_yield_c2me_sender(sent, backlog), expected);
}
}
#[test]
fn quota_would_be_exceeded_exact_remaining_one_byte() {
let stats = Stats::new();
let user = "quota-edge";
let quota = 100u64;
stats.add_user_octets_to(user, 99);
assert!(
!quota_would_be_exceeded_for_user(&stats, user, Some(quota), 1),
"exactly remaining budget should be allowed"
);
assert!(
quota_would_be_exceeded_for_user(&stats, user, Some(quota), 2),
"one byte beyond remaining budget must be rejected"
);
}
#[test]
fn quota_would_be_exceeded_saturating_edge_remains_fail_closed() {
let stats = Stats::new();
let user = "quota-saturating-edge";
let quota = u64::MAX - 3;
stats.add_user_octets_to(user, u64::MAX - 4);
assert!(
quota_would_be_exceeded_for_user(&stats, user, Some(quota), 2),
"saturating arithmetic edge must stay fail-closed"
);
}
#[test]
fn quota_exceeded_boundary_is_inclusive() {
let stats = Stats::new();
let user = "quota-inclusive-boundary";
stats.add_user_octets_to(user, 50);
assert!(quota_exceeded_for_user(&stats, user, Some(50)));
assert!(!quota_exceeded_for_user(&stats, user, Some(51)));
}
#[tokio::test]
async fn enqueue_c2me_close_fast_path_succeeds_without_backpressure() {
let (tx, mut rx) = mpsc::channel::<C2MeCommand>(4);
enqueue_c2me_command(&tx, C2MeCommand::Close)
.await
.expect("close should enqueue on fast path");
let recv = timeout(TokioDuration::from_millis(50), rx.recv())
.await
.expect("must receive close command")
.expect("close command should be present");
assert!(matches!(recv, C2MeCommand::Close));
}
#[tokio::test]
async fn enqueue_c2me_data_full_then_drain_preserves_order() {
let (tx, mut rx) = mpsc::channel::<C2MeCommand>(1);
tx.send(C2MeCommand::Data {
payload: make_pooled_payload(&[1]),
flags: 10,
})
.await
.unwrap();
let tx2 = tx.clone();
let producer = tokio::spawn(async move {
enqueue_c2me_command(
&tx2,
C2MeCommand::Data {
payload: make_pooled_payload(&[2, 2]),
flags: 20,
},
)
.await
});
tokio::time::sleep(TokioDuration::from_millis(10)).await;
let first = rx.recv().await.expect("first item should exist");
match first {
C2MeCommand::Data { payload, flags } => {
assert_eq!(payload.as_ref(), &[1]);
assert_eq!(flags, 10);
}
C2MeCommand::Close => panic!("unexpected close as first item"),
}
producer.await.unwrap().expect("producer should complete");
let second = timeout(TokioDuration::from_millis(100), rx.recv())
.await
.unwrap()
.expect("second item should exist");
match second {
C2MeCommand::Data { payload, flags } => {
assert_eq!(payload.as_ref(), &[2, 2]);
assert_eq!(flags, 20);
}
C2MeCommand::Close => panic!("unexpected close as second item"),
}
}

View File

@ -47,7 +47,11 @@ fn desync_all_full_bypass_keeps_existing_dedup_entries_unchanged() {
); );
} }
assert_eq!(dedup.len(), 2, "bypass path must not mutate dedup cardinality"); assert_eq!(
dedup.len(),
2,
"bypass path must not mutate dedup cardinality"
);
assert_eq!( assert_eq!(
*dedup *dedup
.get(&0xAAAABBBBCCCCDDDD) .get(&0xAAAABBBBCCCCDDDD)
@ -73,7 +77,11 @@ fn edge_all_full_burst_does_not_poison_later_false_path_tracking() {
let now = Instant::now(); let now = Instant::now();
for i in 0..8192u64 { for i in 0..8192u64 {
assert!(should_emit_full_desync(0xABCD_0000_0000_0000 ^ i, true, now)); assert!(should_emit_full_desync(
0xABCD_0000_0000_0000 ^ i,
true,
now
));
} }
let tracked_key = 0xDEAD_BEEF_0000_0001u64; let tracked_key = 0xDEAD_BEEF_0000_0001u64;
@ -175,5 +183,9 @@ fn stress_parallel_all_full_storm_does_not_grow_or_mutate_cache() {
} }
assert_eq!(emits.load(Ordering::Relaxed), 16 * 4096); assert_eq!(emits.load(Ordering::Relaxed), 16 * 4096);
assert_eq!(dedup.len(), before_len, "parallel all_full storm must not mutate cache len"); assert_eq!(
dedup.len(),
before_len,
"parallel all_full storm must not mutate cache len"
);
} }

View File

@ -2,8 +2,8 @@ use super::*;
use crate::crypto::AesCtr; use crate::crypto::AesCtr;
use crate::stats::Stats; use crate::stats::Stats;
use crate::stream::{BufferPool, CryptoReader}; use crate::stream::{BufferPool, CryptoReader};
use std::sync::{Arc, Mutex, OnceLock};
use std::sync::atomic::AtomicU64; use std::sync::atomic::AtomicU64;
use std::sync::{Arc, Mutex, OnceLock};
use tokio::io::AsyncWriteExt; use tokio::io::AsyncWriteExt;
use tokio::io::duplex; use tokio::io::duplex;
use tokio::time::{Duration as TokioDuration, Instant as TokioInstant, timeout}; use tokio::time::{Duration as TokioDuration, Instant as TokioInstant, timeout};
@ -93,7 +93,9 @@ async fn idle_policy_soft_mark_then_hard_close_increments_reason_counters() {
.await .await
.expect("idle test must complete"); .expect("idle test must complete");
assert!(matches!(result, Err(ProxyError::Io(ref e)) if e.kind() == std::io::ErrorKind::TimedOut)); assert!(
matches!(result, Err(ProxyError::Io(ref e)) if e.kind() == std::io::ErrorKind::TimedOut)
);
let err_text = match result { let err_text = match result {
Err(ProxyError::Io(ref e)) => e.to_string(), Err(ProxyError::Io(ref e)) => e.to_string(),
_ => String::new(), _ => String::new(),
@ -143,7 +145,9 @@ async fn idle_policy_downstream_activity_grace_extends_hard_deadline() {
.await .await
.expect("grace test must complete"); .expect("grace test must complete");
assert!(matches!(result, Err(ProxyError::Io(ref e)) if e.kind() == std::io::ErrorKind::TimedOut)); assert!(
matches!(result, Err(ProxyError::Io(ref e)) if e.kind() == std::io::ErrorKind::TimedOut)
);
assert!( assert!(
start.elapsed() >= TokioDuration::from_millis(100), start.elapsed() >= TokioDuration::from_millis(100),
"recent downstream activity must extend hard idle deadline" "recent downstream activity must extend hard idle deadline"
@ -171,7 +175,9 @@ async fn relay_idle_policy_disabled_keeps_legacy_timeout_behavior() {
) )
.await; .await;
assert!(matches!(result, Err(ProxyError::Io(ref e)) if e.kind() == std::io::ErrorKind::TimedOut)); assert!(
matches!(result, Err(ProxyError::Io(ref e)) if e.kind() == std::io::ErrorKind::TimedOut)
);
let err_text = match result { let err_text = match result {
Err(ProxyError::Io(ref e)) => e.to_string(), Err(ProxyError::Io(ref e)) => e.to_string(),
_ => String::new(), _ => String::new(),
@ -225,8 +231,13 @@ async fn adversarial_partial_frame_trickle_cannot_bypass_hard_idle_close() {
.await .await
.expect("partial frame trickle test must complete"); .expect("partial frame trickle test must complete");
assert!(matches!(result, Err(ProxyError::Io(ref e)) if e.kind() == std::io::ErrorKind::TimedOut)); assert!(
assert_eq!(frame_counter, 0, "partial trickle must not count as a valid frame"); matches!(result, Err(ProxyError::Io(ref e)) if e.kind() == std::io::ErrorKind::TimedOut)
);
assert_eq!(
frame_counter, 0,
"partial trickle must not count as a valid frame"
);
} }
#[tokio::test] #[tokio::test]
@ -291,7 +302,10 @@ async fn protocol_desync_small_frame_updates_reason_counter() {
plaintext.extend_from_slice(&3u32.to_le_bytes()); plaintext.extend_from_slice(&3u32.to_le_bytes());
plaintext.extend_from_slice(&[1u8, 2, 3]); plaintext.extend_from_slice(&[1u8, 2, 3]);
let encrypted = encrypt_for_reader(&plaintext); let encrypted = encrypt_for_reader(&plaintext);
writer.write_all(&encrypted).await.expect("must write frame"); writer
.write_all(&encrypted)
.await
.expect("must write frame");
let result = read_client_payload( let result = read_client_payload(
&mut crypto_reader, &mut crypto_reader,
@ -657,7 +671,8 @@ fn blackhat_pressure_seq_saturation_must_not_break_multiple_distinct_events() {
} }
#[tokio::test(flavor = "multi_thread", worker_threads = 4)] #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn integration_race_single_pressure_event_allows_at_most_one_eviction_under_parallel_claims() { async fn integration_race_single_pressure_event_allows_at_most_one_eviction_under_parallel_claims()
{
let _guard = acquire_idle_pressure_test_lock(); let _guard = acquire_idle_pressure_test_lock();
clear_relay_idle_pressure_state_for_testing(); clear_relay_idle_pressure_state_for_testing();
@ -680,7 +695,8 @@ async fn integration_race_single_pressure_event_allows_at_most_one_eviction_unde
let conn_id = *conn_id; let conn_id = *conn_id;
let stats = stats.clone(); let stats = stats.clone();
joins.push(tokio::spawn(async move { joins.push(tokio::spawn(async move {
let evicted = maybe_evict_idle_candidate_on_pressure(conn_id, &mut seen, stats.as_ref()); let evicted =
maybe_evict_idle_candidate_on_pressure(conn_id, &mut seen, stats.as_ref());
(idx, conn_id, seen, evicted) (idx, conn_id, seen, evicted)
})); }));
} }
@ -753,7 +769,8 @@ async fn integration_race_burst_pressure_with_churn_preserves_empty_set_invalida
let conn_id = *conn_id; let conn_id = *conn_id;
let stats = stats.clone(); let stats = stats.clone();
joins.push(tokio::spawn(async move { joins.push(tokio::spawn(async move {
let evicted = maybe_evict_idle_candidate_on_pressure(conn_id, &mut seen, stats.as_ref()); let evicted =
maybe_evict_idle_candidate_on_pressure(conn_id, &mut seen, stats.as_ref());
(idx, conn_id, seen, evicted) (idx, conn_id, seen, evicted)
})); }));
} }

View File

@ -0,0 +1,75 @@
use super::*;
use std::sync::Arc;
use std::sync::atomic::AtomicU64;
#[test]
fn intermediate_secure_wire_len_allows_max_31bit_payload() {
let (len_val, total) = compute_intermediate_secure_wire_len(0x7fff_fffe, 1, true)
.expect("31-bit wire length should be accepted");
assert_eq!(len_val, 0xffff_ffff, "quickack must use top bit only");
assert_eq!(total, 0x8000_0003);
}
#[test]
fn intermediate_secure_wire_len_rejects_length_above_31bit_limit() {
let err = compute_intermediate_secure_wire_len(0x7fff_ffff, 1, false)
.expect_err("wire length above 31-bit must fail closed");
assert!(
format!("{err}").contains("frame too large"),
"error should identify oversize frame path"
);
}
#[test]
fn intermediate_secure_wire_len_rejects_addition_overflow() {
let err = compute_intermediate_secure_wire_len(usize::MAX, 1, false)
.expect_err("overflowing addition must fail closed");
assert!(
format!("{err}").contains("overflow"),
"error should clearly report overflow"
);
}
#[test]
fn desync_forensics_len_bytes_marks_truncation_for_oversize_values() {
let (small_bytes, small_truncated) = desync_forensics_len_bytes(0x1020_3040);
assert_eq!(small_bytes, 0x1020_3040u32.to_le_bytes());
assert!(!small_truncated);
let (huge_bytes, huge_truncated) = desync_forensics_len_bytes(usize::MAX);
assert_eq!(huge_bytes, u32::MAX.to_le_bytes());
assert!(huge_truncated);
}
#[test]
fn report_desync_frame_too_large_preserves_full_length_in_error_message() {
let state = RelayForensicsState {
trace_id: 0x1234,
conn_id: 0x5678,
user: "middle-desync-oversize".to_string(),
peer: "198.51.100.55:443".parse().expect("valid test peer"),
peer_hash: 0xAABBCCDD,
started_at: Instant::now(),
bytes_c2me: 7,
bytes_me2c: Arc::new(AtomicU64::new(9)),
desync_all_full: false,
};
let huge_len = usize::MAX;
let err = report_desync_frame_too_large(
&state,
ProtoTag::Intermediate,
3,
1024,
huge_len,
None,
&Stats::new(),
);
let msg = format!("{err}");
assert!(
msg.contains(&huge_len.to_string()),
"error must preserve full usize length for forensics"
);
}

View File

@ -0,0 +1,131 @@
use super::*;
use dashmap::DashMap;
use std::sync::Arc;
#[test]
fn saturation_uses_stable_overflow_lock_without_cache_growth() {
let _guard = super::quota_user_lock_test_scope();
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let prefix = format!("middle-quota-held-{}", std::process::id());
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
for idx in 0..QUOTA_USER_LOCKS_MAX {
retained.push(quota_user_lock(&format!("{prefix}-{idx}")));
}
assert_eq!(map.len(), QUOTA_USER_LOCKS_MAX);
let user = format!("middle-quota-overflow-{}", std::process::id());
let first = quota_user_lock(&user);
let second = quota_user_lock(&user);
assert!(
Arc::ptr_eq(&first, &second),
"overflow user must get deterministic same lock while cache is saturated"
);
assert_eq!(
map.len(),
QUOTA_USER_LOCKS_MAX,
"overflow path must not grow bounded lock map"
);
assert!(
map.get(&user).is_none(),
"overflow user should stay outside bounded lock map under saturation"
);
drop(retained);
}
#[test]
fn overflow_striping_keeps_different_users_distributed() {
let _guard = super::quota_user_lock_test_scope();
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let prefix = format!("middle-quota-dist-held-{}", std::process::id());
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
for idx in 0..QUOTA_USER_LOCKS_MAX {
retained.push(quota_user_lock(&format!("{prefix}-{idx}")));
}
let a = quota_user_lock("middle-overflow-user-a");
let b = quota_user_lock("middle-overflow-user-b");
let c = quota_user_lock("middle-overflow-user-c");
let distinct = [
Arc::as_ptr(&a) as usize,
Arc::as_ptr(&b) as usize,
Arc::as_ptr(&c) as usize,
]
.iter()
.copied()
.collect::<std::collections::HashSet<_>>()
.len();
assert!(
distinct >= 2,
"striped overflow lock set should avoid collapsing all users to one lock"
);
drop(retained);
}
#[test]
fn reclaim_path_caches_new_user_after_stale_entries_drop() {
let _guard = super::quota_user_lock_test_scope();
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let prefix = format!("middle-quota-reclaim-held-{}", std::process::id());
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
for idx in 0..QUOTA_USER_LOCKS_MAX {
retained.push(quota_user_lock(&format!("{prefix}-{idx}")));
}
drop(retained);
let user = format!("middle-quota-reclaim-user-{}", std::process::id());
let got = quota_user_lock(&user);
assert!(map.get(&user).is_some());
assert!(
Arc::strong_count(&got) >= 2,
"after reclaim, lock should be held both by caller and map"
);
}
#[test]
fn overflow_path_same_user_is_stable_across_parallel_threads() {
let _guard = super::quota_user_lock_test_scope();
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
for idx in 0..QUOTA_USER_LOCKS_MAX {
retained.push(quota_user_lock(&format!(
"middle-quota-thread-held-{}-{idx}",
std::process::id()
)));
}
let user = format!("middle-quota-overflow-thread-user-{}", std::process::id());
let mut workers = Vec::new();
for _ in 0..32 {
let user = user.clone();
workers.push(std::thread::spawn(move || quota_user_lock(&user)));
}
let first = workers
.remove(0)
.join()
.expect("thread must return lock handle");
for worker in workers {
let got = worker.join().expect("thread must return lock handle");
assert!(
Arc::ptr_eq(&first, &got),
"same overflow user should resolve to one striped lock even under contention"
);
}
drop(retained);
}

View File

@ -1,27 +1,27 @@
use super::*; use super::*;
use crate::proxy::handshake::HandshakeSuccess; use crate::config::{GeneralConfig, MeRouteNoWriterMode, MeSocksKdfPolicy, MeWriterPickMode};
use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController};
use bytes::Bytes;
use crate::crypto::AesCtr; use crate::crypto::AesCtr;
use crate::crypto::SecureRandom; use crate::crypto::SecureRandom;
use crate::config::{GeneralConfig, MeRouteNoWriterMode, MeSocksKdfPolicy, MeWriterPickMode};
use crate::network::probe::NetworkDecision; use crate::network::probe::NetworkDecision;
use crate::proxy::handshake::HandshakeSuccess;
use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController};
use crate::stats::Stats; use crate::stats::Stats;
use crate::stream::{BufferPool, CryptoReader, CryptoWriter, PooledBuffer}; use crate::stream::{BufferPool, CryptoReader, CryptoWriter, PooledBuffer};
use crate::transport::middle_proxy::MePool; use crate::transport::middle_proxy::MePool;
use bytes::Bytes;
use rand::rngs::StdRng; use rand::rngs::StdRng;
use rand::{RngExt, SeedableRng}; use rand::{RngExt, SeedableRng};
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering}; use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::sync::Mutex;
use std::thread; use std::thread;
use tokio::sync::Barrier;
use tokio::io::AsyncReadExt; use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt; use tokio::io::AsyncWriteExt;
use tokio::io::duplex; use tokio::io::duplex;
use tokio::sync::Barrier;
use tokio::time::{Duration as TokioDuration, timeout}; use tokio::time::{Duration as TokioDuration, timeout};
use std::sync::{Mutex, OnceLock};
fn make_pooled_payload(data: &[u8]) -> PooledBuffer { fn make_pooled_payload(data: &[u8]) -> PooledBuffer {
let pool = Arc::new(BufferPool::with_config(data.len().max(1), 4)); let pool = Arc::new(BufferPool::with_config(data.len().max(1), 4));
@ -38,16 +38,17 @@ fn make_pooled_payload_from(pool: &Arc<BufferPool>, data: &[u8]) -> PooledBuffer
payload payload
} }
fn quota_user_lock_test_lock() -> &'static Mutex<()> {
static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
TEST_LOCK.get_or_init(|| Mutex::new(()))
}
#[test] #[test]
fn should_yield_sender_only_on_budget_with_backlog() { fn should_yield_sender_only_on_budget_with_backlog() {
assert!(!should_yield_c2me_sender(0, true)); assert!(!should_yield_c2me_sender(0, true));
assert!(!should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET - 1, true)); assert!(!should_yield_c2me_sender(
assert!(!should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET, false)); C2ME_SENDER_FAIRNESS_BUDGET - 1,
true
));
assert!(!should_yield_c2me_sender(
C2ME_SENDER_FAIRNESS_BUDGET,
false
));
assert!(should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET, true)); assert!(should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET, true));
} }
@ -125,14 +126,7 @@ async fn enqueue_c2me_command_closed_channel_recycles_payload() {
let (tx, rx) = mpsc::channel::<C2MeCommand>(1); let (tx, rx) = mpsc::channel::<C2MeCommand>(1);
drop(rx); drop(rx);
let result = enqueue_c2me_command( let result = enqueue_c2me_command(&tx, C2MeCommand::Data { payload, flags: 0 }).await;
&tx,
C2MeCommand::Data {
payload,
flags: 0,
},
)
.await;
assert!(result.is_err(), "closed queue must fail enqueue"); assert!(result.is_err(), "closed queue must fail enqueue");
drop(result); drop(result);
@ -244,6 +238,11 @@ fn desync_dedup_cache_is_bounded() {
#[test] #[test]
fn quota_user_lock_cache_reuses_entry_for_same_user() { fn quota_user_lock_cache_reuses_entry_for_same_user() {
let _guard = super::quota_user_lock_test_scope();
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let a = quota_user_lock("quota-user-a"); let a = quota_user_lock("quota-user-a");
let b = quota_user_lock("quota-user-a"); let b = quota_user_lock("quota-user-a");
assert!(Arc::ptr_eq(&a, &b), "same user must reuse same quota lock"); assert!(Arc::ptr_eq(&a, &b), "same user must reuse same quota lock");
@ -251,9 +250,7 @@ fn quota_user_lock_cache_reuses_entry_for_same_user() {
#[test] #[test]
fn quota_user_lock_cache_is_bounded_under_unique_churn() { fn quota_user_lock_cache_is_bounded_under_unique_churn() {
let _guard = quota_user_lock_test_lock() let _guard = super::quota_user_lock_test_scope();
.lock()
.expect("quota user lock test lock must be available");
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear(); map.clear();
@ -271,10 +268,8 @@ fn quota_user_lock_cache_is_bounded_under_unique_churn() {
} }
#[test] #[test]
fn quota_user_lock_cache_saturation_returns_ephemeral_lock_without_growth() { fn quota_user_lock_cache_saturation_returns_stable_overflow_lock_without_growth() {
let _guard = quota_user_lock_test_lock() let _guard = super::quota_user_lock_test_scope();
.lock()
.expect("quota user lock test lock must be available");
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
for attempt in 0..8u32 { for attempt in 0..8u32 {
@ -306,17 +301,15 @@ fn quota_user_lock_cache_saturation_returns_ephemeral_lock_without_growth() {
"overflow path should not cache new user lock when map is saturated and all entries are retained" "overflow path should not cache new user lock when map is saturated and all entries are retained"
); );
assert!( assert!(
!Arc::ptr_eq(&overflow_a, &overflow_b), Arc::ptr_eq(&overflow_a, &overflow_b),
"overflow user lock should be ephemeral under saturation to preserve bounded cache size" "overflow user lock should use deterministic striping under saturation"
); );
drop(retained); drop(retained);
return; return;
} }
panic!( panic!("unable to observe stable saturated lock-cache precondition after bounded retries");
"unable to observe stable saturated lock-cache precondition after bounded retries"
);
} }
#[tokio::test(flavor = "multi_thread", worker_threads = 4)] #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
@ -390,14 +383,7 @@ async fn stress_quota_race_under_lock_cache_saturation_never_allows_double_succe
12_000 + round, 12_000 + round,
barrier.clone(), barrier.clone(),
); );
let two = run_quota_race_attempt( let two = run_quota_race_attempt(&stats, &bytes_me2c, &user, 0x72, 13_000 + round, barrier);
&stats,
&bytes_me2c,
&user,
0x72,
13_000 + round,
barrier,
);
let (r1, r2) = tokio::join!(one, two); let (r1, r2) = tokio::join!(one, two);
assert!( assert!(
@ -823,7 +809,9 @@ fn full_cache_gate_lock_poison_is_fail_closed_without_panic() {
// Poison the full-cache gate lock intentionally. // Poison the full-cache gate lock intentionally.
let gate = DESYNC_FULL_CACHE_LAST_EMIT_AT.get_or_init(|| Mutex::new(None)); let gate = DESYNC_FULL_CACHE_LAST_EMIT_AT.get_or_init(|| Mutex::new(None));
let _ = std::panic::catch_unwind(|| { let _ = std::panic::catch_unwind(|| {
let _lock = gate.lock().expect("gate lock must be lockable before poison"); let _lock = gate
.lock()
.expect("gate lock must be lockable before poison");
panic!("intentional gate poison for fail-closed regression"); panic!("intentional gate poison for fail-closed regression");
}); });
@ -961,24 +949,6 @@ fn light_fuzz_desync_dedup_temporal_gate_behavior_is_stable() {
panic!("expected at least one post-window sample to re-emit forensic record"); panic!("expected at least one post-window sample to re-emit forensic record");
} }
#[test]
#[ignore = "Tracking for M-04: Verify should_emit_full_desync returns true on first occurrence and false on duplicate within window"]
fn should_emit_full_desync_filters_duplicates() {
unimplemented!("Stub for M-04");
}
#[test]
#[ignore = "Tracking for M-04: Verify desync dedup eviction behaves correctly under map-full condition"]
fn desync_dedup_eviction_under_map_full_condition() {
unimplemented!("Stub for M-04");
}
#[tokio::test]
#[ignore = "Tracking for M-05: Verify C2ME channel full path yields then sends under backpressure"]
async fn c2me_channel_full_path_yields_then_sends() {
unimplemented!("Stub for M-05");
}
fn make_forensics_state() -> RelayForensicsState { fn make_forensics_state() -> RelayForensicsState {
RelayForensicsState { RelayForensicsState {
trace_id: 1, trace_id: 1,
@ -1208,7 +1178,11 @@ async fn read_client_payload_large_intermediate_frame_is_exact() {
let (frame, quickack) = read; let (frame, quickack) = read;
assert!(!quickack, "quickack flag must be unset"); assert!(!quickack, "quickack flag must be unset");
assert_eq!(frame.len(), payload_len, "payload size must match wire length"); assert_eq!(
frame.len(),
payload_len,
"payload size must match wire length"
);
for (idx, byte) in frame.iter().enumerate() { for (idx, byte) in frame.iter().enumerate() {
assert_eq!(*byte, (idx as u8).wrapping_mul(31)); assert_eq!(*byte, (idx as u8).wrapping_mul(31));
} }
@ -1376,7 +1350,10 @@ async fn read_client_payload_abridged_extended_len_sets_quickack() {
.expect("frame must be present"); .expect("frame must be present");
let (frame, quickack) = read; let (frame, quickack) = read;
assert!(quickack, "quickack bit must be propagated from abridged header"); assert!(
quickack,
"quickack bit must be propagated from abridged header"
);
assert_eq!(frame.len(), payload_len); assert_eq!(frame.len(), payload_len);
assert_eq!(frame_counter, 1, "one abridged frame must be counted"); assert_eq!(frame_counter, 1, "one abridged frame must be counted");
} }
@ -1436,7 +1413,11 @@ async fn read_client_payload_keeps_pool_buffer_checked_out_until_frame_drop() {
let pool = Arc::new(BufferPool::with_config(64, 2)); let pool = Arc::new(BufferPool::with_config(64, 2));
pool.preallocate(1); pool.preallocate(1);
assert_eq!(pool.stats().pooled, 1, "one pooled buffer must be available"); assert_eq!(
pool.stats().pooled,
1,
"one pooled buffer must be available"
);
let (reader, mut writer) = duplex(1024); let (reader, mut writer) = duplex(1024);
let mut crypto_reader = make_crypto_reader(reader); let mut crypto_reader = make_crypto_reader(reader);
@ -1491,7 +1472,8 @@ async fn enqueue_c2me_close_unblocks_after_queue_drain() {
.unwrap(); .unwrap();
let tx2 = tx.clone(); let tx2 = tx.clone();
let close_task = tokio::spawn(async move { enqueue_c2me_command(&tx2, C2MeCommand::Close).await }); let close_task =
tokio::spawn(async move { enqueue_c2me_command(&tx2, C2MeCommand::Close).await });
tokio::time::sleep(TokioDuration::from_millis(10)).await; tokio::time::sleep(TokioDuration::from_millis(10)).await;
@ -1501,7 +1483,10 @@ async fn enqueue_c2me_close_unblocks_after_queue_drain() {
.expect("first queued item must be present"); .expect("first queued item must be present");
assert!(matches!(first, C2MeCommand::Data { .. })); assert!(matches!(first, C2MeCommand::Data { .. }));
close_task.await.unwrap().expect("close enqueue must succeed after drain"); close_task
.await
.unwrap()
.expect("close enqueue must succeed after drain");
let second = timeout(TokioDuration::from_millis(100), rx.recv()) let second = timeout(TokioDuration::from_millis(100), rx.recv())
.await .await
@ -1521,7 +1506,8 @@ async fn enqueue_c2me_close_full_then_receiver_drop_fails_cleanly() {
.unwrap(); .unwrap();
let tx2 = tx.clone(); let tx2 = tx.clone();
let close_task = tokio::spawn(async move { enqueue_c2me_command(&tx2, C2MeCommand::Close).await }); let close_task =
tokio::spawn(async move { enqueue_c2me_command(&tx2, C2MeCommand::Close).await });
tokio::time::sleep(TokioDuration::from_millis(10)).await; tokio::time::sleep(TokioDuration::from_millis(10)).await;
drop(rx); drop(rx);
@ -1756,7 +1742,8 @@ async fn process_me_writer_response_concurrent_same_user_quota_does_not_overshoo
} }
#[tokio::test] #[tokio::test]
async fn process_me_writer_response_data_does_not_forward_partial_payload_when_remaining_quota_is_smaller_than_message() { async fn process_me_writer_response_data_does_not_forward_partial_payload_when_remaining_quota_is_smaller_than_message()
{
let (writer_side, mut reader_side) = duplex(1024); let (writer_side, mut reader_side) = duplex(1024);
let mut writer = make_crypto_writer(writer_side); let mut writer = make_crypto_writer(writer_side);
let rng = SecureRandom::new(); let rng = SecureRandom::new();
@ -1851,11 +1838,17 @@ async fn middle_relay_abort_midflight_releases_route_gauge() {
} }
}) })
.await; .await;
assert!(started.is_ok(), "middle relay must increment route gauge before abort"); assert!(
started.is_ok(),
"middle relay must increment route gauge before abort"
);
relay_task.abort(); relay_task.abort();
let joined = relay_task.await; let joined = relay_task.await;
assert!(joined.is_err(), "aborted middle relay task must return join error"); assert!(
joined.is_err(),
"aborted middle relay task must return join error"
);
tokio::time::sleep(TokioDuration::from_millis(20)).await; tokio::time::sleep(TokioDuration::from_millis(20)).await;
assert_eq!( assert_eq!(
@ -2014,8 +2007,14 @@ async fn abridged_max_extended_length_fails_closed_without_panic_or_partial_read
) )
.await; .await;
assert!(result.is_err(), "oversized abridged length must fail closed"); assert!(
assert_eq!(frame_counter, 0, "oversized frame must not be counted as accepted"); result.is_err(),
"oversized abridged length must fail closed"
);
assert_eq!(
frame_counter, 0,
"oversized frame must not be counted as accepted"
);
} }
#[tokio::test(flavor = "multi_thread", worker_threads = 4)] #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
@ -2067,14 +2066,7 @@ async fn stress_quota_race_bursts_never_allow_double_success_per_round() {
6000 + round, 6000 + round,
barrier.clone(), barrier.clone(),
); );
let two = run_quota_race_attempt( let two = run_quota_race_attempt(&stats, &bytes_me2c, &user, 0x44, 7000 + round, barrier);
&stats,
&bytes_me2c,
&user,
0x44,
7000 + round,
barrier,
);
let (r1, r2) = tokio::join!(one, two); let (r1, r2) = tokio::join!(one, two);
assert!( assert!(

Some files were not shown because too many files have changed in this diff Show More