Compare commits

..

No commits in common. "2d3c2807abada98276bdbb73cafefea096dd356b" and "54d65dd124bf308e10d81bf5e97d6a1956694003" have entirely different histories.

173 changed files with 4192 additions and 8925 deletions

View File

@ -21,13 +21,16 @@ env:
jobs: jobs:
prepare: prepare:
name: Prepare metadata
runs-on: ubuntu-latest runs-on: ubuntu-latest
outputs: outputs:
version: ${{ steps.meta.outputs.version }} version: ${{ steps.meta.outputs.version }}
prerelease: ${{ steps.meta.outputs.prerelease }} prerelease: ${{ steps.meta.outputs.prerelease }}
release_enabled: ${{ steps.meta.outputs.release_enabled }} release_enabled: ${{ steps.meta.outputs.release_enabled }}
steps: steps:
- id: meta - name: Derive version
id: meta
shell: bash
run: | run: |
set -euo pipefail set -euo pipefail
@ -50,38 +53,67 @@ jobs:
echo "release_enabled=$RELEASE_ENABLED" >> "$GITHUB_OUTPUT" echo "release_enabled=$RELEASE_ENABLED" >> "$GITHUB_OUTPUT"
checks: checks:
name: Checks
runs-on: ubuntu-latest runs-on: ubuntu-latest
container: container:
image: debian:trixie image: debian:trixie
steps: steps:
- run: | - name: Install system dependencies
shell: bash
run: |
set -euo pipefail
apt-get update apt-get update
apt-get install -y build-essential clang llvm pkg-config curl git apt-get install -y --no-install-recommends \
ca-certificates \
curl \
git \
build-essential \
pkg-config \
clang \
llvm \
python3 \
python3-pip
update-ca-certificates
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: dtolnay/rust-toolchain@stable - uses: dtolnay/rust-toolchain@stable
with: with:
components: rustfmt, clippy components: rustfmt, clippy
- uses: actions/cache@v4 - name: Cache cargo
uses: actions/cache@v4
with: with:
path: | path: |
/github/home/.cargo/registry /github/home/.cargo/registry
/github/home/.cargo/git /github/home/.cargo/git
target target
key: checks-${{ hashFiles('**/Cargo.lock') }} key: checks-${{ runner.os }}-${{ hashFiles('**/Cargo.lock') }}
restore-keys: |
checks-${{ runner.os }}-
- run: cargo fetch --locked - name: Cargo fetch
- run: cargo fmt --all -- --check shell: bash
- run: cargo clippy run: cargo fetch --locked
- run: cargo test
- name: Format
shell: bash
run: cargo fmt --all -- --check
- name: Clippy
shell: bash
run: cargo clippy --workspace --all-targets --locked -- -D warnings
- name: Tests
shell: bash
run: cargo test --workspace --all-targets --locked
build-binaries: build-binaries:
name: Build ${{ matrix.asset_name }}
needs: [prepare, checks] needs: [prepare, checks]
runs-on: ubuntu-latest runs-on: ubuntu-latest
container: container:
image: debian:trixie image: debian:trixie
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
@ -100,80 +132,156 @@ jobs:
asset_name: telemt-aarch64-linux-musl asset_name: telemt-aarch64-linux-musl
steps: steps:
- run: | - name: Install system dependencies
shell: bash
run: |
set -euo pipefail
apt-get update apt-get update
apt-get install -y clang llvm pkg-config curl git python3 python3-pip file tar xz-utils apt-get install -y --no-install-recommends \
ca-certificates \
curl \
git \
build-essential \
pkg-config \
clang \
llvm \
file \
tar \
xz-utils \
python3 \
python3-pip
update-ca-certificates
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: dtolnay/rust-toolchain@stable - uses: dtolnay/rust-toolchain@stable
with: with:
targets: ${{ matrix.rust_target }} targets: ${{ matrix.rust_target }}
- uses: actions/cache@v4 - name: Cache cargo
uses: actions/cache@v4
with: with:
path: | path: |
/github/home/.cargo/registry /github/home/.cargo/registry
/github/home/.cargo/git /github/home/.cargo/git
target target
key: build-${{ matrix.zig_target }}-${{ hashFiles('**/Cargo.lock') }} key: build-${{ matrix.zig_target }}-${{ hashFiles('**/Cargo.lock') }}
restore-keys: |
build-${{ matrix.zig_target }}-
- run: | - name: Install cargo-zigbuild + Zig
shell: bash
run: |
set -euo pipefail
python3 -m pip install --user --break-system-packages cargo-zigbuild python3 -m pip install --user --break-system-packages cargo-zigbuild
echo "/github/home/.local/bin" >> "$GITHUB_PATH" echo "/github/home/.local/bin" >> "$GITHUB_PATH"
- run: cargo fetch --locked - name: Cargo fetch
shell: bash
run: cargo fetch --locked
- run: | - name: Build release
shell: bash
env:
CARGO_PROFILE_RELEASE_LTO: "fat"
CARGO_PROFILE_RELEASE_CODEGEN_UNITS: "1"
CARGO_PROFILE_RELEASE_PANIC: "abort"
run: |
set -euo pipefail
cargo zigbuild --release --locked --target "${{ matrix.zig_target }}" cargo zigbuild --release --locked --target "${{ matrix.zig_target }}"
- run: | - name: Strip binary
BIN="target/${{ matrix.rust_target }}/release/${BINARY_NAME}" shell: bash
llvm-strip "$BIN" || true run: |
set -euo pipefail
llvm-strip "target/${{ matrix.zig_target }}/release/${BINARY_NAME}" || true
- run: | - name: Inspect binary
BIN="target/${{ matrix.rust_target }}/release/${BINARY_NAME}" shell: bash
OUT="$RUNNER_TEMP/${{ matrix.asset_name }}" run: |
mkdir -p "$OUT" set -euo pipefail
install -m755 "$BIN" "$OUT/${BINARY_NAME}" file "target/${{ matrix.zig_target }}/release/${BINARY_NAME}"
tar -C "$RUNNER_TEMP" -czf "${{ matrix.asset_name }}.tar.gz" "${{ matrix.asset_name }}" - name: Package
sha256sum "${{ matrix.asset_name }}.tar.gz" > "${{ matrix.asset_name }}.sha256" shell: bash
run: |
set -euo pipefail
OUTDIR="$RUNNER_TEMP/pkg/${{ matrix.asset_name }}"
mkdir -p "$OUTDIR"
install -m 0755 "target/${{ matrix.zig_target }}/release/${BINARY_NAME}" "$OUTDIR/${BINARY_NAME}"
if [[ -f LICENSE ]]; then cp LICENSE "$OUTDIR/"; fi
if [[ -f README.md ]]; then cp README.md "$OUTDIR/"; fi
cat > "$OUTDIR/BUILD-INFO.txt" <<EOF
project=${GITHUB_REPOSITORY}
version=${{ needs.prepare.outputs.version }}
git_ref=${GITHUB_REF}
git_sha=${GITHUB_SHA}
rust_target=${{ matrix.rust_target }}
zig_target=${{ matrix.zig_target }}
built_at=$(date -u +%Y-%m-%dT%H:%M:%SZ)
EOF
mkdir -p dist
tar -C "$RUNNER_TEMP/pkg" -czf "dist/${{ matrix.asset_name }}.tar.gz" "${{ matrix.asset_name }}"
sha256sum "dist/${{ matrix.asset_name }}.tar.gz" > "dist/${{ matrix.asset_name }}.sha256"
- uses: actions/upload-artifact@v4 - uses: actions/upload-artifact@v4
with: with:
name: ${{ matrix.asset_name }} name: ${{ matrix.asset_name }}
path: | path: |
${{ matrix.asset_name }}.tar.gz dist/${{ matrix.asset_name }}.tar.gz
${{ matrix.asset_name }}.sha256 dist/${{ matrix.asset_name }}.sha256
if-no-files-found: error
retention-days: 14
attest-binaries:
name: Attest binary archives
needs: build-binaries
runs-on: ubuntu-latest
permissions:
contents: read
attestations: write
id-token: write
steps:
- uses: actions/download-artifact@v4
with:
path: dist
- name: Flatten artifacts
shell: bash
run: |
set -euo pipefail
mkdir -p upload
find dist -type f \( -name '*.tar.gz' -o -name '*.sha256' \) -exec cp {} upload/ \;
ls -lah upload
- name: Attest release archives
uses: actions/attest-build-provenance@v3
with:
subject-path: 'upload/*.tar.gz'
docker-image: docker-image:
name: Docker ${{ matrix.platform }} name: Build and push GHCR image
needs: [prepare, build-binaries] needs: [prepare, checks]
runs-on: ubuntu-latest runs-on: ubuntu-latest
permissions:
strategy: contents: read
matrix: packages: write
include:
- platform: linux/amd64
artifact: telemt-x86_64-linux-gnu
- platform: linux/arm64
artifact: telemt-aarch64-linux-gnu
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: actions/download-artifact@v4 - name: Set up QEMU
with: uses: docker/setup-qemu-action@v3
name: ${{ matrix.artifact }}
path: dist
- run: | - name: Set up Buildx
mkdir docker-build uses: docker/setup-buildx-action@v3
tar -xzf dist/*.tar.gz -C docker-build --strip-components=1
- uses: docker/setup-buildx-action@v3 - name: Log in to GHCR
- name: Login
if: ${{ needs.prepare.outputs.release_enabled == 'true' }} if: ${{ needs.prepare.outputs.release_enabled == 'true' }}
uses: docker/login-action@v3 uses: docker/login-action@v3
with: with:
@ -181,20 +289,43 @@ jobs:
username: ${{ github.actor }} username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }} password: ${{ secrets.GITHUB_TOKEN }}
- uses: docker/build-push-action@v6 - name: Docker metadata
id: meta
uses: docker/metadata-action@v5
with: with:
context: ./docker-build images: ghcr.io/${{ github.repository }}
platforms: ${{ matrix.platform }} tags: |
type=raw,value=${{ needs.prepare.outputs.version }}
type=raw,value=latest,enable=${{ needs.prepare.outputs.prerelease != 'true' && needs.prepare.outputs.release_enabled == 'true' }}
labels: |
org.opencontainers.image.title=telemt
org.opencontainers.image.description=telemt
org.opencontainers.image.source=https://github.com/${{ github.repository }}
org.opencontainers.image.version=${{ needs.prepare.outputs.version }}
org.opencontainers.image.revision=${{ github.sha }}
- name: Build and push
id: build
uses: docker/build-push-action@v6
with:
context: .
file: ./Dockerfile
platforms: linux/amd64,linux/arm64
push: ${{ needs.prepare.outputs.release_enabled == 'true' }} push: ${{ needs.prepare.outputs.release_enabled == 'true' }}
tags: ghcr.io/${{ github.repository }}:${{ needs.prepare.outputs.version }} tags: ${{ steps.meta.outputs.tags }}
cache-from: type=gha,scope=telemt-${{ matrix.platform }} labels: ${{ steps.meta.outputs.labels }}
cache-to: type=gha,mode=max,scope=telemt-${{ matrix.platform }} cache-from: type=gha
provenance: false cache-to: type=gha,mode=max
sbom: false provenance: mode=max
sbom: true
build-args: |
TELEMT_VERSION=${{ needs.prepare.outputs.version }}
VCS_REF=${{ github.sha }}
release: release:
name: Create GitHub Release
if: ${{ needs.prepare.outputs.release_enabled == 'true' }} if: ${{ needs.prepare.outputs.release_enabled == 'true' }}
needs: [prepare, build-binaries] needs: [prepare, build-binaries, attest-binaries, docker-image]
runs-on: ubuntu-latest runs-on: ubuntu-latest
permissions: permissions:
contents: write contents: write
@ -203,14 +334,19 @@ jobs:
- uses: actions/download-artifact@v4 - uses: actions/download-artifact@v4
with: with:
path: release-artifacts path: release-artifacts
pattern: telemt-*
- run: | - name: Flatten artifacts
mkdir upload shell: bash
run: |
set -euo pipefail
mkdir -p upload
find release-artifacts -type f \( -name '*.tar.gz' -o -name '*.sha256' \) -exec cp {} upload/ \; find release-artifacts -type f \( -name '*.tar.gz' -o -name '*.sha256' \) -exec cp {} upload/ \;
ls -lah upload
- uses: softprops/action-gh-release@v2 - name: Create release
uses: softprops/action-gh-release@v2
with: with:
files: upload/* files: upload/*
generate_release_notes: true generate_release_notes: true
draft: false
prerelease: ${{ needs.prepare.outputs.prerelease == 'true' }} prerelease: ${{ needs.prepare.outputs.prerelease == 'true' }}

6
Cargo.lock generated
View File

@ -90,9 +90,9 @@ checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c"
[[package]] [[package]]
name = "arc-swap" name = "arc-swap"
version = "1.9.0" version = "1.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a07d1f37ff60921c83bdfc7407723bdefe89b44b98a9b772f225c8f9d67141a6" checksum = "f9f3647c145568cec02c42054e07bdf9a5a698e15b466fb2341bfc393cd24aa5"
dependencies = [ dependencies = [
"rustversion", "rustversion",
] ]
@ -2771,7 +2771,7 @@ checksum = "7b2093cf4c8eb1e67749a6762251bc9cd836b6fc171623bd0a9d324d37af2417"
[[package]] [[package]]
name = "telemt" name = "telemt"
version = "3.3.29" version = "3.4.0"
dependencies = [ dependencies = [
"aes", "aes",
"anyhow", "anyhow",

View File

@ -1,6 +1,6 @@
[package] [package]
name = "telemt" name = "telemt"
version = "3.3.29" version = "3.4.0"
edition = "2024" edition = "2024"
[dependencies] [dependencies]
@ -27,7 +27,7 @@ static_assertions = "1.1"
# Network # Network
socket2 = { version = "0.6", features = ["all"] } socket2 = { version = "0.6", features = ["all"] }
nix = { version = "0.31", default-features = false, features = ["net", "fs"] } nix = { version = "0.31", default-features = false, features = ["net"] }
shadowsocks = { version = "1.24", features = ["aead-cipher-2022"] } shadowsocks = { version = "1.24", features = ["aead-cipher-2022"] }
# Serialization # Serialization

View File

@ -1,5 +1,5 @@
// Cryptobench // Cryptobench
use criterion::{Criterion, black_box, criterion_group}; use criterion::{black_box, criterion_group, Criterion};
fn bench_aes_ctr(c: &mut Criterion) { fn bench_aes_ctr(c: &mut Criterion) {
c.bench_function("aes_ctr_encrypt_64kb", |b| { c.bench_function("aes_ctr_encrypt_64kb", |b| {
@ -9,4 +9,4 @@ fn bench_aes_ctr(c: &mut Criterion) {
black_box(enc.encrypt(&data)) black_box(enc.encrypt(&data))
}) })
}); });
} }

View File

@ -261,7 +261,6 @@ This document lists all configuration keys accepted by `config.toml`.
| alpn_enforce | `bool` | `true` | — | Enforces ALPN echo behavior based on client preference. | | alpn_enforce | `bool` | `true` | — | Enforces ALPN echo behavior based on client preference. |
| mask_proxy_protocol | `u8` | `0` | — | PROXY protocol mode for mask backend (`0` disabled, `1` v1, `2` v2). | | mask_proxy_protocol | `u8` | `0` | — | PROXY protocol mode for mask backend (`0` disabled, `1` v1, `2` v2). |
| mask_shape_hardening | `bool` | `true` | — | Enables client->mask shape-channel hardening by applying controlled tail padding to bucket boundaries on mask relay shutdown. | | mask_shape_hardening | `bool` | `true` | — | Enables client->mask shape-channel hardening by applying controlled tail padding to bucket boundaries on mask relay shutdown. |
| mask_shape_hardening_aggressive_mode | `bool` | `false` | Requires `mask_shape_hardening = true`. | Opt-in aggressive shaping profile: allows shaping on backend-silent non-EOF paths and switches above-cap blur to strictly positive random tail. |
| mask_shape_bucket_floor_bytes | `usize` | `512` | Must be `> 0`; should be `<= mask_shape_bucket_cap_bytes`. | Minimum bucket size used by shape-channel hardening. | | mask_shape_bucket_floor_bytes | `usize` | `512` | Must be `> 0`; should be `<= mask_shape_bucket_cap_bytes`. | Minimum bucket size used by shape-channel hardening. |
| mask_shape_bucket_cap_bytes | `usize` | `4096` | Must be `>= mask_shape_bucket_floor_bytes`. | Maximum bucket size used by shape-channel hardening; traffic above cap is not padded further. | | mask_shape_bucket_cap_bytes | `usize` | `4096` | Must be `>= mask_shape_bucket_floor_bytes`. | Maximum bucket size used by shape-channel hardening; traffic above cap is not padded further. |
| mask_shape_above_cap_blur | `bool` | `false` | Requires `mask_shape_hardening = true`; requires `mask_shape_above_cap_blur_max_bytes > 0`. | Adds bounded randomized tail bytes even when forwarded size already exceeds cap. | | mask_shape_above_cap_blur | `bool` | `false` | Requires `mask_shape_hardening = true`; requires `mask_shape_above_cap_blur_max_bytes > 0`. | Adds bounded randomized tail bytes even when forwarded size already exceeds cap. |
@ -285,27 +284,6 @@ When `mask_shape_hardening = true`, Telemt pads the **client->mask** stream tail
This means multiple nearby probe sizes collapse into the same backend-observed size class, making active classification harder. This means multiple nearby probe sizes collapse into the same backend-observed size class, making active classification harder.
What each parameter changes in practice:
- `mask_shape_hardening`
Enables or disables this entire length-shaping stage on the fallback path.
When `false`, backend-observed length stays close to the real forwarded probe length.
When `true`, clean relay shutdown can append random padding bytes to move the total into a bucket.
- `mask_shape_bucket_floor_bytes`
Sets the first bucket boundary used for small probes.
Example: with floor `512`, a malformed probe that would otherwise forward `37` bytes can be expanded to `512` bytes on clean EOF.
Larger floor values hide very small probes better, but increase egress cost.
- `mask_shape_bucket_cap_bytes`
Sets the largest bucket Telemt will pad up to with bucket logic.
Example: with cap `4096`, a forwarded total of `1800` bytes may be padded to `2048` or `4096` depending on the bucket ladder, but a total already above `4096` will not be bucket-padded further.
Larger cap values increase the range over which size classes are collapsed, but also increase worst-case overhead.
- Clean EOF matters in conservative mode
In the default profile, shape padding is intentionally conservative: it is applied on clean relay shutdown, not on every timeout/drip path.
This avoids introducing new timeout-tail artifacts that some backends or tests interpret as a separate fingerprint.
Practical trade-offs: Practical trade-offs:
- Better anti-fingerprinting on size/shape channel. - Better anti-fingerprinting on size/shape channel.
@ -318,56 +296,14 @@ Recommended starting profile:
- `mask_shape_bucket_floor_bytes = 512` - `mask_shape_bucket_floor_bytes = 512`
- `mask_shape_bucket_cap_bytes = 4096` - `mask_shape_bucket_cap_bytes = 4096`
### Aggressive mode notes (`[censorship]`)
`mask_shape_hardening_aggressive_mode` is an opt-in profile for higher anti-classifier pressure.
- Default is `false` to preserve conservative timeout/no-tail behavior.
- Requires `mask_shape_hardening = true`.
- When enabled, backend-silent non-EOF masking paths may be shaped.
- When enabled together with above-cap blur, the random extra tail uses `[1, max]` instead of `[0, max]`.
What changes when aggressive mode is enabled:
- Backend-silent timeout paths can be shaped
In default mode, a client that keeps the socket half-open and times out will usually not receive shape padding on that path.
In aggressive mode, Telemt may still shape that backend-silent session if no backend bytes were returned.
This is specifically aimed at active probes that try to avoid EOF in order to preserve an exact backend-observed length.
- Above-cap blur always adds at least one byte
In default mode, above-cap blur may choose `0`, so some oversized probes still land on their exact base forwarded length.
In aggressive mode, that exact-base sample is removed by construction.
- Tradeoff
Aggressive mode improves resistance to active length classifiers, but it is more opinionated and less conservative.
If your deployment prioritizes strict compatibility with timeout/no-tail semantics, leave it disabled.
If your threat model includes repeated active probing by a censor, this mode is the stronger profile.
Use this mode only when your threat model prioritizes classifier resistance over strict compatibility with conservative masking semantics.
### Above-cap blur notes (`[censorship]`) ### Above-cap blur notes (`[censorship]`)
`mask_shape_above_cap_blur` adds a second-stage blur for very large probes that are already above `mask_shape_bucket_cap_bytes`. `mask_shape_above_cap_blur` adds a second-stage blur for very large probes that are already above `mask_shape_bucket_cap_bytes`.
- A random tail in `[0, mask_shape_above_cap_blur_max_bytes]` is appended in default mode. - A random tail in `[0, mask_shape_above_cap_blur_max_bytes]` is appended.
- In aggressive mode, the random tail becomes strictly positive: `[1, mask_shape_above_cap_blur_max_bytes]`.
- This reduces exact-size leakage above cap at bounded overhead. - This reduces exact-size leakage above cap at bounded overhead.
- Keep `mask_shape_above_cap_blur_max_bytes` conservative to avoid unnecessary egress growth. - Keep `mask_shape_above_cap_blur_max_bytes` conservative to avoid unnecessary egress growth.
Operational meaning:
- Without above-cap blur
A probe that forwards `5005` bytes will still look like `5005` bytes to the backend if it is already above cap.
- With above-cap blur enabled
That same probe may look like any value in a bounded window above its base length.
Example with `mask_shape_above_cap_blur_max_bytes = 64`:
backend-observed size becomes `5005..5069` in default mode, or `5006..5069` in aggressive mode.
- Choosing `mask_shape_above_cap_blur_max_bytes`
Small values reduce cost but preserve more separability between far-apart oversized classes.
Larger values blur oversized classes more aggressively, but add more egress overhead and more output variance.
### Timing normalization envelope notes (`[censorship]`) ### Timing normalization envelope notes (`[censorship]`)
`mask_timing_normalization_enabled` smooths timing differences between masking outcomes by applying a target duration envelope. `mask_timing_normalization_enabled` smooths timing differences between masking outcomes by applying a target duration envelope.

View File

@ -24,7 +24,10 @@ pub(super) fn success_response<T: Serialize>(
.unwrap() .unwrap()
} }
pub(super) fn error_response(request_id: u64, failure: ApiFailure) -> hyper::Response<Full<Bytes>> { pub(super) fn error_response(
request_id: u64,
failure: ApiFailure,
) -> hyper::Response<Full<Bytes>> {
let payload = ErrorResponse { let payload = ErrorResponse {
ok: false, ok: false,
error: ErrorBody { error: ErrorBody {

View File

@ -1,5 +1,3 @@
#![allow(clippy::too_many_arguments)]
use std::convert::Infallible; use std::convert::Infallible;
use std::net::{IpAddr, SocketAddr}; use std::net::{IpAddr, SocketAddr};
use std::path::PathBuf; use std::path::PathBuf;
@ -21,8 +19,8 @@ use crate::ip_tracker::UserIpTracker;
use crate::proxy::route_mode::RouteRuntimeController; use crate::proxy::route_mode::RouteRuntimeController;
use crate::startup::StartupTracker; use crate::startup::StartupTracker;
use crate::stats::Stats; use crate::stats::Stats;
use crate::transport::UpstreamManager;
use crate::transport::middle_proxy::MePool; use crate::transport::middle_proxy::MePool;
use crate::transport::UpstreamManager;
mod config_store; mod config_store;
mod events; mod events;
@ -38,8 +36,8 @@ mod runtime_zero;
mod users; mod users;
use config_store::{current_revision, parse_if_match}; use config_store::{current_revision, parse_if_match};
use events::ApiEventStore;
use http_utils::{error_response, read_json, read_optional_json, success_response}; use http_utils::{error_response, read_json, read_optional_json, success_response};
use events::ApiEventStore;
use model::{ use model::{
ApiFailure, CreateUserRequest, HealthData, PatchUserRequest, RotateSecretRequest, SummaryData, ApiFailure, CreateUserRequest, HealthData, PatchUserRequest, RotateSecretRequest, SummaryData,
}; };
@ -57,11 +55,11 @@ use runtime_stats::{
MinimalCacheEntry, build_dcs_data, build_me_writers_data, build_minimal_all_data, MinimalCacheEntry, build_dcs_data, build_me_writers_data, build_minimal_all_data,
build_upstreams_data, build_zero_all_data, build_upstreams_data, build_zero_all_data,
}; };
use runtime_watch::spawn_runtime_watchers;
use runtime_zero::{ use runtime_zero::{
build_limits_effective_data, build_runtime_gates_data, build_security_posture_data, build_limits_effective_data, build_runtime_gates_data, build_security_posture_data,
build_system_info_data, build_system_info_data,
}; };
use runtime_watch::spawn_runtime_watchers;
use users::{create_user, delete_user, patch_user, rotate_secret, users_from_config}; use users::{create_user, delete_user, patch_user, rotate_secret, users_from_config};
pub(super) struct ApiRuntimeState { pub(super) struct ApiRuntimeState {
@ -210,15 +208,15 @@ async fn handle(
)); ));
} }
if !api_cfg.whitelist.is_empty() && !api_cfg.whitelist.iter().any(|net| net.contains(peer.ip())) if !api_cfg.whitelist.is_empty()
&& !api_cfg
.whitelist
.iter()
.any(|net| net.contains(peer.ip()))
{ {
return Ok(error_response( return Ok(error_response(
request_id, request_id,
ApiFailure::new( ApiFailure::new(StatusCode::FORBIDDEN, "forbidden", "Source IP is not allowed"),
StatusCode::FORBIDDEN,
"forbidden",
"Source IP is not allowed",
),
)); ));
} }
@ -349,8 +347,7 @@ async fn handle(
} }
("GET", "/v1/runtime/connections/summary") => { ("GET", "/v1/runtime/connections/summary") => {
let revision = current_revision(&shared.config_path).await?; let revision = current_revision(&shared.config_path).await?;
let data = let data = build_runtime_connections_summary_data(shared.as_ref(), cfg.as_ref()).await;
build_runtime_connections_summary_data(shared.as_ref(), cfg.as_ref()).await;
Ok(success_response(StatusCode::OK, data, revision)) Ok(success_response(StatusCode::OK, data, revision))
} }
("GET", "/v1/runtime/events/recent") => { ("GET", "/v1/runtime/events/recent") => {
@ -392,16 +389,13 @@ async fn handle(
let (data, revision) = match result { let (data, revision) = match result {
Ok(ok) => ok, Ok(ok) => ok,
Err(error) => { Err(error) => {
shared shared.runtime_events.record("api.user.create.failed", error.code);
.runtime_events
.record("api.user.create.failed", error.code);
return Err(error); return Err(error);
} }
}; };
shared.runtime_events.record( shared
"api.user.create.ok", .runtime_events
format!("username={}", data.user.username), .record("api.user.create.ok", format!("username={}", data.user.username));
);
Ok(success_response(StatusCode::CREATED, data, revision)) Ok(success_response(StatusCode::CREATED, data, revision))
} }
_ => { _ => {
@ -420,8 +414,7 @@ async fn handle(
detected_ip_v6, detected_ip_v6,
) )
.await; .await;
if let Some(user_info) = if let Some(user_info) = users.into_iter().find(|entry| entry.username == user)
users.into_iter().find(|entry| entry.username == user)
{ {
return Ok(success_response(StatusCode::OK, user_info, revision)); return Ok(success_response(StatusCode::OK, user_info, revision));
} }
@ -442,8 +435,7 @@ async fn handle(
)); ));
} }
let expected_revision = parse_if_match(req.headers()); let expected_revision = parse_if_match(req.headers());
let body = let body = read_json::<PatchUserRequest>(req.into_body(), body_limit).await?;
read_json::<PatchUserRequest>(req.into_body(), body_limit).await?;
let result = patch_user(user, body, expected_revision, &shared).await; let result = patch_user(user, body, expected_revision, &shared).await;
let (data, revision) = match result { let (data, revision) = match result {
Ok(ok) => ok, Ok(ok) => ok,
@ -483,9 +475,10 @@ async fn handle(
return Err(error); return Err(error);
} }
}; };
shared shared.runtime_events.record(
.runtime_events "api.user.delete.ok",
.record("api.user.delete.ok", format!("username={}", deleted_user)); format!("username={}", deleted_user),
);
return Ok(success_response(StatusCode::OK, deleted_user, revision)); return Ok(success_response(StatusCode::OK, deleted_user, revision));
} }
if method == Method::POST if method == Method::POST

View File

@ -167,7 +167,11 @@ async fn current_me_pool_stage_progress(shared: &ApiShared) -> Option<f64> {
let pool = shared.me_pool.read().await.clone()?; let pool = shared.me_pool.read().await.clone()?;
let status = pool.api_status_snapshot().await; let status = pool.api_status_snapshot().await;
let configured_dc_groups = status.configured_dc_groups; let configured_dc_groups = status.configured_dc_groups;
let covered_dc_groups = status.dcs.iter().filter(|dc| dc.alive_writers > 0).count(); let covered_dc_groups = status
.dcs
.iter()
.filter(|dc| dc.alive_writers > 0)
.count();
let dc_coverage = ratio_01(covered_dc_groups, configured_dc_groups); let dc_coverage = ratio_01(covered_dc_groups, configured_dc_groups);
let writer_coverage = ratio_01(status.alive_writers, status.required_writers); let writer_coverage = ratio_01(status.alive_writers, status.required_writers);

View File

@ -2,8 +2,8 @@ use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
use crate::config::ApiConfig; use crate::config::ApiConfig;
use crate::stats::Stats; use crate::stats::Stats;
use crate::transport::UpstreamRouteKind;
use crate::transport::upstream::IpPreference; use crate::transport::upstream::IpPreference;
use crate::transport::UpstreamRouteKind;
use super::ApiShared; use super::ApiShared;
use super::model::{ use super::model::{

View File

@ -128,8 +128,7 @@ pub(super) fn build_system_info_data(
.runtime_state .runtime_state
.last_config_reload_epoch_secs .last_config_reload_epoch_secs
.load(Ordering::Relaxed); .load(Ordering::Relaxed);
let last_config_reload_epoch_secs = let last_config_reload_epoch_secs = (last_reload_epoch_secs > 0).then_some(last_reload_epoch_secs);
(last_reload_epoch_secs > 0).then_some(last_reload_epoch_secs);
let git_commit = option_env!("TELEMT_GIT_COMMIT") let git_commit = option_env!("TELEMT_GIT_COMMIT")
.or(option_env!("VERGEN_GIT_SHA")) .or(option_env!("VERGEN_GIT_SHA"))
@ -154,10 +153,7 @@ pub(super) fn build_system_info_data(
uptime_seconds: shared.stats.uptime_secs(), uptime_seconds: shared.stats.uptime_secs(),
config_path: shared.config_path.display().to_string(), config_path: shared.config_path.display().to_string(),
config_hash: revision.to_string(), config_hash: revision.to_string(),
config_reload_count: shared config_reload_count: shared.runtime_state.config_reload_count.load(Ordering::Relaxed),
.runtime_state
.config_reload_count
.load(Ordering::Relaxed),
last_config_reload_epoch_secs, last_config_reload_epoch_secs,
} }
} }
@ -237,7 +233,9 @@ pub(super) fn build_limits_effective_data(cfg: &ProxyConfig) -> EffectiveLimitsD
adaptive_floor_writers_per_core_total: cfg adaptive_floor_writers_per_core_total: cfg
.general .general
.me_adaptive_floor_writers_per_core_total, .me_adaptive_floor_writers_per_core_total,
adaptive_floor_cpu_cores_override: cfg.general.me_adaptive_floor_cpu_cores_override, adaptive_floor_cpu_cores_override: cfg
.general
.me_adaptive_floor_cpu_cores_override,
adaptive_floor_max_extra_writers_single_per_core: cfg adaptive_floor_max_extra_writers_single_per_core: cfg
.general .general
.me_adaptive_floor_max_extra_writers_single_per_core, .me_adaptive_floor_max_extra_writers_single_per_core,

View File

@ -46,9 +46,7 @@ pub(super) async fn create_user(
None => random_user_secret(), None => random_user_secret(),
}; };
if let Some(ad_tag) = body.user_ad_tag.as_ref() if let Some(ad_tag) = body.user_ad_tag.as_ref() && !is_valid_ad_tag(ad_tag) {
&& !is_valid_ad_tag(ad_tag)
{
return Err(ApiFailure::bad_request( return Err(ApiFailure::bad_request(
"user_ad_tag must be exactly 32 hex characters", "user_ad_tag must be exactly 32 hex characters",
)); ));
@ -67,18 +65,12 @@ pub(super) async fn create_user(
)); ));
} }
cfg.access cfg.access.users.insert(body.username.clone(), secret.clone());
.users
.insert(body.username.clone(), secret.clone());
if let Some(ad_tag) = body.user_ad_tag { if let Some(ad_tag) = body.user_ad_tag {
cfg.access cfg.access.user_ad_tags.insert(body.username.clone(), ad_tag);
.user_ad_tags
.insert(body.username.clone(), ad_tag);
} }
if let Some(limit) = body.max_tcp_conns { if let Some(limit) = body.max_tcp_conns {
cfg.access cfg.access.user_max_tcp_conns.insert(body.username.clone(), limit);
.user_max_tcp_conns
.insert(body.username.clone(), limit);
} }
if let Some(expiration) = expiration { if let Some(expiration) = expiration {
cfg.access cfg.access
@ -86,9 +78,7 @@ pub(super) async fn create_user(
.insert(body.username.clone(), expiration); .insert(body.username.clone(), expiration);
} }
if let Some(quota) = body.data_quota_bytes { if let Some(quota) = body.data_quota_bytes {
cfg.access cfg.access.user_data_quota.insert(body.username.clone(), quota);
.user_data_quota
.insert(body.username.clone(), quota);
} }
let updated_limit = body.max_unique_ips; let updated_limit = body.max_unique_ips;
@ -118,15 +108,11 @@ pub(super) async fn create_user(
touched_sections.push(AccessSection::UserMaxUniqueIps); touched_sections.push(AccessSection::UserMaxUniqueIps);
} }
let revision = let revision = save_access_sections_to_disk(&shared.config_path, &cfg, &touched_sections).await?;
save_access_sections_to_disk(&shared.config_path, &cfg, &touched_sections).await?;
drop(_guard); drop(_guard);
if let Some(limit) = updated_limit { if let Some(limit) = updated_limit {
shared shared.ip_tracker.set_user_limit(&body.username, limit).await;
.ip_tracker
.set_user_limit(&body.username, limit)
.await;
} }
let (detected_ip_v4, detected_ip_v6) = shared.detected_link_ips(); let (detected_ip_v4, detected_ip_v6) = shared.detected_link_ips();
@ -154,7 +140,12 @@ pub(super) async fn create_user(
recent_unique_ips: 0, recent_unique_ips: 0,
recent_unique_ips_list: Vec::new(), recent_unique_ips_list: Vec::new(),
total_octets: 0, total_octets: 0,
links: build_user_links(&cfg, &secret, detected_ip_v4, detected_ip_v6), links: build_user_links(
&cfg,
&secret,
detected_ip_v4,
detected_ip_v6,
),
}); });
Ok((CreateUserResponse { user, secret }, revision)) Ok((CreateUserResponse { user, secret }, revision))
@ -166,16 +157,12 @@ pub(super) async fn patch_user(
expected_revision: Option<String>, expected_revision: Option<String>,
shared: &ApiShared, shared: &ApiShared,
) -> Result<(UserInfo, String), ApiFailure> { ) -> Result<(UserInfo, String), ApiFailure> {
if let Some(secret) = body.secret.as_ref() if let Some(secret) = body.secret.as_ref() && !is_valid_user_secret(secret) {
&& !is_valid_user_secret(secret)
{
return Err(ApiFailure::bad_request( return Err(ApiFailure::bad_request(
"secret must be exactly 32 hex characters", "secret must be exactly 32 hex characters",
)); ));
} }
if let Some(ad_tag) = body.user_ad_tag.as_ref() if let Some(ad_tag) = body.user_ad_tag.as_ref() && !is_valid_ad_tag(ad_tag) {
&& !is_valid_ad_tag(ad_tag)
{
return Err(ApiFailure::bad_request( return Err(ApiFailure::bad_request(
"user_ad_tag must be exactly 32 hex characters", "user_ad_tag must be exactly 32 hex characters",
)); ));
@ -200,14 +187,10 @@ pub(super) async fn patch_user(
cfg.access.user_ad_tags.insert(user.to_string(), ad_tag); cfg.access.user_ad_tags.insert(user.to_string(), ad_tag);
} }
if let Some(limit) = body.max_tcp_conns { if let Some(limit) = body.max_tcp_conns {
cfg.access cfg.access.user_max_tcp_conns.insert(user.to_string(), limit);
.user_max_tcp_conns
.insert(user.to_string(), limit);
} }
if let Some(expiration) = expiration { if let Some(expiration) = expiration {
cfg.access cfg.access.user_expirations.insert(user.to_string(), expiration);
.user_expirations
.insert(user.to_string(), expiration);
} }
if let Some(quota) = body.data_quota_bytes { if let Some(quota) = body.data_quota_bytes {
cfg.access.user_data_quota.insert(user.to_string(), quota); cfg.access.user_data_quota.insert(user.to_string(), quota);
@ -215,9 +198,7 @@ pub(super) async fn patch_user(
let mut updated_limit = None; let mut updated_limit = None;
if let Some(limit) = body.max_unique_ips { if let Some(limit) = body.max_unique_ips {
cfg.access cfg.access.user_max_unique_ips.insert(user.to_string(), limit);
.user_max_unique_ips
.insert(user.to_string(), limit);
updated_limit = Some(limit); updated_limit = Some(limit);
} }
@ -282,8 +263,7 @@ pub(super) async fn rotate_secret(
AccessSection::UserDataQuota, AccessSection::UserDataQuota,
AccessSection::UserMaxUniqueIps, AccessSection::UserMaxUniqueIps,
]; ];
let revision = let revision = save_access_sections_to_disk(&shared.config_path, &cfg, &touched_sections).await?;
save_access_sections_to_disk(&shared.config_path, &cfg, &touched_sections).await?;
drop(_guard); drop(_guard);
let (detected_ip_v4, detected_ip_v6) = shared.detected_link_ips(); let (detected_ip_v4, detected_ip_v6) = shared.detected_link_ips();
@ -350,8 +330,7 @@ pub(super) async fn delete_user(
AccessSection::UserDataQuota, AccessSection::UserDataQuota,
AccessSection::UserMaxUniqueIps, AccessSection::UserMaxUniqueIps,
]; ];
let revision = let revision = save_access_sections_to_disk(&shared.config_path, &cfg, &touched_sections).await?;
save_access_sections_to_disk(&shared.config_path, &cfg, &touched_sections).await?;
drop(_guard); drop(_guard);
shared.ip_tracker.remove_user_limit(user).await; shared.ip_tracker.remove_user_limit(user).await;
shared.ip_tracker.clear_user_ips(user).await; shared.ip_tracker.clear_user_ips(user).await;
@ -386,7 +365,12 @@ pub(super) async fn users_from_config(
.users .users
.get(&username) .get(&username)
.map(|secret| { .map(|secret| {
build_user_links(cfg, secret, startup_detected_ip_v4, startup_detected_ip_v6) build_user_links(
cfg,
secret,
startup_detected_ip_v4,
startup_detected_ip_v6,
)
}) })
.unwrap_or(UserLinks { .unwrap_or(UserLinks {
classic: Vec::new(), classic: Vec::new(),
@ -408,8 +392,10 @@ pub(super) async fn users_from_config(
.get(&username) .get(&username)
.copied() .copied()
.filter(|limit| *limit > 0) .filter(|limit| *limit > 0)
.or((cfg.access.user_max_unique_ips_global_each > 0) .or(
.then_some(cfg.access.user_max_unique_ips_global_each)), (cfg.access.user_max_unique_ips_global_each > 0)
.then_some(cfg.access.user_max_unique_ips_global_each),
),
current_connections: stats.get_user_curr_connects(&username), current_connections: stats.get_user_curr_connects(&username),
active_unique_ips: active_ip_list.len(), active_unique_ips: active_ip_list.len(),
active_unique_ips_list: active_ip_list, active_unique_ips_list: active_ip_list,
@ -495,11 +481,11 @@ fn resolve_link_hosts(
push_unique_host(&mut hosts, host); push_unique_host(&mut hosts, host);
continue; continue;
} }
if let Some(ip) = listener.announce_ip if let Some(ip) = listener.announce_ip {
&& !ip.is_unspecified() if !ip.is_unspecified() {
{ push_unique_host(&mut hosts, &ip.to_string());
push_unique_host(&mut hosts, &ip.to_string()); continue;
continue; }
} }
if listener.ip.is_unspecified() { if listener.ip.is_unspecified() {
let detected_ip = if listener.ip.is_ipv4() { let detected_ip = if listener.ip.is_ipv4() {

View File

@ -1,9 +1,9 @@
//! CLI commands: --init (fire-and-forget setup) //! CLI commands: --init (fire-and-forget setup)
use rand::RngExt;
use std::fs; use std::fs;
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use std::process::Command; use std::process::Command;
use rand::RngExt;
/// Options for the init command /// Options for the init command
pub struct InitOptions { pub struct InitOptions {
@ -35,10 +35,10 @@ pub fn parse_init_args(args: &[String]) -> Option<InitOptions> {
if !args.iter().any(|a| a == "--init") { if !args.iter().any(|a| a == "--init") {
return None; return None;
} }
let mut opts = InitOptions::default(); let mut opts = InitOptions::default();
let mut i = 0; let mut i = 0;
while i < args.len() { while i < args.len() {
match args[i].as_str() { match args[i].as_str() {
"--port" => { "--port" => {
@ -78,7 +78,7 @@ pub fn parse_init_args(args: &[String]) -> Option<InitOptions> {
} }
i += 1; i += 1;
} }
Some(opts) Some(opts)
} }
@ -86,7 +86,7 @@ pub fn parse_init_args(args: &[String]) -> Option<InitOptions> {
pub fn run_init(opts: InitOptions) -> Result<(), Box<dyn std::error::Error>> { pub fn run_init(opts: InitOptions) -> Result<(), Box<dyn std::error::Error>> {
eprintln!("[telemt] Fire-and-forget setup"); eprintln!("[telemt] Fire-and-forget setup");
eprintln!(); eprintln!();
// 1. Generate or validate secret // 1. Generate or validate secret
let secret = match opts.secret { let secret = match opts.secret {
Some(s) => { Some(s) => {
@ -98,28 +98,28 @@ pub fn run_init(opts: InitOptions) -> Result<(), Box<dyn std::error::Error>> {
} }
None => generate_secret(), None => generate_secret(),
}; };
eprintln!("[+] Secret: {}", secret); eprintln!("[+] Secret: {}", secret);
eprintln!("[+] User: {}", opts.username); eprintln!("[+] User: {}", opts.username);
eprintln!("[+] Port: {}", opts.port); eprintln!("[+] Port: {}", opts.port);
eprintln!("[+] Domain: {}", opts.domain); eprintln!("[+] Domain: {}", opts.domain);
// 2. Create config directory // 2. Create config directory
fs::create_dir_all(&opts.config_dir)?; fs::create_dir_all(&opts.config_dir)?;
let config_path = opts.config_dir.join("config.toml"); let config_path = opts.config_dir.join("config.toml");
// 3. Write config // 3. Write config
let config_content = generate_config(&opts.username, &secret, opts.port, &opts.domain); let config_content = generate_config(&opts.username, &secret, opts.port, &opts.domain);
fs::write(&config_path, &config_content)?; fs::write(&config_path, &config_content)?;
eprintln!("[+] Config written to {}", config_path.display()); eprintln!("[+] Config written to {}", config_path.display());
// 4. Write systemd unit // 4. Write systemd unit
let exe_path = let exe_path = std::env::current_exe()
std::env::current_exe().unwrap_or_else(|_| PathBuf::from("/usr/local/bin/telemt")); .unwrap_or_else(|_| PathBuf::from("/usr/local/bin/telemt"));
let unit_path = Path::new("/etc/systemd/system/telemt.service"); let unit_path = Path::new("/etc/systemd/system/telemt.service");
let unit_content = generate_systemd_unit(&exe_path, &config_path); let unit_content = generate_systemd_unit(&exe_path, &config_path);
match fs::write(unit_path, &unit_content) { match fs::write(unit_path, &unit_content) {
Ok(()) => { Ok(()) => {
eprintln!("[+] Systemd unit written to {}", unit_path.display()); eprintln!("[+] Systemd unit written to {}", unit_path.display());
@ -128,31 +128,31 @@ pub fn run_init(opts: InitOptions) -> Result<(), Box<dyn std::error::Error>> {
eprintln!("[!] Cannot write systemd unit (run as root?): {}", e); eprintln!("[!] Cannot write systemd unit (run as root?): {}", e);
eprintln!("[!] Manual unit file content:"); eprintln!("[!] Manual unit file content:");
eprintln!("{}", unit_content); eprintln!("{}", unit_content);
// Still print links and config // Still print links and config
print_links(&opts.username, &secret, opts.port, &opts.domain); print_links(&opts.username, &secret, opts.port, &opts.domain);
return Ok(()); return Ok(());
} }
} }
// 5. Reload systemd // 5. Reload systemd
run_cmd("systemctl", &["daemon-reload"]); run_cmd("systemctl", &["daemon-reload"]);
// 6. Enable service // 6. Enable service
run_cmd("systemctl", &["enable", "telemt.service"]); run_cmd("systemctl", &["enable", "telemt.service"]);
eprintln!("[+] Service enabled"); eprintln!("[+] Service enabled");
// 7. Start service (unless --no-start) // 7. Start service (unless --no-start)
if !opts.no_start { if !opts.no_start {
run_cmd("systemctl", &["start", "telemt.service"]); run_cmd("systemctl", &["start", "telemt.service"]);
eprintln!("[+] Service started"); eprintln!("[+] Service started");
// Brief delay then check status // Brief delay then check status
std::thread::sleep(std::time::Duration::from_secs(1)); std::thread::sleep(std::time::Duration::from_secs(1));
let status = Command::new("systemctl") let status = Command::new("systemctl")
.args(["is-active", "telemt.service"]) .args(["is-active", "telemt.service"])
.output(); .output();
match status { match status {
Ok(out) if out.status.success() => { Ok(out) if out.status.success() => {
eprintln!("[+] Service is running"); eprintln!("[+] Service is running");
@ -166,12 +166,12 @@ pub fn run_init(opts: InitOptions) -> Result<(), Box<dyn std::error::Error>> {
eprintln!("[+] Service not started (--no-start)"); eprintln!("[+] Service not started (--no-start)");
eprintln!("[+] Start manually: systemctl start telemt.service"); eprintln!("[+] Start manually: systemctl start telemt.service");
} }
eprintln!(); eprintln!();
// 8. Print links // 8. Print links
print_links(&opts.username, &secret, opts.port, &opts.domain); print_links(&opts.username, &secret, opts.port, &opts.domain);
Ok(()) Ok(())
} }
@ -183,7 +183,7 @@ fn generate_secret() -> String {
fn generate_config(username: &str, secret: &str, port: u16, domain: &str) -> String { fn generate_config(username: &str, secret: &str, port: u16, domain: &str) -> String {
format!( format!(
r#"# Telemt MTProxy — auto-generated config r#"# Telemt MTProxy — auto-generated config
# Re-run `telemt --init` to regenerate # Re-run `telemt --init` to regenerate
show_link = ["{username}"] show_link = ["{username}"]
@ -266,7 +266,7 @@ weight = 10
fn generate_systemd_unit(exe_path: &Path, config_path: &Path) -> String { fn generate_systemd_unit(exe_path: &Path, config_path: &Path) -> String {
format!( format!(
r#"[Unit] r#"[Unit]
Description=Telemt MTProxy Description=Telemt MTProxy
Documentation=https://github.com/telemt/telemt Documentation=https://github.com/telemt/telemt
After=network-online.target After=network-online.target
@ -309,13 +309,11 @@ fn run_cmd(cmd: &str, args: &[&str]) {
fn print_links(username: &str, secret: &str, port: u16, domain: &str) { fn print_links(username: &str, secret: &str, port: u16, domain: &str) {
let domain_hex = hex::encode(domain); let domain_hex = hex::encode(domain);
println!("=== Proxy Links ==="); println!("=== Proxy Links ===");
println!("[{}]", username); println!("[{}]", username);
println!( println!(" EE-TLS: tg://proxy?server=YOUR_SERVER_IP&port={}&secret=ee{}{}",
" EE-TLS: tg://proxy?server=YOUR_SERVER_IP&port={}&secret=ee{}{}", port, secret, domain_hex);
port, secret, domain_hex
);
println!(); println!();
println!("Replace YOUR_SERVER_IP with your server's public IP."); println!("Replace YOUR_SERVER_IP with your server's public IP.");
println!("The proxy will auto-detect and display the correct link on startup."); println!("The proxy will auto-detect and display the correct link on startup.");

View File

@ -1,6 +1,6 @@
use std::collections::HashMap;
use ipnetwork::IpNetwork; use ipnetwork::IpNetwork;
use serde::Deserialize; use serde::Deserialize;
use std::collections::HashMap;
// Helper defaults kept private to the config module. // Helper defaults kept private to the config module.
const DEFAULT_NETWORK_IPV6: Option<bool> = Some(false); const DEFAULT_NETWORK_IPV6: Option<bool> = Some(false);
@ -143,7 +143,10 @@ pub(crate) fn default_weight() -> u16 {
} }
pub(crate) fn default_metrics_whitelist() -> Vec<IpNetwork> { pub(crate) fn default_metrics_whitelist() -> Vec<IpNetwork> {
vec!["127.0.0.1/32".parse().unwrap(), "::1/128".parse().unwrap()] vec![
"127.0.0.1/32".parse().unwrap(),
"::1/128".parse().unwrap(),
]
} }
pub(crate) fn default_api_listen() -> String { pub(crate) fn default_api_listen() -> String {
@ -166,18 +169,10 @@ pub(crate) fn default_api_minimal_runtime_cache_ttl_ms() -> u64 {
1000 1000
} }
pub(crate) fn default_api_runtime_edge_enabled() -> bool { pub(crate) fn default_api_runtime_edge_enabled() -> bool { false }
false pub(crate) fn default_api_runtime_edge_cache_ttl_ms() -> u64 { 1000 }
} pub(crate) fn default_api_runtime_edge_top_n() -> usize { 10 }
pub(crate) fn default_api_runtime_edge_cache_ttl_ms() -> u64 { pub(crate) fn default_api_runtime_edge_events_capacity() -> usize { 256 }
1000
}
pub(crate) fn default_api_runtime_edge_top_n() -> usize {
10
}
pub(crate) fn default_api_runtime_edge_events_capacity() -> usize {
256
}
pub(crate) fn default_proxy_protocol_header_timeout_ms() -> u64 { pub(crate) fn default_proxy_protocol_header_timeout_ms() -> u64 {
500 500
@ -523,10 +518,6 @@ pub(crate) fn default_mask_shape_hardening() -> bool {
true true
} }
pub(crate) fn default_mask_shape_hardening_aggressive_mode() -> bool {
false
}
pub(crate) fn default_mask_shape_bucket_floor_bytes() -> usize { pub(crate) fn default_mask_shape_bucket_floor_bytes() -> usize {
512 512
} }

View File

@ -31,10 +31,11 @@ use notify::{EventKind, RecursiveMode, Watcher, recommended_watcher};
use tokio::sync::{mpsc, watch}; use tokio::sync::{mpsc, watch};
use tracing::{error, info, warn}; use tracing::{error, info, warn};
use super::load::{LoadedConfig, ProxyConfig};
use crate::config::{ use crate::config::{
LogLevel, MeBindStaleMode, MeFloorMode, MeSocksKdfPolicy, MeTelemetryLevel, MeWriterPickMode, LogLevel, MeBindStaleMode, MeFloorMode, MeSocksKdfPolicy, MeTelemetryLevel,
MeWriterPickMode,
}; };
use super::load::{LoadedConfig, ProxyConfig};
const HOT_RELOAD_DEBOUNCE: Duration = Duration::from_millis(50); const HOT_RELOAD_DEBOUNCE: Duration = Duration::from_millis(50);
@ -43,16 +44,16 @@ const HOT_RELOAD_DEBOUNCE: Duration = Duration::from_millis(50);
/// Fields that are safe to swap without restarting listeners. /// Fields that are safe to swap without restarting listeners.
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
pub struct HotFields { pub struct HotFields {
pub log_level: LogLevel, pub log_level: LogLevel,
pub ad_tag: Option<String>, pub ad_tag: Option<String>,
pub dns_overrides: Vec<String>, pub dns_overrides: Vec<String>,
pub desync_all_full: bool, pub desync_all_full: bool,
pub update_every_secs: u64, pub update_every_secs: u64,
pub me_reinit_every_secs: u64, pub me_reinit_every_secs: u64,
pub me_reinit_singleflight: bool, pub me_reinit_singleflight: bool,
pub me_reinit_coalesce_window_ms: u64, pub me_reinit_coalesce_window_ms: u64,
pub hardswap: bool, pub hardswap: bool,
pub me_pool_drain_ttl_secs: u64, pub me_pool_drain_ttl_secs: u64,
pub me_instadrain: bool, pub me_instadrain: bool,
pub me_pool_drain_threshold: u64, pub me_pool_drain_threshold: u64,
pub me_pool_min_fresh_ratio: f32, pub me_pool_min_fresh_ratio: f32,
@ -112,12 +113,12 @@ pub struct HotFields {
pub me_health_interval_ms_healthy: u64, pub me_health_interval_ms_healthy: u64,
pub me_admission_poll_ms: u64, pub me_admission_poll_ms: u64,
pub me_warn_rate_limit_ms: u64, pub me_warn_rate_limit_ms: u64,
pub users: std::collections::HashMap<String, String>, pub users: std::collections::HashMap<String, String>,
pub user_ad_tags: std::collections::HashMap<String, String>, pub user_ad_tags: std::collections::HashMap<String, String>,
pub user_max_tcp_conns: std::collections::HashMap<String, usize>, pub user_max_tcp_conns: std::collections::HashMap<String, usize>,
pub user_expirations: std::collections::HashMap<String, chrono::DateTime<chrono::Utc>>, pub user_expirations: std::collections::HashMap<String, chrono::DateTime<chrono::Utc>>,
pub user_data_quota: std::collections::HashMap<String, u64>, pub user_data_quota: std::collections::HashMap<String, u64>,
pub user_max_unique_ips: std::collections::HashMap<String, usize>, pub user_max_unique_ips: std::collections::HashMap<String, usize>,
pub user_max_unique_ips_global_each: usize, pub user_max_unique_ips_global_each: usize,
pub user_max_unique_ips_mode: crate::config::UserMaxUniqueIpsMode, pub user_max_unique_ips_mode: crate::config::UserMaxUniqueIpsMode,
pub user_max_unique_ips_window_secs: u64, pub user_max_unique_ips_window_secs: u64,
@ -126,16 +127,16 @@ pub struct HotFields {
impl HotFields { impl HotFields {
pub fn from_config(cfg: &ProxyConfig) -> Self { pub fn from_config(cfg: &ProxyConfig) -> Self {
Self { Self {
log_level: cfg.general.log_level.clone(), log_level: cfg.general.log_level.clone(),
ad_tag: cfg.general.ad_tag.clone(), ad_tag: cfg.general.ad_tag.clone(),
dns_overrides: cfg.network.dns_overrides.clone(), dns_overrides: cfg.network.dns_overrides.clone(),
desync_all_full: cfg.general.desync_all_full, desync_all_full: cfg.general.desync_all_full,
update_every_secs: cfg.general.effective_update_every_secs(), update_every_secs: cfg.general.effective_update_every_secs(),
me_reinit_every_secs: cfg.general.me_reinit_every_secs, me_reinit_every_secs: cfg.general.me_reinit_every_secs,
me_reinit_singleflight: cfg.general.me_reinit_singleflight, me_reinit_singleflight: cfg.general.me_reinit_singleflight,
me_reinit_coalesce_window_ms: cfg.general.me_reinit_coalesce_window_ms, me_reinit_coalesce_window_ms: cfg.general.me_reinit_coalesce_window_ms,
hardswap: cfg.general.hardswap, hardswap: cfg.general.hardswap,
me_pool_drain_ttl_secs: cfg.general.me_pool_drain_ttl_secs, me_pool_drain_ttl_secs: cfg.general.me_pool_drain_ttl_secs,
me_instadrain: cfg.general.me_instadrain, me_instadrain: cfg.general.me_instadrain,
me_pool_drain_threshold: cfg.general.me_pool_drain_threshold, me_pool_drain_threshold: cfg.general.me_pool_drain_threshold,
me_pool_min_fresh_ratio: cfg.general.me_pool_min_fresh_ratio, me_pool_min_fresh_ratio: cfg.general.me_pool_min_fresh_ratio,
@ -188,11 +189,15 @@ impl HotFields {
me_adaptive_floor_min_writers_multi_endpoint: cfg me_adaptive_floor_min_writers_multi_endpoint: cfg
.general .general
.me_adaptive_floor_min_writers_multi_endpoint, .me_adaptive_floor_min_writers_multi_endpoint,
me_adaptive_floor_recover_grace_secs: cfg.general.me_adaptive_floor_recover_grace_secs, me_adaptive_floor_recover_grace_secs: cfg
.general
.me_adaptive_floor_recover_grace_secs,
me_adaptive_floor_writers_per_core_total: cfg me_adaptive_floor_writers_per_core_total: cfg
.general .general
.me_adaptive_floor_writers_per_core_total, .me_adaptive_floor_writers_per_core_total,
me_adaptive_floor_cpu_cores_override: cfg.general.me_adaptive_floor_cpu_cores_override, me_adaptive_floor_cpu_cores_override: cfg
.general
.me_adaptive_floor_cpu_cores_override,
me_adaptive_floor_max_extra_writers_single_per_core: cfg me_adaptive_floor_max_extra_writers_single_per_core: cfg
.general .general
.me_adaptive_floor_max_extra_writers_single_per_core, .me_adaptive_floor_max_extra_writers_single_per_core,
@ -211,15 +216,9 @@ impl HotFields {
me_adaptive_floor_max_warm_writers_global: cfg me_adaptive_floor_max_warm_writers_global: cfg
.general .general
.me_adaptive_floor_max_warm_writers_global, .me_adaptive_floor_max_warm_writers_global,
me_route_backpressure_base_timeout_ms: cfg me_route_backpressure_base_timeout_ms: cfg.general.me_route_backpressure_base_timeout_ms,
.general me_route_backpressure_high_timeout_ms: cfg.general.me_route_backpressure_high_timeout_ms,
.me_route_backpressure_base_timeout_ms, me_route_backpressure_high_watermark_pct: cfg.general.me_route_backpressure_high_watermark_pct,
me_route_backpressure_high_timeout_ms: cfg
.general
.me_route_backpressure_high_timeout_ms,
me_route_backpressure_high_watermark_pct: cfg
.general
.me_route_backpressure_high_watermark_pct,
me_reader_route_data_wait_ms: cfg.general.me_reader_route_data_wait_ms, me_reader_route_data_wait_ms: cfg.general.me_reader_route_data_wait_ms,
me_d2c_flush_batch_max_frames: cfg.general.me_d2c_flush_batch_max_frames, me_d2c_flush_batch_max_frames: cfg.general.me_d2c_flush_batch_max_frames,
me_d2c_flush_batch_max_bytes: cfg.general.me_d2c_flush_batch_max_bytes, me_d2c_flush_batch_max_bytes: cfg.general.me_d2c_flush_batch_max_bytes,
@ -231,12 +230,12 @@ impl HotFields {
me_health_interval_ms_healthy: cfg.general.me_health_interval_ms_healthy, me_health_interval_ms_healthy: cfg.general.me_health_interval_ms_healthy,
me_admission_poll_ms: cfg.general.me_admission_poll_ms, me_admission_poll_ms: cfg.general.me_admission_poll_ms,
me_warn_rate_limit_ms: cfg.general.me_warn_rate_limit_ms, me_warn_rate_limit_ms: cfg.general.me_warn_rate_limit_ms,
users: cfg.access.users.clone(), users: cfg.access.users.clone(),
user_ad_tags: cfg.access.user_ad_tags.clone(), user_ad_tags: cfg.access.user_ad_tags.clone(),
user_max_tcp_conns: cfg.access.user_max_tcp_conns.clone(), user_max_tcp_conns: cfg.access.user_max_tcp_conns.clone(),
user_expirations: cfg.access.user_expirations.clone(), user_expirations: cfg.access.user_expirations.clone(),
user_data_quota: cfg.access.user_data_quota.clone(), user_data_quota: cfg.access.user_data_quota.clone(),
user_max_unique_ips: cfg.access.user_max_unique_ips.clone(), user_max_unique_ips: cfg.access.user_max_unique_ips.clone(),
user_max_unique_ips_global_each: cfg.access.user_max_unique_ips_global_each, user_max_unique_ips_global_each: cfg.access.user_max_unique_ips_global_each,
user_max_unique_ips_mode: cfg.access.user_max_unique_ips_mode, user_max_unique_ips_mode: cfg.access.user_max_unique_ips_mode,
user_max_unique_ips_window_secs: cfg.access.user_max_unique_ips_window_secs, user_max_unique_ips_window_secs: cfg.access.user_max_unique_ips_window_secs,
@ -335,9 +334,7 @@ struct ReloadState {
impl ReloadState { impl ReloadState {
fn new(applied_snapshot_hash: Option<u64>) -> Self { fn new(applied_snapshot_hash: Option<u64>) -> Self {
Self { Self { applied_snapshot_hash }
applied_snapshot_hash,
}
} }
fn is_applied(&self, hash: u64) -> bool { fn is_applied(&self, hash: u64) -> bool {
@ -484,14 +481,10 @@ fn overlay_hot_fields(old: &ProxyConfig, new: &ProxyConfig) -> ProxyConfig {
new.general.me_adaptive_floor_writers_per_core_total; new.general.me_adaptive_floor_writers_per_core_total;
cfg.general.me_adaptive_floor_cpu_cores_override = cfg.general.me_adaptive_floor_cpu_cores_override =
new.general.me_adaptive_floor_cpu_cores_override; new.general.me_adaptive_floor_cpu_cores_override;
cfg.general cfg.general.me_adaptive_floor_max_extra_writers_single_per_core =
.me_adaptive_floor_max_extra_writers_single_per_core = new new.general.me_adaptive_floor_max_extra_writers_single_per_core;
.general cfg.general.me_adaptive_floor_max_extra_writers_multi_per_core =
.me_adaptive_floor_max_extra_writers_single_per_core; new.general.me_adaptive_floor_max_extra_writers_multi_per_core;
cfg.general
.me_adaptive_floor_max_extra_writers_multi_per_core = new
.general
.me_adaptive_floor_max_extra_writers_multi_per_core;
cfg.general.me_adaptive_floor_max_active_writers_per_core = cfg.general.me_adaptive_floor_max_active_writers_per_core =
new.general.me_adaptive_floor_max_active_writers_per_core; new.general.me_adaptive_floor_max_active_writers_per_core;
cfg.general.me_adaptive_floor_max_warm_writers_per_core = cfg.general.me_adaptive_floor_max_warm_writers_per_core =
@ -550,7 +543,8 @@ fn warn_non_hot_changes(old: &ProxyConfig, new: &ProxyConfig, non_hot_changed: b
|| old.server.api.minimal_runtime_cache_ttl_ms || old.server.api.minimal_runtime_cache_ttl_ms
!= new.server.api.minimal_runtime_cache_ttl_ms != new.server.api.minimal_runtime_cache_ttl_ms
|| old.server.api.runtime_edge_enabled != new.server.api.runtime_edge_enabled || old.server.api.runtime_edge_enabled != new.server.api.runtime_edge_enabled
|| old.server.api.runtime_edge_cache_ttl_ms != new.server.api.runtime_edge_cache_ttl_ms || old.server.api.runtime_edge_cache_ttl_ms
!= new.server.api.runtime_edge_cache_ttl_ms
|| old.server.api.runtime_edge_top_n != new.server.api.runtime_edge_top_n || old.server.api.runtime_edge_top_n != new.server.api.runtime_edge_top_n
|| old.server.api.runtime_edge_events_capacity || old.server.api.runtime_edge_events_capacity
!= new.server.api.runtime_edge_events_capacity != new.server.api.runtime_edge_events_capacity
@ -589,8 +583,10 @@ fn warn_non_hot_changes(old: &ProxyConfig, new: &ProxyConfig, non_hot_changed: b
|| old.censorship.mask_shape_hardening != new.censorship.mask_shape_hardening || old.censorship.mask_shape_hardening != new.censorship.mask_shape_hardening
|| old.censorship.mask_shape_bucket_floor_bytes || old.censorship.mask_shape_bucket_floor_bytes
!= new.censorship.mask_shape_bucket_floor_bytes != new.censorship.mask_shape_bucket_floor_bytes
|| old.censorship.mask_shape_bucket_cap_bytes != new.censorship.mask_shape_bucket_cap_bytes || old.censorship.mask_shape_bucket_cap_bytes
|| old.censorship.mask_shape_above_cap_blur != new.censorship.mask_shape_above_cap_blur != new.censorship.mask_shape_bucket_cap_bytes
|| old.censorship.mask_shape_above_cap_blur
!= new.censorship.mask_shape_above_cap_blur
|| old.censorship.mask_shape_above_cap_blur_max_bytes || old.censorship.mask_shape_above_cap_blur_max_bytes
!= new.censorship.mask_shape_above_cap_blur_max_bytes != new.censorship.mask_shape_above_cap_blur_max_bytes
|| old.censorship.mask_timing_normalization_enabled || old.censorship.mask_timing_normalization_enabled
@ -874,7 +870,8 @@ fn log_changes(
{ {
info!( info!(
"config reload: me_bind_stale: mode={:?} ttl={}s", "config reload: me_bind_stale: mode={:?} ttl={}s",
new_hot.me_bind_stale_mode, new_hot.me_bind_stale_ttl_secs new_hot.me_bind_stale_mode,
new_hot.me_bind_stale_ttl_secs
); );
} }
if old_hot.me_secret_atomic_snapshot != new_hot.me_secret_atomic_snapshot if old_hot.me_secret_atomic_snapshot != new_hot.me_secret_atomic_snapshot
@ -954,7 +951,8 @@ fn log_changes(
if old_hot.me_socks_kdf_policy != new_hot.me_socks_kdf_policy { if old_hot.me_socks_kdf_policy != new_hot.me_socks_kdf_policy {
info!( info!(
"config reload: me_socks_kdf_policy: {:?} → {:?}", "config reload: me_socks_kdf_policy: {:?} → {:?}",
old_hot.me_socks_kdf_policy, new_hot.me_socks_kdf_policy, old_hot.me_socks_kdf_policy,
new_hot.me_socks_kdf_policy,
); );
} }
@ -1008,7 +1006,8 @@ fn log_changes(
|| old_hot.me_route_backpressure_high_watermark_pct || old_hot.me_route_backpressure_high_watermark_pct
!= new_hot.me_route_backpressure_high_watermark_pct != new_hot.me_route_backpressure_high_watermark_pct
|| old_hot.me_reader_route_data_wait_ms != new_hot.me_reader_route_data_wait_ms || old_hot.me_reader_route_data_wait_ms != new_hot.me_reader_route_data_wait_ms
|| old_hot.me_health_interval_ms_unhealthy != new_hot.me_health_interval_ms_unhealthy || old_hot.me_health_interval_ms_unhealthy
!= new_hot.me_health_interval_ms_unhealthy
|| old_hot.me_health_interval_ms_healthy != new_hot.me_health_interval_ms_healthy || old_hot.me_health_interval_ms_healthy != new_hot.me_health_interval_ms_healthy
|| old_hot.me_admission_poll_ms != new_hot.me_admission_poll_ms || old_hot.me_admission_poll_ms != new_hot.me_admission_poll_ms
|| old_hot.me_warn_rate_limit_ms != new_hot.me_warn_rate_limit_ms || old_hot.me_warn_rate_limit_ms != new_hot.me_warn_rate_limit_ms
@ -1045,27 +1044,19 @@ fn log_changes(
} }
if old_hot.users != new_hot.users { if old_hot.users != new_hot.users {
let mut added: Vec<&String> = new_hot let mut added: Vec<&String> = new_hot.users.keys()
.users
.keys()
.filter(|u| !old_hot.users.contains_key(*u)) .filter(|u| !old_hot.users.contains_key(*u))
.collect(); .collect();
added.sort(); added.sort();
let mut removed: Vec<&String> = old_hot let mut removed: Vec<&String> = old_hot.users.keys()
.users
.keys()
.filter(|u| !new_hot.users.contains_key(*u)) .filter(|u| !new_hot.users.contains_key(*u))
.collect(); .collect();
removed.sort(); removed.sort();
let mut changed: Vec<&String> = new_hot let mut changed: Vec<&String> = new_hot.users.keys()
.users
.keys()
.filter(|u| { .filter(|u| {
old_hot old_hot.users.get(*u)
.users
.get(*u)
.map(|s| s != &new_hot.users[*u]) .map(|s| s != &new_hot.users[*u])
.unwrap_or(false) .unwrap_or(false)
}) })
@ -1075,18 +1066,10 @@ fn log_changes(
if !added.is_empty() { if !added.is_empty() {
info!( info!(
"config reload: users added: [{}]", "config reload: users added: [{}]",
added added.iter().map(|s| s.as_str()).collect::<Vec<_>>().join(", ")
.iter()
.map(|s| s.as_str())
.collect::<Vec<_>>()
.join(", ")
); );
let host = resolve_link_host(new_cfg, detected_ip_v4, detected_ip_v6); let host = resolve_link_host(new_cfg, detected_ip_v4, detected_ip_v6);
let port = new_cfg let port = new_cfg.general.links.public_port.unwrap_or(new_cfg.server.port);
.general
.links
.public_port
.unwrap_or(new_cfg.server.port);
for user in &added { for user in &added {
if let Some(secret) = new_hot.users.get(*user) { if let Some(secret) = new_hot.users.get(*user) {
print_user_links(user, secret, &host, port, new_cfg); print_user_links(user, secret, &host, port, new_cfg);
@ -1096,21 +1079,13 @@ fn log_changes(
if !removed.is_empty() { if !removed.is_empty() {
info!( info!(
"config reload: users removed: [{}]", "config reload: users removed: [{}]",
removed removed.iter().map(|s| s.as_str()).collect::<Vec<_>>().join(", ")
.iter()
.map(|s| s.as_str())
.collect::<Vec<_>>()
.join(", ")
); );
} }
if !changed.is_empty() { if !changed.is_empty() {
info!( info!(
"config reload: users secret changed: [{}]", "config reload: users secret changed: [{}]",
changed changed.iter().map(|s| s.as_str()).collect::<Vec<_>>().join(", ")
.iter()
.map(|s| s.as_str())
.collect::<Vec<_>>()
.join(", ")
); );
} }
} }
@ -1141,7 +1116,8 @@ fn log_changes(
} }
if old_hot.user_max_unique_ips_global_each != new_hot.user_max_unique_ips_global_each if old_hot.user_max_unique_ips_global_each != new_hot.user_max_unique_ips_global_each
|| old_hot.user_max_unique_ips_mode != new_hot.user_max_unique_ips_mode || old_hot.user_max_unique_ips_mode != new_hot.user_max_unique_ips_mode
|| old_hot.user_max_unique_ips_window_secs != new_hot.user_max_unique_ips_window_secs || old_hot.user_max_unique_ips_window_secs
!= new_hot.user_max_unique_ips_window_secs
{ {
info!( info!(
"config reload: user_max_unique_ips policy global_each={} mode={:?} window={}s", "config reload: user_max_unique_ips policy global_each={} mode={:?} window={}s",
@ -1176,10 +1152,7 @@ fn reload_config(
let next_manifest = WatchManifest::from_source_files(&source_files); let next_manifest = WatchManifest::from_source_files(&source_files);
if let Err(e) = new_cfg.validate() { if let Err(e) = new_cfg.validate() {
error!( error!("config reload: validation failed: {}; keeping old config", e);
"config reload: validation failed: {}; keeping old config",
e
);
return Some(next_manifest); return Some(next_manifest);
} }
@ -1244,7 +1217,7 @@ pub fn spawn_config_watcher(
) -> (watch::Receiver<Arc<ProxyConfig>>, watch::Receiver<LogLevel>) { ) -> (watch::Receiver<Arc<ProxyConfig>>, watch::Receiver<LogLevel>) {
let initial_level = initial.general.log_level.clone(); let initial_level = initial.general.log_level.clone();
let (config_tx, config_rx) = watch::channel(initial); let (config_tx, config_rx) = watch::channel(initial);
let (log_tx, log_rx) = watch::channel(initial_level); let (log_tx, log_rx) = watch::channel(initial_level);
let config_path = normalize_watch_path(&config_path); let config_path = normalize_watch_path(&config_path);
let initial_loaded = ProxyConfig::load_with_metadata(&config_path).ok(); let initial_loaded = ProxyConfig::load_with_metadata(&config_path).ok();
@ -1261,29 +1234,25 @@ pub fn spawn_config_watcher(
let tx_inotify = notify_tx.clone(); let tx_inotify = notify_tx.clone();
let manifest_for_inotify = manifest_state.clone(); let manifest_for_inotify = manifest_state.clone();
let mut inotify_watcher = let mut inotify_watcher = match recommended_watcher(move |res: notify::Result<notify::Event>| {
match recommended_watcher(move |res: notify::Result<notify::Event>| { let Ok(event) = res else { return };
let Ok(event) = res else { return }; if !matches!(event.kind, EventKind::Modify(_) | EventKind::Create(_) | EventKind::Remove(_)) {
if !matches!( return;
event.kind, }
EventKind::Modify(_) | EventKind::Create(_) | EventKind::Remove(_) let is_our_file = manifest_for_inotify
) { .read()
return; .map(|manifest| manifest.matches_event_paths(&event.paths))
} .unwrap_or(false);
let is_our_file = manifest_for_inotify if is_our_file {
.read() let _ = tx_inotify.try_send(());
.map(|manifest| manifest.matches_event_paths(&event.paths)) }
.unwrap_or(false); }) {
if is_our_file { Ok(watcher) => Some(watcher),
let _ = tx_inotify.try_send(()); Err(e) => {
} warn!("config watcher: inotify unavailable: {}", e);
}) { None
Ok(watcher) => Some(watcher), }
Err(e) => { };
warn!("config watcher: inotify unavailable: {}", e);
None
}
};
apply_watch_manifest( apply_watch_manifest(
inotify_watcher.as_mut(), inotify_watcher.as_mut(),
Option::<&mut notify::poll::PollWatcher>::None, Option::<&mut notify::poll::PollWatcher>::None,
@ -1299,10 +1268,7 @@ pub fn spawn_config_watcher(
let mut poll_watcher = match notify::poll::PollWatcher::new( let mut poll_watcher = match notify::poll::PollWatcher::new(
move |res: notify::Result<notify::Event>| { move |res: notify::Result<notify::Event>| {
let Ok(event) = res else { return }; let Ok(event) = res else { return };
if !matches!( if !matches!(event.kind, EventKind::Modify(_) | EventKind::Create(_) | EventKind::Remove(_)) {
event.kind,
EventKind::Modify(_) | EventKind::Create(_) | EventKind::Remove(_)
) {
return; return;
} }
let is_our_file = manifest_for_poll let is_our_file = manifest_for_poll
@ -1350,9 +1316,7 @@ pub fn spawn_config_watcher(
} }
} }
#[cfg(not(unix))] #[cfg(not(unix))]
if notify_rx.recv().await.is_none() { if notify_rx.recv().await.is_none() { break; }
break;
}
// Debounce: drain extra events that arrive within a short quiet window. // Debounce: drain extra events that arrive within a short quiet window.
tokio::time::sleep(HOT_RELOAD_DEBOUNCE).await; tokio::time::sleep(HOT_RELOAD_DEBOUNCE).await;
@ -1454,10 +1418,7 @@ mod tests {
new.server.port = old.server.port.saturating_add(1); new.server.port = old.server.port.saturating_add(1);
let applied = overlay_hot_fields(&old, &new); let applied = overlay_hot_fields(&old, &new);
assert_eq!( assert_eq!(HotFields::from_config(&old), HotFields::from_config(&applied));
HotFields::from_config(&old),
HotFields::from_config(&applied)
);
assert_eq!(applied.server.port, old.server.port); assert_eq!(applied.server.port, old.server.port);
} }
@ -1476,10 +1437,7 @@ mod tests {
applied.general.me_bind_stale_mode, applied.general.me_bind_stale_mode,
new.general.me_bind_stale_mode new.general.me_bind_stale_mode
); );
assert_ne!( assert_ne!(HotFields::from_config(&old), HotFields::from_config(&applied));
HotFields::from_config(&old),
HotFields::from_config(&applied)
);
} }
#[test] #[test]
@ -1493,10 +1451,7 @@ mod tests {
applied.general.me_keepalive_interval_secs, applied.general.me_keepalive_interval_secs,
old.general.me_keepalive_interval_secs old.general.me_keepalive_interval_secs
); );
assert_eq!( assert_eq!(HotFields::from_config(&old), HotFields::from_config(&applied));
HotFields::from_config(&old),
HotFields::from_config(&applied)
);
} }
#[test] #[test]
@ -1508,10 +1463,7 @@ mod tests {
let applied = overlay_hot_fields(&old, &new); let applied = overlay_hot_fields(&old, &new);
assert_eq!(applied.general.hardswap, new.general.hardswap); assert_eq!(applied.general.hardswap, new.general.hardswap);
assert_eq!( assert_eq!(applied.general.use_middle_proxy, old.general.use_middle_proxy);
applied.general.use_middle_proxy,
old.general.use_middle_proxy
);
assert!(!config_equal(&applied, &new)); assert!(!config_equal(&applied, &new));
} }
@ -1523,19 +1475,14 @@ mod tests {
write_reload_config(&path, Some(initial_tag), None); write_reload_config(&path, Some(initial_tag), None);
let initial_cfg = Arc::new(ProxyConfig::load(&path).unwrap()); let initial_cfg = Arc::new(ProxyConfig::load(&path).unwrap());
let initial_hash = ProxyConfig::load_with_metadata(&path) let initial_hash = ProxyConfig::load_with_metadata(&path).unwrap().rendered_hash;
.unwrap()
.rendered_hash;
let (config_tx, _config_rx) = watch::channel(initial_cfg.clone()); let (config_tx, _config_rx) = watch::channel(initial_cfg.clone());
let (log_tx, _log_rx) = watch::channel(initial_cfg.general.log_level.clone()); let (log_tx, _log_rx) = watch::channel(initial_cfg.general.log_level.clone());
let mut reload_state = ReloadState::new(Some(initial_hash)); let mut reload_state = ReloadState::new(Some(initial_hash));
write_reload_config(&path, Some(final_tag), None); write_reload_config(&path, Some(final_tag), None);
reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap(); reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap();
assert_eq!( assert_eq!(config_tx.borrow().general.ad_tag.as_deref(), Some(final_tag));
config_tx.borrow().general.ad_tag.as_deref(),
Some(final_tag)
);
let _ = std::fs::remove_file(path); let _ = std::fs::remove_file(path);
} }
@ -1548,9 +1495,7 @@ mod tests {
write_reload_config(&path, Some(initial_tag), None); write_reload_config(&path, Some(initial_tag), None);
let initial_cfg = Arc::new(ProxyConfig::load(&path).unwrap()); let initial_cfg = Arc::new(ProxyConfig::load(&path).unwrap());
let initial_hash = ProxyConfig::load_with_metadata(&path) let initial_hash = ProxyConfig::load_with_metadata(&path).unwrap().rendered_hash;
.unwrap()
.rendered_hash;
let (config_tx, _config_rx) = watch::channel(initial_cfg.clone()); let (config_tx, _config_rx) = watch::channel(initial_cfg.clone());
let (log_tx, _log_rx) = watch::channel(initial_cfg.general.log_level.clone()); let (log_tx, _log_rx) = watch::channel(initial_cfg.general.log_level.clone());
let mut reload_state = ReloadState::new(Some(initial_hash)); let mut reload_state = ReloadState::new(Some(initial_hash));
@ -1573,9 +1518,7 @@ mod tests {
write_reload_config(&path, Some(initial_tag), None); write_reload_config(&path, Some(initial_tag), None);
let initial_cfg = Arc::new(ProxyConfig::load(&path).unwrap()); let initial_cfg = Arc::new(ProxyConfig::load(&path).unwrap());
let initial_hash = ProxyConfig::load_with_metadata(&path) let initial_hash = ProxyConfig::load_with_metadata(&path).unwrap().rendered_hash;
.unwrap()
.rendered_hash;
let (config_tx, _config_rx) = watch::channel(initial_cfg.clone()); let (config_tx, _config_rx) = watch::channel(initial_cfg.clone());
let (log_tx, _log_rx) = watch::channel(initial_cfg.general.log_level.clone()); let (log_tx, _log_rx) = watch::channel(initial_cfg.general.log_level.clone());
let mut reload_state = ReloadState::new(Some(initial_hash)); let mut reload_state = ReloadState::new(Some(initial_hash));
@ -1589,10 +1532,7 @@ mod tests {
write_reload_config(&path, Some(final_tag), None); write_reload_config(&path, Some(final_tag), None);
reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap(); reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap();
assert_eq!( assert_eq!(config_tx.borrow().general.ad_tag.as_deref(), Some(final_tag));
config_tx.borrow().general.ad_tag.as_deref(),
Some(final_tag)
);
let _ = std::fs::remove_file(path); let _ = std::fs::remove_file(path);
} }

View File

@ -399,18 +399,11 @@ impl ProxyConfig {
)); ));
} }
if config.censorship.mask_shape_above_cap_blur && !config.censorship.mask_shape_hardening { if config.censorship.mask_shape_above_cap_blur
return Err(ProxyError::Config(
"censorship.mask_shape_above_cap_blur requires censorship.mask_shape_hardening = true"
.to_string(),
));
}
if config.censorship.mask_shape_hardening_aggressive_mode
&& !config.censorship.mask_shape_hardening && !config.censorship.mask_shape_hardening
{ {
return Err(ProxyError::Config( return Err(ProxyError::Config(
"censorship.mask_shape_hardening_aggressive_mode requires censorship.mask_shape_hardening = true" "censorship.mask_shape_above_cap_blur requires censorship.mask_shape_hardening = true"
.to_string(), .to_string(),
)); ));
} }
@ -426,7 +419,8 @@ impl ProxyConfig {
if config.censorship.mask_shape_above_cap_blur_max_bytes > 1_048_576 { if config.censorship.mask_shape_above_cap_blur_max_bytes > 1_048_576 {
return Err(ProxyError::Config( return Err(ProxyError::Config(
"censorship.mask_shape_above_cap_blur_max_bytes must be <= 1048576".to_string(), "censorship.mask_shape_above_cap_blur_max_bytes must be <= 1048576"
.to_string(),
)); ));
} }
@ -450,7 +444,8 @@ impl ProxyConfig {
if config.censorship.mask_timing_normalization_ceiling_ms > 60_000 { if config.censorship.mask_timing_normalization_ceiling_ms > 60_000 {
return Err(ProxyError::Config( return Err(ProxyError::Config(
"censorship.mask_timing_normalization_ceiling_ms must be <= 60000".to_string(), "censorship.mask_timing_normalization_ceiling_ms must be <= 60000"
.to_string(),
)); ));
} }
@ -466,7 +461,8 @@ impl ProxyConfig {
)); ));
} }
if config.timeouts.relay_client_idle_hard_secs < config.timeouts.relay_client_idle_soft_secs if config.timeouts.relay_client_idle_hard_secs
< config.timeouts.relay_client_idle_soft_secs
{ {
return Err(ProxyError::Config( return Err(ProxyError::Config(
"timeouts.relay_client_idle_hard_secs must be >= timeouts.relay_client_idle_soft_secs" "timeouts.relay_client_idle_hard_secs must be >= timeouts.relay_client_idle_soft_secs"
@ -474,9 +470,7 @@ impl ProxyConfig {
)); ));
} }
if config if config.timeouts.relay_idle_grace_after_downstream_activity_secs
.timeouts
.relay_idle_grace_after_downstream_activity_secs
> config.timeouts.relay_client_idle_hard_secs > config.timeouts.relay_client_idle_hard_secs
{ {
return Err(ProxyError::Config( return Err(ProxyError::Config(
@ -773,8 +767,7 @@ impl ProxyConfig {
} }
if config.general.me_route_backpressure_base_timeout_ms > 5000 { if config.general.me_route_backpressure_base_timeout_ms > 5000 {
return Err(ProxyError::Config( return Err(ProxyError::Config(
"general.me_route_backpressure_base_timeout_ms must be within [1, 5000]" "general.me_route_backpressure_base_timeout_ms must be within [1, 5000]".to_string(),
.to_string(),
)); ));
} }
@ -787,8 +780,7 @@ impl ProxyConfig {
} }
if config.general.me_route_backpressure_high_timeout_ms > 5000 { if config.general.me_route_backpressure_high_timeout_ms > 5000 {
return Err(ProxyError::Config( return Err(ProxyError::Config(
"general.me_route_backpressure_high_timeout_ms must be within [1, 5000]" "general.me_route_backpressure_high_timeout_ms must be within [1, 5000]".to_string(),
.to_string(),
)); ));
} }
@ -1836,9 +1828,7 @@ mod tests {
let path = dir.join("telemt_me_route_backpressure_base_timeout_ms_out_of_range_test.toml"); let path = dir.join("telemt_me_route_backpressure_base_timeout_ms_out_of_range_test.toml");
std::fs::write(&path, toml).unwrap(); std::fs::write(&path, toml).unwrap();
let err = ProxyConfig::load(&path).unwrap_err().to_string(); let err = ProxyConfig::load(&path).unwrap_err().to_string();
assert!( assert!(err.contains("general.me_route_backpressure_base_timeout_ms must be within [1, 5000]"));
err.contains("general.me_route_backpressure_base_timeout_ms must be within [1, 5000]")
);
let _ = std::fs::remove_file(path); let _ = std::fs::remove_file(path);
} }
@ -1859,9 +1849,7 @@ mod tests {
let path = dir.join("telemt_me_route_backpressure_high_timeout_ms_out_of_range_test.toml"); let path = dir.join("telemt_me_route_backpressure_high_timeout_ms_out_of_range_test.toml");
std::fs::write(&path, toml).unwrap(); std::fs::write(&path, toml).unwrap();
let err = ProxyConfig::load(&path).unwrap_err().to_string(); let err = ProxyConfig::load(&path).unwrap_err().to_string();
assert!( assert!(err.contains("general.me_route_backpressure_high_timeout_ms must be within [1, 5000]"));
err.contains("general.me_route_backpressure_high_timeout_ms must be within [1, 5000]")
);
let _ = std::fs::remove_file(path); let _ = std::fs::remove_file(path);
} }

View File

@ -1,9 +1,9 @@
//! Configuration. //! Configuration.
pub(crate) mod defaults; pub(crate) mod defaults;
pub mod hot_reload;
mod load;
mod types; mod types;
mod load;
pub mod hot_reload;
pub use load::ProxyConfig; pub use load::ProxyConfig;
pub use types::*; pub use types::*;

View File

@ -30,9 +30,7 @@ relay_client_idle_hard_secs = 60
let err = ProxyConfig::load(&path).expect_err("config with hard<soft must fail"); let err = ProxyConfig::load(&path).expect_err("config with hard<soft must fail");
let msg = err.to_string(); let msg = err.to_string();
assert!( assert!(
msg.contains( msg.contains("timeouts.relay_client_idle_hard_secs must be >= timeouts.relay_client_idle_soft_secs"),
"timeouts.relay_client_idle_hard_secs must be >= timeouts.relay_client_idle_soft_secs"
),
"error must explain the violated hard>=soft invariant, got: {msg}" "error must explain the violated hard>=soft invariant, got: {msg}"
); );

View File

@ -91,13 +91,11 @@ mask_shape_above_cap_blur_max_bytes = 64
"#, "#,
); );
let err = let err = ProxyConfig::load(&path)
ProxyConfig::load(&path).expect_err("above-cap blur must require shape hardening enabled"); .expect_err("above-cap blur must require shape hardening enabled");
let msg = err.to_string(); let msg = err.to_string();
assert!( assert!(
msg.contains( msg.contains("censorship.mask_shape_above_cap_blur requires censorship.mask_shape_hardening = true"),
"censorship.mask_shape_above_cap_blur requires censorship.mask_shape_hardening = true"
),
"error must explain blur prerequisite, got: {msg}" "error must explain blur prerequisite, got: {msg}"
); );
@ -115,8 +113,8 @@ mask_shape_above_cap_blur_max_bytes = 0
"#, "#,
); );
let err = let err = ProxyConfig::load(&path)
ProxyConfig::load(&path).expect_err("above-cap blur max bytes must be > 0 when enabled"); .expect_err("above-cap blur max bytes must be > 0 when enabled");
let msg = err.to_string(); let msg = err.to_string();
assert!( assert!(
msg.contains("censorship.mask_shape_above_cap_blur_max_bytes must be > 0 when censorship.mask_shape_above_cap_blur is enabled"), msg.contains("censorship.mask_shape_above_cap_blur_max_bytes must be > 0 when censorship.mask_shape_above_cap_blur is enabled"),
@ -137,8 +135,8 @@ mask_timing_normalization_ceiling_ms = 200
"#, "#,
); );
let err = let err = ProxyConfig::load(&path)
ProxyConfig::load(&path).expect_err("timing normalization floor must be > 0 when enabled"); .expect_err("timing normalization floor must be > 0 when enabled");
let msg = err.to_string(); let msg = err.to_string();
assert!( assert!(
msg.contains("censorship.mask_timing_normalization_floor_ms must be > 0 when censorship.mask_timing_normalization_enabled is true"), msg.contains("censorship.mask_timing_normalization_floor_ms must be > 0 when censorship.mask_timing_normalization_enabled is true"),
@ -159,7 +157,8 @@ mask_timing_normalization_ceiling_ms = 200
"#, "#,
); );
let err = ProxyConfig::load(&path).expect_err("timing normalization ceiling must be >= floor"); let err = ProxyConfig::load(&path)
.expect_err("timing normalization ceiling must be >= floor");
let msg = err.to_string(); let msg = err.to_string();
assert!( assert!(
msg.contains("censorship.mask_timing_normalization_ceiling_ms must be >= censorship.mask_timing_normalization_floor_ms"), msg.contains("censorship.mask_timing_normalization_ceiling_ms must be >= censorship.mask_timing_normalization_floor_ms"),
@ -194,45 +193,3 @@ mask_timing_normalization_ceiling_ms = 240
remove_temp_config(&path); remove_temp_config(&path);
} }
#[test]
fn load_rejects_aggressive_shape_mode_when_shape_hardening_disabled() {
let path = write_temp_config(
r#"
[censorship]
mask_shape_hardening = false
mask_shape_hardening_aggressive_mode = true
"#,
);
let err = ProxyConfig::load(&path)
.expect_err("aggressive shape hardening mode must require shape hardening enabled");
let msg = err.to_string();
assert!(
msg.contains("censorship.mask_shape_hardening_aggressive_mode requires censorship.mask_shape_hardening = true"),
"error must explain aggressive-mode prerequisite, got: {msg}"
);
remove_temp_config(&path);
}
#[test]
fn load_accepts_aggressive_shape_mode_when_shape_hardening_enabled() {
let path = write_temp_config(
r#"
[censorship]
mask_shape_hardening = true
mask_shape_hardening_aggressive_mode = true
mask_shape_above_cap_blur = true
mask_shape_above_cap_blur_max_bytes = 8
"#,
);
let cfg = ProxyConfig::load(&path)
.expect("aggressive shape hardening mode should be accepted when prerequisites are met");
assert!(cfg.censorship.mask_shape_hardening);
assert!(cfg.censorship.mask_shape_hardening_aggressive_mode);
assert!(cfg.censorship.mask_shape_above_cap_blur);
remove_temp_config(&path);
}

View File

@ -29,13 +29,11 @@ server_hello_delay_max_ms = 1000
"#, "#,
); );
let err = let err = ProxyConfig::load(&path)
ProxyConfig::load(&path).expect_err("delay equal to handshake timeout must be rejected"); .expect_err("delay equal to handshake timeout must be rejected");
let msg = err.to_string(); let msg = err.to_string();
assert!( assert!(
msg.contains( msg.contains("censorship.server_hello_delay_max_ms must be < timeouts.client_handshake * 1000"),
"censorship.server_hello_delay_max_ms must be < timeouts.client_handshake * 1000"
),
"error must explain delay<timeout invariant, got: {msg}" "error must explain delay<timeout invariant, got: {msg}"
); );
@ -54,13 +52,11 @@ server_hello_delay_max_ms = 1500
"#, "#,
); );
let err = let err = ProxyConfig::load(&path)
ProxyConfig::load(&path).expect_err("delay larger than handshake timeout must be rejected"); .expect_err("delay larger than handshake timeout must be rejected");
let msg = err.to_string(); let msg = err.to_string();
assert!( assert!(
msg.contains( msg.contains("censorship.server_hello_delay_max_ms must be < timeouts.client_handshake * 1000"),
"censorship.server_hello_delay_max_ms must be < timeouts.client_handshake * 1000"
),
"error must explain delay<timeout invariant, got: {msg}" "error must explain delay<timeout invariant, got: {msg}"
); );
@ -79,8 +75,8 @@ server_hello_delay_max_ms = 999
"#, "#,
); );
let cfg = let cfg = ProxyConfig::load(&path)
ProxyConfig::load(&path).expect("delay below handshake timeout budget must be accepted"); .expect("delay below handshake timeout budget must be accepted");
assert_eq!(cfg.timeouts.client_handshake, 1); assert_eq!(cfg.timeouts.client_handshake, 1);
assert_eq!(cfg.censorship.server_hello_delay_max_ms, 999); assert_eq!(cfg.censorship.server_hello_delay_max_ms, 999);

View File

@ -1047,7 +1047,8 @@ impl Default for GeneralConfig {
me_pool_drain_soft_evict_per_writer: default_me_pool_drain_soft_evict_per_writer(), me_pool_drain_soft_evict_per_writer: default_me_pool_drain_soft_evict_per_writer(),
me_pool_drain_soft_evict_budget_per_core: me_pool_drain_soft_evict_budget_per_core:
default_me_pool_drain_soft_evict_budget_per_core(), default_me_pool_drain_soft_evict_budget_per_core(),
me_pool_drain_soft_evict_cooldown_ms: default_me_pool_drain_soft_evict_cooldown_ms(), me_pool_drain_soft_evict_cooldown_ms:
default_me_pool_drain_soft_evict_cooldown_ms(),
me_bind_stale_mode: MeBindStaleMode::default(), me_bind_stale_mode: MeBindStaleMode::default(),
me_bind_stale_ttl_secs: default_me_bind_stale_ttl_secs(), me_bind_stale_ttl_secs: default_me_bind_stale_ttl_secs(),
me_pool_min_fresh_ratio: default_me_pool_min_fresh_ratio(), me_pool_min_fresh_ratio: default_me_pool_min_fresh_ratio(),
@ -1417,12 +1418,6 @@ pub struct AntiCensorshipConfig {
#[serde(default = "default_mask_shape_hardening")] #[serde(default = "default_mask_shape_hardening")]
pub mask_shape_hardening: bool, pub mask_shape_hardening: bool,
/// Opt-in aggressive shape hardening mode.
/// When enabled, masking may shape some backend-silent timeout paths and
/// enforces strictly positive above-cap blur when blur is enabled.
#[serde(default = "default_mask_shape_hardening_aggressive_mode")]
pub mask_shape_hardening_aggressive_mode: bool,
/// Minimum bucket size for mask shape hardening padding. /// Minimum bucket size for mask shape hardening padding.
#[serde(default = "default_mask_shape_bucket_floor_bytes")] #[serde(default = "default_mask_shape_bucket_floor_bytes")]
pub mask_shape_bucket_floor_bytes: usize, pub mask_shape_bucket_floor_bytes: usize,
@ -1473,7 +1468,6 @@ impl Default for AntiCensorshipConfig {
alpn_enforce: default_alpn_enforce(), alpn_enforce: default_alpn_enforce(),
mask_proxy_protocol: 0, mask_proxy_protocol: 0,
mask_shape_hardening: default_mask_shape_hardening(), mask_shape_hardening: default_mask_shape_hardening(),
mask_shape_hardening_aggressive_mode: default_mask_shape_hardening_aggressive_mode(),
mask_shape_bucket_floor_bytes: default_mask_shape_bucket_floor_bytes(), mask_shape_bucket_floor_bytes: default_mask_shape_bucket_floor_bytes(),
mask_shape_bucket_cap_bytes: default_mask_shape_bucket_cap_bytes(), mask_shape_bucket_cap_bytes: default_mask_shape_bucket_cap_bytes(),
mask_shape_above_cap_blur: default_mask_shape_above_cap_blur(), mask_shape_above_cap_blur: default_mask_shape_above_cap_blur(),

View File

@ -13,13 +13,10 @@
#![allow(dead_code)] #![allow(dead_code)]
use crate::error::{ProxyError, Result};
use aes::Aes256; use aes::Aes256;
use ctr::{ use ctr::{Ctr128BE, cipher::{KeyIvInit, StreamCipher}};
Ctr128BE,
cipher::{KeyIvInit, StreamCipher},
};
use zeroize::Zeroize; use zeroize::Zeroize;
use crate::error::{ProxyError, Result};
type Aes256Ctr = Ctr128BE<Aes256>; type Aes256Ctr = Ctr128BE<Aes256>;
@ -45,39 +42,33 @@ impl AesCtr {
cipher: Aes256Ctr::new(key.into(), (&iv_bytes).into()), cipher: Aes256Ctr::new(key.into(), (&iv_bytes).into()),
} }
} }
/// Create from key and IV slices /// Create from key and IV slices
pub fn from_key_iv(key: &[u8], iv: &[u8]) -> Result<Self> { pub fn from_key_iv(key: &[u8], iv: &[u8]) -> Result<Self> {
if key.len() != 32 { if key.len() != 32 {
return Err(ProxyError::InvalidKeyLength { return Err(ProxyError::InvalidKeyLength { expected: 32, got: key.len() });
expected: 32,
got: key.len(),
});
} }
if iv.len() != 16 { if iv.len() != 16 {
return Err(ProxyError::InvalidKeyLength { return Err(ProxyError::InvalidKeyLength { expected: 16, got: iv.len() });
expected: 16,
got: iv.len(),
});
} }
let key: [u8; 32] = key.try_into().unwrap(); let key: [u8; 32] = key.try_into().unwrap();
let iv = u128::from_be_bytes(iv.try_into().unwrap()); let iv = u128::from_be_bytes(iv.try_into().unwrap());
Ok(Self::new(&key, iv)) Ok(Self::new(&key, iv))
} }
/// Encrypt/decrypt data in-place (CTR mode is symmetric) /// Encrypt/decrypt data in-place (CTR mode is symmetric)
pub fn apply(&mut self, data: &mut [u8]) { pub fn apply(&mut self, data: &mut [u8]) {
self.cipher.apply_keystream(data); self.cipher.apply_keystream(data);
} }
/// Encrypt data, returning new buffer /// Encrypt data, returning new buffer
pub fn encrypt(&mut self, data: &[u8]) -> Vec<u8> { pub fn encrypt(&mut self, data: &[u8]) -> Vec<u8> {
let mut output = data.to_vec(); let mut output = data.to_vec();
self.apply(&mut output); self.apply(&mut output);
output output
} }
/// Decrypt data (for CTR, identical to encrypt) /// Decrypt data (for CTR, identical to encrypt)
pub fn decrypt(&mut self, data: &[u8]) -> Vec<u8> { pub fn decrypt(&mut self, data: &[u8]) -> Vec<u8> {
self.encrypt(data) self.encrypt(data)
@ -108,33 +99,27 @@ impl Drop for AesCbc {
impl AesCbc { impl AesCbc {
/// AES block size /// AES block size
const BLOCK_SIZE: usize = 16; const BLOCK_SIZE: usize = 16;
/// Create new AES-CBC cipher with key and IV /// Create new AES-CBC cipher with key and IV
pub fn new(key: [u8; 32], iv: [u8; 16]) -> Self { pub fn new(key: [u8; 32], iv: [u8; 16]) -> Self {
Self { key, iv } Self { key, iv }
} }
/// Create from slices /// Create from slices
pub fn from_slices(key: &[u8], iv: &[u8]) -> Result<Self> { pub fn from_slices(key: &[u8], iv: &[u8]) -> Result<Self> {
if key.len() != 32 { if key.len() != 32 {
return Err(ProxyError::InvalidKeyLength { return Err(ProxyError::InvalidKeyLength { expected: 32, got: key.len() });
expected: 32,
got: key.len(),
});
} }
if iv.len() != 16 { if iv.len() != 16 {
return Err(ProxyError::InvalidKeyLength { return Err(ProxyError::InvalidKeyLength { expected: 16, got: iv.len() });
expected: 16,
got: iv.len(),
});
} }
Ok(Self { Ok(Self {
key: key.try_into().unwrap(), key: key.try_into().unwrap(),
iv: iv.try_into().unwrap(), iv: iv.try_into().unwrap(),
}) })
} }
/// Encrypt a single block using raw AES (no chaining) /// Encrypt a single block using raw AES (no chaining)
fn encrypt_block(&self, block: &[u8; 16], key_schedule: &aes::Aes256) -> [u8; 16] { fn encrypt_block(&self, block: &[u8; 16], key_schedule: &aes::Aes256) -> [u8; 16] {
use aes::cipher::BlockEncrypt; use aes::cipher::BlockEncrypt;
@ -142,7 +127,7 @@ impl AesCbc {
key_schedule.encrypt_block((&mut output).into()); key_schedule.encrypt_block((&mut output).into());
output output
} }
/// Decrypt a single block using raw AES (no chaining) /// Decrypt a single block using raw AES (no chaining)
fn decrypt_block(&self, block: &[u8; 16], key_schedule: &aes::Aes256) -> [u8; 16] { fn decrypt_block(&self, block: &[u8; 16], key_schedule: &aes::Aes256) -> [u8; 16] {
use aes::cipher::BlockDecrypt; use aes::cipher::BlockDecrypt;
@ -150,7 +135,7 @@ impl AesCbc {
key_schedule.decrypt_block((&mut output).into()); key_schedule.decrypt_block((&mut output).into());
output output
} }
/// XOR two 16-byte blocks /// XOR two 16-byte blocks
fn xor_blocks(a: &[u8; 16], b: &[u8; 16]) -> [u8; 16] { fn xor_blocks(a: &[u8; 16], b: &[u8; 16]) -> [u8; 16] {
let mut result = [0u8; 16]; let mut result = [0u8; 16];
@ -159,28 +144,27 @@ impl AesCbc {
} }
result result
} }
/// Encrypt data using CBC mode with proper chaining /// Encrypt data using CBC mode with proper chaining
/// ///
/// CBC Encryption: C[i] = AES_Encrypt(P[i] XOR C[i-1]), where C[-1] = IV /// CBC Encryption: C[i] = AES_Encrypt(P[i] XOR C[i-1]), where C[-1] = IV
pub fn encrypt(&self, data: &[u8]) -> Result<Vec<u8>> { pub fn encrypt(&self, data: &[u8]) -> Result<Vec<u8>> {
if !data.len().is_multiple_of(Self::BLOCK_SIZE) { if !data.len().is_multiple_of(Self::BLOCK_SIZE) {
return Err(ProxyError::Crypto(format!( return Err(ProxyError::Crypto(
"CBC data must be aligned to 16 bytes, got {}", format!("CBC data must be aligned to 16 bytes, got {}", data.len())
data.len() ));
)));
} }
if data.is_empty() { if data.is_empty() {
return Ok(Vec::new()); return Ok(Vec::new());
} }
use aes::cipher::KeyInit; use aes::cipher::KeyInit;
let key_schedule = aes::Aes256::new((&self.key).into()); let key_schedule = aes::Aes256::new((&self.key).into());
let mut result = Vec::with_capacity(data.len()); let mut result = Vec::with_capacity(data.len());
let mut prev_ciphertext = self.iv; let mut prev_ciphertext = self.iv;
for chunk in data.chunks(Self::BLOCK_SIZE) { for chunk in data.chunks(Self::BLOCK_SIZE) {
let plaintext: [u8; 16] = chunk.try_into().unwrap(); let plaintext: [u8; 16] = chunk.try_into().unwrap();
let xored = Self::xor_blocks(&plaintext, &prev_ciphertext); let xored = Self::xor_blocks(&plaintext, &prev_ciphertext);
@ -188,31 +172,30 @@ impl AesCbc {
prev_ciphertext = ciphertext; prev_ciphertext = ciphertext;
result.extend_from_slice(&ciphertext); result.extend_from_slice(&ciphertext);
} }
Ok(result) Ok(result)
} }
/// Decrypt data using CBC mode with proper chaining /// Decrypt data using CBC mode with proper chaining
/// ///
/// CBC Decryption: P[i] = AES_Decrypt(C[i]) XOR C[i-1], where C[-1] = IV /// CBC Decryption: P[i] = AES_Decrypt(C[i]) XOR C[i-1], where C[-1] = IV
pub fn decrypt(&self, data: &[u8]) -> Result<Vec<u8>> { pub fn decrypt(&self, data: &[u8]) -> Result<Vec<u8>> {
if !data.len().is_multiple_of(Self::BLOCK_SIZE) { if !data.len().is_multiple_of(Self::BLOCK_SIZE) {
return Err(ProxyError::Crypto(format!( return Err(ProxyError::Crypto(
"CBC data must be aligned to 16 bytes, got {}", format!("CBC data must be aligned to 16 bytes, got {}", data.len())
data.len() ));
)));
} }
if data.is_empty() { if data.is_empty() {
return Ok(Vec::new()); return Ok(Vec::new());
} }
use aes::cipher::KeyInit; use aes::cipher::KeyInit;
let key_schedule = aes::Aes256::new((&self.key).into()); let key_schedule = aes::Aes256::new((&self.key).into());
let mut result = Vec::with_capacity(data.len()); let mut result = Vec::with_capacity(data.len());
let mut prev_ciphertext = self.iv; let mut prev_ciphertext = self.iv;
for chunk in data.chunks(Self::BLOCK_SIZE) { for chunk in data.chunks(Self::BLOCK_SIZE) {
let ciphertext: [u8; 16] = chunk.try_into().unwrap(); let ciphertext: [u8; 16] = chunk.try_into().unwrap();
let decrypted = self.decrypt_block(&ciphertext, &key_schedule); let decrypted = self.decrypt_block(&ciphertext, &key_schedule);
@ -220,77 +203,75 @@ impl AesCbc {
prev_ciphertext = ciphertext; prev_ciphertext = ciphertext;
result.extend_from_slice(&plaintext); result.extend_from_slice(&plaintext);
} }
Ok(result) Ok(result)
} }
/// Encrypt data in-place /// Encrypt data in-place
pub fn encrypt_in_place(&self, data: &mut [u8]) -> Result<()> { pub fn encrypt_in_place(&self, data: &mut [u8]) -> Result<()> {
if !data.len().is_multiple_of(Self::BLOCK_SIZE) { if !data.len().is_multiple_of(Self::BLOCK_SIZE) {
return Err(ProxyError::Crypto(format!( return Err(ProxyError::Crypto(
"CBC data must be aligned to 16 bytes, got {}", format!("CBC data must be aligned to 16 bytes, got {}", data.len())
data.len() ));
)));
} }
if data.is_empty() { if data.is_empty() {
return Ok(()); return Ok(());
} }
use aes::cipher::KeyInit; use aes::cipher::KeyInit;
let key_schedule = aes::Aes256::new((&self.key).into()); let key_schedule = aes::Aes256::new((&self.key).into());
let mut prev_ciphertext = self.iv; let mut prev_ciphertext = self.iv;
for i in (0..data.len()).step_by(Self::BLOCK_SIZE) { for i in (0..data.len()).step_by(Self::BLOCK_SIZE) {
let block = &mut data[i..i + Self::BLOCK_SIZE]; let block = &mut data[i..i + Self::BLOCK_SIZE];
for j in 0..Self::BLOCK_SIZE { for j in 0..Self::BLOCK_SIZE {
block[j] ^= prev_ciphertext[j]; block[j] ^= prev_ciphertext[j];
} }
let block_array: &mut [u8; 16] = block.try_into().unwrap(); let block_array: &mut [u8; 16] = block.try_into().unwrap();
*block_array = self.encrypt_block(block_array, &key_schedule); *block_array = self.encrypt_block(block_array, &key_schedule);
prev_ciphertext = *block_array; prev_ciphertext = *block_array;
} }
Ok(()) Ok(())
} }
/// Decrypt data in-place /// Decrypt data in-place
pub fn decrypt_in_place(&self, data: &mut [u8]) -> Result<()> { pub fn decrypt_in_place(&self, data: &mut [u8]) -> Result<()> {
if !data.len().is_multiple_of(Self::BLOCK_SIZE) { if !data.len().is_multiple_of(Self::BLOCK_SIZE) {
return Err(ProxyError::Crypto(format!( return Err(ProxyError::Crypto(
"CBC data must be aligned to 16 bytes, got {}", format!("CBC data must be aligned to 16 bytes, got {}", data.len())
data.len() ));
)));
} }
if data.is_empty() { if data.is_empty() {
return Ok(()); return Ok(());
} }
use aes::cipher::KeyInit; use aes::cipher::KeyInit;
let key_schedule = aes::Aes256::new((&self.key).into()); let key_schedule = aes::Aes256::new((&self.key).into());
let mut prev_ciphertext = self.iv; let mut prev_ciphertext = self.iv;
for i in (0..data.len()).step_by(Self::BLOCK_SIZE) { for i in (0..data.len()).step_by(Self::BLOCK_SIZE) {
let block = &mut data[i..i + Self::BLOCK_SIZE]; let block = &mut data[i..i + Self::BLOCK_SIZE];
let current_ciphertext: [u8; 16] = block.try_into().unwrap(); let current_ciphertext: [u8; 16] = block.try_into().unwrap();
let block_array: &mut [u8; 16] = block.try_into().unwrap(); let block_array: &mut [u8; 16] = block.try_into().unwrap();
*block_array = self.decrypt_block(block_array, &key_schedule); *block_array = self.decrypt_block(block_array, &key_schedule);
for j in 0..Self::BLOCK_SIZE { for j in 0..Self::BLOCK_SIZE {
block[j] ^= prev_ciphertext[j]; block[j] ^= prev_ciphertext[j];
} }
prev_ciphertext = current_ciphertext; prev_ciphertext = current_ciphertext;
} }
Ok(()) Ok(())
} }
} }
@ -337,227 +318,227 @@ impl Decryptor for PassthroughEncryptor {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
// ============= AES-CTR Tests ============= // ============= AES-CTR Tests =============
#[test] #[test]
fn test_aes_ctr_roundtrip() { fn test_aes_ctr_roundtrip() {
let key = [0u8; 32]; let key = [0u8; 32];
let iv = 12345u128; let iv = 12345u128;
let original = b"Hello, MTProto!"; let original = b"Hello, MTProto!";
let mut enc = AesCtr::new(&key, iv); let mut enc = AesCtr::new(&key, iv);
let encrypted = enc.encrypt(original); let encrypted = enc.encrypt(original);
let mut dec = AesCtr::new(&key, iv); let mut dec = AesCtr::new(&key, iv);
let decrypted = dec.decrypt(&encrypted); let decrypted = dec.decrypt(&encrypted);
assert_eq!(original.as_slice(), decrypted.as_slice()); assert_eq!(original.as_slice(), decrypted.as_slice());
} }
#[test] #[test]
fn test_aes_ctr_in_place() { fn test_aes_ctr_in_place() {
let key = [0x42u8; 32]; let key = [0x42u8; 32];
let iv = 999u128; let iv = 999u128;
let original = b"Test data for in-place encryption"; let original = b"Test data for in-place encryption";
let mut data = original.to_vec(); let mut data = original.to_vec();
let mut cipher = AesCtr::new(&key, iv); let mut cipher = AesCtr::new(&key, iv);
cipher.apply(&mut data); cipher.apply(&mut data);
assert_ne!(&data[..], original); assert_ne!(&data[..], original);
let mut cipher = AesCtr::new(&key, iv); let mut cipher = AesCtr::new(&key, iv);
cipher.apply(&mut data); cipher.apply(&mut data);
assert_eq!(&data[..], original); assert_eq!(&data[..], original);
} }
// ============= AES-CBC Tests ============= // ============= AES-CBC Tests =============
#[test] #[test]
fn test_aes_cbc_roundtrip() { fn test_aes_cbc_roundtrip() {
let key = [0u8; 32]; let key = [0u8; 32];
let iv = [0u8; 16]; let iv = [0u8; 16];
let original = [0u8; 32]; let original = [0u8; 32];
let cipher = AesCbc::new(key, iv); let cipher = AesCbc::new(key, iv);
let encrypted = cipher.encrypt(&original).unwrap(); let encrypted = cipher.encrypt(&original).unwrap();
let decrypted = cipher.decrypt(&encrypted).unwrap(); let decrypted = cipher.decrypt(&encrypted).unwrap();
assert_eq!(original.as_slice(), decrypted.as_slice()); assert_eq!(original.as_slice(), decrypted.as_slice());
} }
#[test] #[test]
fn test_aes_cbc_chaining_works() { fn test_aes_cbc_chaining_works() {
let key = [0x42u8; 32]; let key = [0x42u8; 32];
let iv = [0x00u8; 16]; let iv = [0x00u8; 16];
let plaintext = [0xAAu8; 32]; let plaintext = [0xAAu8; 32];
let cipher = AesCbc::new(key, iv); let cipher = AesCbc::new(key, iv);
let ciphertext = cipher.encrypt(&plaintext).unwrap(); let ciphertext = cipher.encrypt(&plaintext).unwrap();
let block1 = &ciphertext[0..16]; let block1 = &ciphertext[0..16];
let block2 = &ciphertext[16..32]; let block2 = &ciphertext[16..32];
assert_ne!( assert_ne!(
block1, block2, block1, block2,
"CBC chaining broken: identical plaintext blocks produced identical ciphertext" "CBC chaining broken: identical plaintext blocks produced identical ciphertext"
); );
} }
#[test] #[test]
fn test_aes_cbc_known_vector() { fn test_aes_cbc_known_vector() {
let key = [0u8; 32]; let key = [0u8; 32];
let iv = [0u8; 16]; let iv = [0u8; 16];
let plaintext = [0u8; 16]; let plaintext = [0u8; 16];
let cipher = AesCbc::new(key, iv); let cipher = AesCbc::new(key, iv);
let ciphertext = cipher.encrypt(&plaintext).unwrap(); let ciphertext = cipher.encrypt(&plaintext).unwrap();
let decrypted = cipher.decrypt(&ciphertext).unwrap(); let decrypted = cipher.decrypt(&ciphertext).unwrap();
assert_eq!(plaintext.as_slice(), decrypted.as_slice()); assert_eq!(plaintext.as_slice(), decrypted.as_slice());
assert_ne!(ciphertext.as_slice(), plaintext.as_slice()); assert_ne!(ciphertext.as_slice(), plaintext.as_slice());
} }
#[test] #[test]
fn test_aes_cbc_multi_block() { fn test_aes_cbc_multi_block() {
let key = [0x12u8; 32]; let key = [0x12u8; 32];
let iv = [0x34u8; 16]; let iv = [0x34u8; 16];
let plaintext: Vec<u8> = (0..80).collect(); let plaintext: Vec<u8> = (0..80).collect();
let cipher = AesCbc::new(key, iv); let cipher = AesCbc::new(key, iv);
let ciphertext = cipher.encrypt(&plaintext).unwrap(); let ciphertext = cipher.encrypt(&plaintext).unwrap();
let decrypted = cipher.decrypt(&ciphertext).unwrap(); let decrypted = cipher.decrypt(&ciphertext).unwrap();
assert_eq!(plaintext, decrypted); assert_eq!(plaintext, decrypted);
} }
#[test] #[test]
fn test_aes_cbc_in_place() { fn test_aes_cbc_in_place() {
let key = [0x12u8; 32]; let key = [0x12u8; 32];
let iv = [0x34u8; 16]; let iv = [0x34u8; 16];
let original = [0x56u8; 48]; let original = [0x56u8; 48];
let mut buffer = original; let mut buffer = original;
let cipher = AesCbc::new(key, iv); let cipher = AesCbc::new(key, iv);
cipher.encrypt_in_place(&mut buffer).unwrap(); cipher.encrypt_in_place(&mut buffer).unwrap();
assert_ne!(&buffer[..], &original[..]); assert_ne!(&buffer[..], &original[..]);
cipher.decrypt_in_place(&mut buffer).unwrap(); cipher.decrypt_in_place(&mut buffer).unwrap();
assert_eq!(&buffer[..], &original[..]); assert_eq!(&buffer[..], &original[..]);
} }
#[test] #[test]
fn test_aes_cbc_empty_data() { fn test_aes_cbc_empty_data() {
let cipher = AesCbc::new([0u8; 32], [0u8; 16]); let cipher = AesCbc::new([0u8; 32], [0u8; 16]);
let encrypted = cipher.encrypt(&[]).unwrap(); let encrypted = cipher.encrypt(&[]).unwrap();
assert!(encrypted.is_empty()); assert!(encrypted.is_empty());
let decrypted = cipher.decrypt(&[]).unwrap(); let decrypted = cipher.decrypt(&[]).unwrap();
assert!(decrypted.is_empty()); assert!(decrypted.is_empty());
} }
#[test] #[test]
fn test_aes_cbc_unaligned_error() { fn test_aes_cbc_unaligned_error() {
let cipher = AesCbc::new([0u8; 32], [0u8; 16]); let cipher = AesCbc::new([0u8; 32], [0u8; 16]);
let result = cipher.encrypt(&[0u8; 15]); let result = cipher.encrypt(&[0u8; 15]);
assert!(result.is_err()); assert!(result.is_err());
let result = cipher.encrypt(&[0u8; 17]); let result = cipher.encrypt(&[0u8; 17]);
assert!(result.is_err()); assert!(result.is_err());
} }
#[test] #[test]
fn test_aes_cbc_avalanche_effect() { fn test_aes_cbc_avalanche_effect() {
let key = [0xAB; 32]; let key = [0xAB; 32];
let iv = [0xCD; 16]; let iv = [0xCD; 16];
let plaintext1 = [0u8; 32]; let plaintext1 = [0u8; 32];
let mut plaintext2 = [0u8; 32]; let mut plaintext2 = [0u8; 32];
plaintext2[0] = 0x01; plaintext2[0] = 0x01;
let cipher = AesCbc::new(key, iv); let cipher = AesCbc::new(key, iv);
let ciphertext1 = cipher.encrypt(&plaintext1).unwrap(); let ciphertext1 = cipher.encrypt(&plaintext1).unwrap();
let ciphertext2 = cipher.encrypt(&plaintext2).unwrap(); let ciphertext2 = cipher.encrypt(&plaintext2).unwrap();
assert_ne!(&ciphertext1[0..16], &ciphertext2[0..16]); assert_ne!(&ciphertext1[0..16], &ciphertext2[0..16]);
assert_ne!(&ciphertext1[16..32], &ciphertext2[16..32]); assert_ne!(&ciphertext1[16..32], &ciphertext2[16..32]);
} }
#[test] #[test]
fn test_aes_cbc_iv_matters() { fn test_aes_cbc_iv_matters() {
let key = [0x55; 32]; let key = [0x55; 32];
let plaintext = [0x77u8; 16]; let plaintext = [0x77u8; 16];
let cipher1 = AesCbc::new(key, [0u8; 16]); let cipher1 = AesCbc::new(key, [0u8; 16]);
let cipher2 = AesCbc::new(key, [1u8; 16]); let cipher2 = AesCbc::new(key, [1u8; 16]);
let ciphertext1 = cipher1.encrypt(&plaintext).unwrap(); let ciphertext1 = cipher1.encrypt(&plaintext).unwrap();
let ciphertext2 = cipher2.encrypt(&plaintext).unwrap(); let ciphertext2 = cipher2.encrypt(&plaintext).unwrap();
assert_ne!(ciphertext1, ciphertext2); assert_ne!(ciphertext1, ciphertext2);
} }
#[test] #[test]
fn test_aes_cbc_deterministic() { fn test_aes_cbc_deterministic() {
let key = [0x99; 32]; let key = [0x99; 32];
let iv = [0x88; 16]; let iv = [0x88; 16];
let plaintext = [0x77u8; 32]; let plaintext = [0x77u8; 32];
let cipher = AesCbc::new(key, iv); let cipher = AesCbc::new(key, iv);
let ciphertext1 = cipher.encrypt(&plaintext).unwrap(); let ciphertext1 = cipher.encrypt(&plaintext).unwrap();
let ciphertext2 = cipher.encrypt(&plaintext).unwrap(); let ciphertext2 = cipher.encrypt(&plaintext).unwrap();
assert_eq!(ciphertext1, ciphertext2); assert_eq!(ciphertext1, ciphertext2);
} }
// ============= Zeroize Tests ============= // ============= Zeroize Tests =============
#[test] #[test]
fn test_aes_cbc_zeroize_on_drop() { fn test_aes_cbc_zeroize_on_drop() {
let key = [0xAA; 32]; let key = [0xAA; 32];
let iv = [0xBB; 16]; let iv = [0xBB; 16];
let cipher = AesCbc::new(key, iv); let cipher = AesCbc::new(key, iv);
// Verify key/iv are set // Verify key/iv are set
assert_eq!(cipher.key, [0xAA; 32]); assert_eq!(cipher.key, [0xAA; 32]);
assert_eq!(cipher.iv, [0xBB; 16]); assert_eq!(cipher.iv, [0xBB; 16]);
drop(cipher); drop(cipher);
// After drop, key/iv are zeroized (can't observe directly, // After drop, key/iv are zeroized (can't observe directly,
// but the Drop impl runs without panic) // but the Drop impl runs without panic)
} }
// ============= Error Handling Tests ============= // ============= Error Handling Tests =============
#[test] #[test]
fn test_invalid_key_length() { fn test_invalid_key_length() {
let result = AesCtr::from_key_iv(&[0u8; 16], &[0u8; 16]); let result = AesCtr::from_key_iv(&[0u8; 16], &[0u8; 16]);
assert!(result.is_err()); assert!(result.is_err());
let result = AesCbc::from_slices(&[0u8; 16], &[0u8; 16]); let result = AesCbc::from_slices(&[0u8; 16], &[0u8; 16]);
assert!(result.is_err()); assert!(result.is_err());
} }
#[test] #[test]
fn test_invalid_iv_length() { fn test_invalid_iv_length() {
let result = AesCtr::from_key_iv(&[0u8; 32], &[0u8; 8]); let result = AesCtr::from_key_iv(&[0u8; 32], &[0u8; 8]);
assert!(result.is_err()); assert!(result.is_err());
let result = AesCbc::from_slices(&[0u8; 32], &[0u8; 8]); let result = AesCbc::from_slices(&[0u8; 32], &[0u8; 8]);
assert!(result.is_err()); assert!(result.is_err());
} }
} }

View File

@ -12,10 +12,10 @@
//! usages are intentional and protocol-mandated. //! usages are intentional and protocol-mandated.
use hmac::{Hmac, Mac}; use hmac::{Hmac, Mac};
use sha2::Sha256;
use md5::Md5; use md5::Md5;
use sha1::Sha1; use sha1::Sha1;
use sha2::Digest; use sha2::Digest;
use sha2::Sha256;
type HmacSha256 = Hmac<Sha256>; type HmacSha256 = Hmac<Sha256>;
@ -28,7 +28,8 @@ pub fn sha256(data: &[u8]) -> [u8; 32] {
/// SHA-256 HMAC /// SHA-256 HMAC
pub fn sha256_hmac(key: &[u8], data: &[u8]) -> [u8; 32] { pub fn sha256_hmac(key: &[u8], data: &[u8]) -> [u8; 32] {
let mut mac = HmacSha256::new_from_slice(key).expect("HMAC accepts any key length"); let mut mac = HmacSha256::new_from_slice(key)
.expect("HMAC accepts any key length");
mac.update(data); mac.update(data);
mac.finalize().into_bytes().into() mac.finalize().into_bytes().into()
} }
@ -123,18 +124,27 @@ pub fn derive_middleproxy_keys(
srv_ipv6: Option<&[u8; 16]>, srv_ipv6: Option<&[u8; 16]>,
) -> ([u8; 32], [u8; 16]) { ) -> ([u8; 32], [u8; 16]) {
let s = build_middleproxy_prekey( let s = build_middleproxy_prekey(
nonce_srv, nonce_clt, clt_ts, srv_ip, clt_port, purpose, clt_ip, srv_port, secret, nonce_srv,
clt_ipv6, srv_ipv6, nonce_clt,
clt_ts,
srv_ip,
clt_port,
purpose,
clt_ip,
srv_port,
secret,
clt_ipv6,
srv_ipv6,
); );
let md5_1 = md5(&s[1..]); let md5_1 = md5(&s[1..]);
let sha1_sum = sha1(&s); let sha1_sum = sha1(&s);
let md5_2 = md5(&s[2..]); let md5_2 = md5(&s[2..]);
let mut key = [0u8; 32]; let mut key = [0u8; 32];
key[..12].copy_from_slice(&md5_1[..12]); key[..12].copy_from_slice(&md5_1[..12]);
key[12..].copy_from_slice(&sha1_sum); key[12..].copy_from_slice(&sha1_sum);
(key, md5_2) (key, md5_2)
} }
@ -154,8 +164,17 @@ mod tests {
let secret = vec![0x55u8; 128]; let secret = vec![0x55u8; 128];
let prekey = build_middleproxy_prekey( let prekey = build_middleproxy_prekey(
&nonce_srv, &nonce_clt, &clt_ts, srv_ip, &clt_port, b"CLIENT", clt_ip, &srv_port, &nonce_srv,
&secret, None, None, &nonce_clt,
&clt_ts,
srv_ip,
&clt_port,
b"CLIENT",
clt_ip,
&srv_port,
&secret,
None,
None,
); );
let digest = sha256(&prekey); let digest = sha256(&prekey);
assert_eq!( assert_eq!(

View File

@ -4,7 +4,7 @@ pub mod aes;
pub mod hash; pub mod hash;
pub mod random; pub mod random;
pub use aes::{AesCbc, AesCtr}; pub use aes::{AesCtr, AesCbc};
pub use hash::{ pub use hash::{
build_middleproxy_prekey, crc32, crc32c, derive_middleproxy_keys, sha256, sha256_hmac, build_middleproxy_prekey, crc32, crc32c, derive_middleproxy_keys, sha256, sha256_hmac,
}; };

View File

@ -3,11 +3,11 @@
#![allow(deprecated)] #![allow(deprecated)]
#![allow(dead_code)] #![allow(dead_code)]
use crate::crypto::AesCtr;
use parking_lot::Mutex;
use rand::rngs::StdRng;
use rand::{Rng, RngExt, SeedableRng}; use rand::{Rng, RngExt, SeedableRng};
use rand::rngs::StdRng;
use parking_lot::Mutex;
use zeroize::Zeroize; use zeroize::Zeroize;
use crate::crypto::AesCtr;
/// Cryptographically secure PRNG with AES-CTR /// Cryptographically secure PRNG with AES-CTR
pub struct SecureRandom { pub struct SecureRandom {
@ -34,16 +34,16 @@ impl SecureRandom {
pub fn new() -> Self { pub fn new() -> Self {
let mut seed_source = rand::rng(); let mut seed_source = rand::rng();
let mut rng = StdRng::from_rng(&mut seed_source); let mut rng = StdRng::from_rng(&mut seed_source);
let mut key = [0u8; 32]; let mut key = [0u8; 32];
rng.fill_bytes(&mut key); rng.fill_bytes(&mut key);
let iv: u128 = rng.random(); let iv: u128 = rng.random();
let cipher = AesCtr::new(&key, iv); let cipher = AesCtr::new(&key, iv);
// Zeroize local key copy — cipher already consumed it // Zeroize local key copy — cipher already consumed it
key.zeroize(); key.zeroize();
Self { Self {
inner: Mutex::new(SecureRandomInner { inner: Mutex::new(SecureRandomInner {
rng, rng,
@ -53,7 +53,7 @@ impl SecureRandom {
}), }),
} }
} }
/// Fill a caller-provided buffer with random bytes. /// Fill a caller-provided buffer with random bytes.
pub fn fill(&self, out: &mut [u8]) { pub fn fill(&self, out: &mut [u8]) {
let mut inner = self.inner.lock(); let mut inner = self.inner.lock();
@ -94,7 +94,7 @@ impl SecureRandom {
self.fill(&mut out); self.fill(&mut out);
out out
} }
/// Generate random number in range [0, max) /// Generate random number in range [0, max)
pub fn range(&self, max: usize) -> usize { pub fn range(&self, max: usize) -> usize {
if max == 0 { if max == 0 {
@ -103,16 +103,16 @@ impl SecureRandom {
let mut inner = self.inner.lock(); let mut inner = self.inner.lock();
inner.rng.random_range(0..max) inner.rng.random_range(0..max)
} }
/// Generate random bits /// Generate random bits
pub fn bits(&self, k: usize) -> u64 { pub fn bits(&self, k: usize) -> u64 {
if k == 0 { if k == 0 {
return 0; return 0;
} }
let bytes_needed = k.div_ceil(8); let bytes_needed = k.div_ceil(8);
let bytes = self.bytes(bytes_needed.min(8)); let bytes = self.bytes(bytes_needed.min(8));
let mut result = 0u64; let mut result = 0u64;
for (i, &b) in bytes.iter().enumerate() { for (i, &b) in bytes.iter().enumerate() {
if i >= 8 { if i >= 8 {
@ -120,14 +120,14 @@ impl SecureRandom {
} }
result |= (b as u64) << (i * 8); result |= (b as u64) << (i * 8);
} }
if k < 64 { if k < 64 {
result &= (1u64 << k) - 1; result &= (1u64 << k) - 1;
} }
result result
} }
/// Choose random element from slice /// Choose random element from slice
pub fn choose<'a, T>(&self, slice: &'a [T]) -> Option<&'a T> { pub fn choose<'a, T>(&self, slice: &'a [T]) -> Option<&'a T> {
if slice.is_empty() { if slice.is_empty() {
@ -136,7 +136,7 @@ impl SecureRandom {
Some(&slice[self.range(slice.len())]) Some(&slice[self.range(slice.len())])
} }
} }
/// Shuffle slice in place /// Shuffle slice in place
pub fn shuffle<T>(&self, slice: &mut [T]) { pub fn shuffle<T>(&self, slice: &mut [T]) {
let mut inner = self.inner.lock(); let mut inner = self.inner.lock();
@ -145,13 +145,13 @@ impl SecureRandom {
slice.swap(i, j); slice.swap(i, j);
} }
} }
/// Generate random u32 /// Generate random u32
pub fn u32(&self) -> u32 { pub fn u32(&self) -> u32 {
let mut inner = self.inner.lock(); let mut inner = self.inner.lock();
inner.rng.random() inner.rng.random()
} }
/// Generate random u64 /// Generate random u64
pub fn u64(&self) -> u64 { pub fn u64(&self) -> u64 {
let mut inner = self.inner.lock(); let mut inner = self.inner.lock();
@ -169,7 +169,7 @@ impl Default for SecureRandom {
mod tests { mod tests {
use super::*; use super::*;
use std::collections::HashSet; use std::collections::HashSet;
#[test] #[test]
fn test_bytes_uniqueness() { fn test_bytes_uniqueness() {
let rng = SecureRandom::new(); let rng = SecureRandom::new();
@ -177,7 +177,7 @@ mod tests {
let b = rng.bytes(32); let b = rng.bytes(32);
assert_ne!(a, b); assert_ne!(a, b);
} }
#[test] #[test]
fn test_bytes_length() { fn test_bytes_length() {
let rng = SecureRandom::new(); let rng = SecureRandom::new();
@ -186,63 +186,63 @@ mod tests {
assert_eq!(rng.bytes(100).len(), 100); assert_eq!(rng.bytes(100).len(), 100);
assert_eq!(rng.bytes(1000).len(), 1000); assert_eq!(rng.bytes(1000).len(), 1000);
} }
#[test] #[test]
fn test_range() { fn test_range() {
let rng = SecureRandom::new(); let rng = SecureRandom::new();
for _ in 0..1000 { for _ in 0..1000 {
let n = rng.range(10); let n = rng.range(10);
assert!(n < 10); assert!(n < 10);
} }
assert_eq!(rng.range(1), 0); assert_eq!(rng.range(1), 0);
assert_eq!(rng.range(0), 0); assert_eq!(rng.range(0), 0);
} }
#[test] #[test]
fn test_bits() { fn test_bits() {
let rng = SecureRandom::new(); let rng = SecureRandom::new();
for _ in 0..100 { for _ in 0..100 {
assert!(rng.bits(1) <= 1); assert!(rng.bits(1) <= 1);
} }
for _ in 0..100 { for _ in 0..100 {
assert!(rng.bits(8) <= 255); assert!(rng.bits(8) <= 255);
} }
} }
#[test] #[test]
fn test_choose() { fn test_choose() {
let rng = SecureRandom::new(); let rng = SecureRandom::new();
let items = vec![1, 2, 3, 4, 5]; let items = vec![1, 2, 3, 4, 5];
let mut seen = HashSet::new(); let mut seen = HashSet::new();
for _ in 0..1000 { for _ in 0..1000 {
if let Some(&item) = rng.choose(&items) { if let Some(&item) = rng.choose(&items) {
seen.insert(item); seen.insert(item);
} }
} }
assert_eq!(seen.len(), 5); assert_eq!(seen.len(), 5);
let empty: Vec<i32> = vec![]; let empty: Vec<i32> = vec![];
assert!(rng.choose(&empty).is_none()); assert!(rng.choose(&empty).is_none());
} }
#[test] #[test]
fn test_shuffle() { fn test_shuffle() {
let rng = SecureRandom::new(); let rng = SecureRandom::new();
let original = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]; let original = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
let mut shuffled = original.clone(); let mut shuffled = original.clone();
rng.shuffle(&mut shuffled); rng.shuffle(&mut shuffled);
let mut sorted = shuffled.clone(); let mut sorted = shuffled.clone();
sorted.sort(); sorted.sort();
assert_eq!(sorted, original); assert_eq!(sorted, original);
assert_ne!(shuffled, original); assert_ne!(shuffled, original);
} }
} }

View File

@ -12,15 +12,28 @@ use thiserror::Error;
#[derive(Debug)] #[derive(Debug)]
pub enum StreamError { pub enum StreamError {
/// Partial read: got fewer bytes than expected /// Partial read: got fewer bytes than expected
PartialRead { expected: usize, got: usize }, PartialRead {
expected: usize,
got: usize,
},
/// Partial write: wrote fewer bytes than expected /// Partial write: wrote fewer bytes than expected
PartialWrite { expected: usize, written: usize }, PartialWrite {
expected: usize,
written: usize,
},
/// Stream is in poisoned state and cannot be used /// Stream is in poisoned state and cannot be used
Poisoned { reason: String }, Poisoned {
reason: String,
},
/// Buffer overflow: attempted to buffer more than allowed /// Buffer overflow: attempted to buffer more than allowed
BufferOverflow { limit: usize, attempted: usize }, BufferOverflow {
limit: usize,
attempted: usize,
},
/// Invalid frame format /// Invalid frame format
InvalidFrame { details: String }, InvalidFrame {
details: String,
},
/// Unexpected end of stream /// Unexpected end of stream
UnexpectedEof, UnexpectedEof,
/// Underlying I/O error /// Underlying I/O error
@ -34,21 +47,13 @@ impl fmt::Display for StreamError {
write!(f, "partial read: expected {} bytes, got {}", expected, got) write!(f, "partial read: expected {} bytes, got {}", expected, got)
} }
Self::PartialWrite { expected, written } => { Self::PartialWrite { expected, written } => {
write!( write!(f, "partial write: expected {} bytes, wrote {}", expected, written)
f,
"partial write: expected {} bytes, wrote {}",
expected, written
)
} }
Self::Poisoned { reason } => { Self::Poisoned { reason } => {
write!(f, "stream poisoned: {}", reason) write!(f, "stream poisoned: {}", reason)
} }
Self::BufferOverflow { limit, attempted } => { Self::BufferOverflow { limit, attempted } => {
write!( write!(f, "buffer overflow: limit {}, attempted {}", limit, attempted)
f,
"buffer overflow: limit {}, attempted {}",
limit, attempted
)
} }
Self::InvalidFrame { details } => { Self::InvalidFrame { details } => {
write!(f, "invalid frame: {}", details) write!(f, "invalid frame: {}", details)
@ -85,7 +90,9 @@ impl From<StreamError> for std::io::Error {
StreamError::UnexpectedEof => { StreamError::UnexpectedEof => {
std::io::Error::new(std::io::ErrorKind::UnexpectedEof, err) std::io::Error::new(std::io::ErrorKind::UnexpectedEof, err)
} }
StreamError::Poisoned { .. } => std::io::Error::other(err), StreamError::Poisoned { .. } => {
std::io::Error::other(err)
}
StreamError::BufferOverflow { .. } => { StreamError::BufferOverflow { .. } => {
std::io::Error::new(std::io::ErrorKind::OutOfMemory, err) std::io::Error::new(std::io::ErrorKind::OutOfMemory, err)
} }
@ -105,7 +112,7 @@ impl From<StreamError> for std::io::Error {
pub trait Recoverable { pub trait Recoverable {
/// Check if error is recoverable (can retry operation) /// Check if error is recoverable (can retry operation)
fn is_recoverable(&self) -> bool; fn is_recoverable(&self) -> bool;
/// Check if connection can continue after this error /// Check if connection can continue after this error
fn can_continue(&self) -> bool; fn can_continue(&self) -> bool;
} }
@ -116,22 +123,19 @@ impl Recoverable for StreamError {
Self::PartialRead { .. } | Self::PartialWrite { .. } => true, Self::PartialRead { .. } | Self::PartialWrite { .. } => true,
Self::Io(e) => matches!( Self::Io(e) => matches!(
e.kind(), e.kind(),
std::io::ErrorKind::WouldBlock std::io::ErrorKind::WouldBlock
| std::io::ErrorKind::Interrupted | std::io::ErrorKind::Interrupted
| std::io::ErrorKind::TimedOut | std::io::ErrorKind::TimedOut
), ),
Self::Poisoned { .. } Self::Poisoned { .. }
| Self::BufferOverflow { .. } | Self::BufferOverflow { .. }
| Self::InvalidFrame { .. } | Self::InvalidFrame { .. }
| Self::UnexpectedEof => false, | Self::UnexpectedEof => false,
} }
} }
fn can_continue(&self) -> bool { fn can_continue(&self) -> bool {
!matches!( !matches!(self, Self::Poisoned { .. } | Self::UnexpectedEof | Self::BufferOverflow { .. })
self,
Self::Poisoned { .. } | Self::UnexpectedEof | Self::BufferOverflow { .. }
)
} }
} }
@ -139,19 +143,19 @@ impl Recoverable for std::io::Error {
fn is_recoverable(&self) -> bool { fn is_recoverable(&self) -> bool {
matches!( matches!(
self.kind(), self.kind(),
std::io::ErrorKind::WouldBlock std::io::ErrorKind::WouldBlock
| std::io::ErrorKind::Interrupted | std::io::ErrorKind::Interrupted
| std::io::ErrorKind::TimedOut | std::io::ErrorKind::TimedOut
) )
} }
fn can_continue(&self) -> bool { fn can_continue(&self) -> bool {
!matches!( !matches!(
self.kind(), self.kind(),
std::io::ErrorKind::BrokenPipe std::io::ErrorKind::BrokenPipe
| std::io::ErrorKind::ConnectionReset | std::io::ErrorKind::ConnectionReset
| std::io::ErrorKind::ConnectionAborted | std::io::ErrorKind::ConnectionAborted
| std::io::ErrorKind::NotConnected | std::io::ErrorKind::NotConnected
) )
} }
} }
@ -161,88 +165,96 @@ impl Recoverable for std::io::Error {
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub enum ProxyError { pub enum ProxyError {
// ============= Crypto Errors ============= // ============= Crypto Errors =============
#[error("Crypto error: {0}")] #[error("Crypto error: {0}")]
Crypto(String), Crypto(String),
#[error("Invalid key length: expected {expected}, got {got}")] #[error("Invalid key length: expected {expected}, got {got}")]
InvalidKeyLength { expected: usize, got: usize }, InvalidKeyLength { expected: usize, got: usize },
// ============= Stream Errors ============= // ============= Stream Errors =============
#[error("Stream error: {0}")] #[error("Stream error: {0}")]
Stream(#[from] StreamError), Stream(#[from] StreamError),
// ============= Protocol Errors ============= // ============= Protocol Errors =============
#[error("Invalid handshake: {0}")] #[error("Invalid handshake: {0}")]
InvalidHandshake(String), InvalidHandshake(String),
#[error("Invalid protocol tag: {0:02x?}")] #[error("Invalid protocol tag: {0:02x?}")]
InvalidProtoTag([u8; 4]), InvalidProtoTag([u8; 4]),
#[error("Invalid TLS record: type={record_type}, version={version:02x?}")] #[error("Invalid TLS record: type={record_type}, version={version:02x?}")]
InvalidTlsRecord { record_type: u8, version: [u8; 2] }, InvalidTlsRecord { record_type: u8, version: [u8; 2] },
#[error("Replay attack detected from {addr}")] #[error("Replay attack detected from {addr}")]
ReplayAttack { addr: SocketAddr }, ReplayAttack { addr: SocketAddr },
#[error("Time skew detected: client={client_time}, server={server_time}")] #[error("Time skew detected: client={client_time}, server={server_time}")]
TimeSkew { client_time: u32, server_time: u32 }, TimeSkew { client_time: u32, server_time: u32 },
#[error("Invalid message length: {len} (min={min}, max={max})")] #[error("Invalid message length: {len} (min={min}, max={max})")]
InvalidMessageLength { len: usize, min: usize, max: usize }, InvalidMessageLength { len: usize, min: usize, max: usize },
#[error("Checksum mismatch: expected={expected:08x}, got={got:08x}")] #[error("Checksum mismatch: expected={expected:08x}, got={got:08x}")]
ChecksumMismatch { expected: u32, got: u32 }, ChecksumMismatch { expected: u32, got: u32 },
#[error("Sequence number mismatch: expected={expected}, got={got}")] #[error("Sequence number mismatch: expected={expected}, got={got}")]
SeqNoMismatch { expected: i32, got: i32 }, SeqNoMismatch { expected: i32, got: i32 },
#[error("TLS handshake failed: {reason}")] #[error("TLS handshake failed: {reason}")]
TlsHandshakeFailed { reason: String }, TlsHandshakeFailed { reason: String },
#[error("Telegram handshake timeout")] #[error("Telegram handshake timeout")]
TgHandshakeTimeout, TgHandshakeTimeout,
// ============= Network Errors ============= // ============= Network Errors =============
#[error("Connection timeout to {addr}")] #[error("Connection timeout to {addr}")]
ConnectionTimeout { addr: String }, ConnectionTimeout { addr: String },
#[error("Connection refused by {addr}")] #[error("Connection refused by {addr}")]
ConnectionRefused { addr: String }, ConnectionRefused { addr: String },
#[error("IO error: {0}")] #[error("IO error: {0}")]
Io(#[from] std::io::Error), Io(#[from] std::io::Error),
// ============= Proxy Protocol Errors ============= // ============= Proxy Protocol Errors =============
#[error("Invalid proxy protocol header")] #[error("Invalid proxy protocol header")]
InvalidProxyProtocol, InvalidProxyProtocol,
#[error("Proxy error: {0}")] #[error("Proxy error: {0}")]
Proxy(String), Proxy(String),
// ============= Config Errors ============= // ============= Config Errors =============
#[error("Config error: {0}")] #[error("Config error: {0}")]
Config(String), Config(String),
#[error("Invalid secret for user {user}: {reason}")] #[error("Invalid secret for user {user}: {reason}")]
InvalidSecret { user: String, reason: String }, InvalidSecret { user: String, reason: String },
// ============= User Errors ============= // ============= User Errors =============
#[error("User {user} expired")] #[error("User {user} expired")]
UserExpired { user: String }, UserExpired { user: String },
#[error("User {user} exceeded connection limit")] #[error("User {user} exceeded connection limit")]
ConnectionLimitExceeded { user: String }, ConnectionLimitExceeded { user: String },
#[error("User {user} exceeded data quota")] #[error("User {user} exceeded data quota")]
DataQuotaExceeded { user: String }, DataQuotaExceeded { user: String },
#[error("Unknown user")] #[error("Unknown user")]
UnknownUser, UnknownUser,
#[error("Rate limited")] #[error("Rate limited")]
RateLimited, RateLimited,
// ============= General Errors ============= // ============= General Errors =============
#[error("Internal error: {0}")] #[error("Internal error: {0}")]
Internal(String), Internal(String),
} }
@ -257,7 +269,7 @@ impl Recoverable for ProxyError {
_ => false, _ => false,
} }
} }
fn can_continue(&self) -> bool { fn can_continue(&self) -> bool {
match self { match self {
Self::Stream(e) => e.can_continue(), Self::Stream(e) => e.can_continue(),
@ -289,19 +301,17 @@ impl<T, R, W> HandshakeResult<T, R, W> {
pub fn is_success(&self) -> bool { pub fn is_success(&self) -> bool {
matches!(self, HandshakeResult::Success(_)) matches!(self, HandshakeResult::Success(_))
} }
/// Check if bad client /// Check if bad client
pub fn is_bad_client(&self) -> bool { pub fn is_bad_client(&self) -> bool {
matches!(self, HandshakeResult::BadClient { .. }) matches!(self, HandshakeResult::BadClient { .. })
} }
/// Map the success value /// Map the success value
pub fn map<U, F: FnOnce(T) -> U>(self, f: F) -> HandshakeResult<U, R, W> { pub fn map<U, F: FnOnce(T) -> U>(self, f: F) -> HandshakeResult<U, R, W> {
match self { match self {
HandshakeResult::Success(v) => HandshakeResult::Success(f(v)), HandshakeResult::Success(v) => HandshakeResult::Success(f(v)),
HandshakeResult::BadClient { reader, writer } => { HandshakeResult::BadClient { reader, writer } => HandshakeResult::BadClient { reader, writer },
HandshakeResult::BadClient { reader, writer }
}
HandshakeResult::Error(e) => HandshakeResult::Error(e), HandshakeResult::Error(e) => HandshakeResult::Error(e),
} }
} }
@ -328,104 +338,76 @@ impl<T, R, W> From<StreamError> for HandshakeResult<T, R, W> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
#[test] #[test]
fn test_stream_error_display() { fn test_stream_error_display() {
let err = StreamError::PartialRead { let err = StreamError::PartialRead { expected: 100, got: 50 };
expected: 100,
got: 50,
};
assert!(err.to_string().contains("100")); assert!(err.to_string().contains("100"));
assert!(err.to_string().contains("50")); assert!(err.to_string().contains("50"));
let err = StreamError::Poisoned { let err = StreamError::Poisoned { reason: "test".into() };
reason: "test".into(),
};
assert!(err.to_string().contains("test")); assert!(err.to_string().contains("test"));
} }
#[test] #[test]
fn test_stream_error_recoverable() { fn test_stream_error_recoverable() {
assert!( assert!(StreamError::PartialRead { expected: 10, got: 5 }.is_recoverable());
StreamError::PartialRead { assert!(StreamError::PartialWrite { expected: 10, written: 5 }.is_recoverable());
expected: 10,
got: 5
}
.is_recoverable()
);
assert!(
StreamError::PartialWrite {
expected: 10,
written: 5
}
.is_recoverable()
);
assert!(!StreamError::Poisoned { reason: "x".into() }.is_recoverable()); assert!(!StreamError::Poisoned { reason: "x".into() }.is_recoverable());
assert!(!StreamError::UnexpectedEof.is_recoverable()); assert!(!StreamError::UnexpectedEof.is_recoverable());
} }
#[test] #[test]
fn test_stream_error_can_continue() { fn test_stream_error_can_continue() {
assert!(!StreamError::Poisoned { reason: "x".into() }.can_continue()); assert!(!StreamError::Poisoned { reason: "x".into() }.can_continue());
assert!(!StreamError::UnexpectedEof.can_continue()); assert!(!StreamError::UnexpectedEof.can_continue());
assert!( assert!(StreamError::PartialRead { expected: 10, got: 5 }.can_continue());
StreamError::PartialRead {
expected: 10,
got: 5
}
.can_continue()
);
} }
#[test] #[test]
fn test_stream_error_to_io_error() { fn test_stream_error_to_io_error() {
let stream_err = StreamError::UnexpectedEof; let stream_err = StreamError::UnexpectedEof;
let io_err: std::io::Error = stream_err.into(); let io_err: std::io::Error = stream_err.into();
assert_eq!(io_err.kind(), std::io::ErrorKind::UnexpectedEof); assert_eq!(io_err.kind(), std::io::ErrorKind::UnexpectedEof);
} }
#[test] #[test]
fn test_handshake_result() { fn test_handshake_result() {
let success: HandshakeResult<i32, (), ()> = HandshakeResult::Success(42); let success: HandshakeResult<i32, (), ()> = HandshakeResult::Success(42);
assert!(success.is_success()); assert!(success.is_success());
assert!(!success.is_bad_client()); assert!(!success.is_bad_client());
let bad: HandshakeResult<i32, (), ()> = HandshakeResult::BadClient { let bad: HandshakeResult<i32, (), ()> = HandshakeResult::BadClient { reader: (), writer: () };
reader: (),
writer: (),
};
assert!(!bad.is_success()); assert!(!bad.is_success());
assert!(bad.is_bad_client()); assert!(bad.is_bad_client());
} }
#[test] #[test]
fn test_handshake_result_map() { fn test_handshake_result_map() {
let success: HandshakeResult<i32, (), ()> = HandshakeResult::Success(42); let success: HandshakeResult<i32, (), ()> = HandshakeResult::Success(42);
let mapped = success.map(|x| x * 2); let mapped = success.map(|x| x * 2);
match mapped { match mapped {
HandshakeResult::Success(v) => assert_eq!(v, 84), HandshakeResult::Success(v) => assert_eq!(v, 84),
_ => panic!("Expected success"), _ => panic!("Expected success"),
} }
} }
#[test] #[test]
fn test_proxy_error_recoverable() { fn test_proxy_error_recoverable() {
let err = ProxyError::RateLimited; let err = ProxyError::RateLimited;
assert!(err.is_recoverable()); assert!(err.is_recoverable());
let err = ProxyError::InvalidHandshake("bad".into()); let err = ProxyError::InvalidHandshake("bad".into());
assert!(!err.is_recoverable()); assert!(!err.is_recoverable());
} }
#[test] #[test]
fn test_error_display() { fn test_error_display() {
let err = ProxyError::ConnectionTimeout { let err = ProxyError::ConnectionTimeout { addr: "1.2.3.4:443".into() };
addr: "1.2.3.4:443".into(),
};
assert!(err.to_string().contains("1.2.3.4:443")); assert!(err.to_string().contains("1.2.3.4:443"));
let err = ProxyError::InvalidProxyProtocol; let err = ProxyError::InvalidProxyProtocol;
assert!(err.to_string().contains("proxy protocol")); assert!(err.to_string().contains("proxy protocol"));
} }
} }

View File

@ -5,9 +5,9 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::net::IpAddr; use std::net::IpAddr;
use std::sync::Arc; use std::sync::Arc;
use std::sync::Mutex;
use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use std::sync::Mutex;
use tokio::sync::{Mutex as AsyncMutex, RwLock}; use tokio::sync::{Mutex as AsyncMutex, RwLock};
@ -22,7 +22,7 @@ pub struct UserIpTracker {
limit_mode: Arc<RwLock<UserMaxUniqueIpsMode>>, limit_mode: Arc<RwLock<UserMaxUniqueIpsMode>>,
limit_window: Arc<RwLock<Duration>>, limit_window: Arc<RwLock<Duration>>,
last_compact_epoch_secs: Arc<AtomicU64>, last_compact_epoch_secs: Arc<AtomicU64>,
cleanup_queue: Arc<Mutex<Vec<(String, IpAddr)>>>, pub(crate) cleanup_queue: Arc<Mutex<Vec<(String, IpAddr)>>>,
cleanup_drain_lock: Arc<AsyncMutex<()>>, cleanup_drain_lock: Arc<AsyncMutex<()>>,
} }
@ -41,6 +41,7 @@ impl UserIpTracker {
} }
} }
pub fn enqueue_cleanup(&self, user: String, ip: IpAddr) { pub fn enqueue_cleanup(&self, user: String, ip: IpAddr) {
match self.cleanup_queue.lock() { match self.cleanup_queue.lock() {
Ok(mut queue) => queue.push((user, ip)), Ok(mut queue) => queue.push((user, ip)),
@ -57,19 +58,6 @@ impl UserIpTracker {
} }
} }
#[cfg(test)]
pub(crate) fn cleanup_queue_len_for_tests(&self) -> usize {
self.cleanup_queue
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
.len()
}
#[cfg(test)]
pub(crate) fn cleanup_queue_mutex_for_tests(&self) -> Arc<Mutex<Vec<(String, IpAddr)>>> {
Arc::clone(&self.cleanup_queue)
}
pub(crate) async fn drain_cleanup_queue(&self) { pub(crate) async fn drain_cleanup_queue(&self) {
// Serialize queue draining and active-IP mutation so check-and-add cannot // Serialize queue draining and active-IP mutation so check-and-add cannot
// observe stale active entries that are already queued for removal. // observe stale active entries that are already queued for removal.
@ -141,8 +129,7 @@ impl UserIpTracker {
let mut active_ips = self.active_ips.write().await; let mut active_ips = self.active_ips.write().await;
let mut recent_ips = self.recent_ips.write().await; let mut recent_ips = self.recent_ips.write().await;
let mut users = let mut users = Vec::<String>::with_capacity(active_ips.len().saturating_add(recent_ips.len()));
Vec::<String>::with_capacity(active_ips.len().saturating_add(recent_ips.len()));
users.extend(active_ips.keys().cloned()); users.extend(active_ips.keys().cloned());
for user in recent_ips.keys() { for user in recent_ips.keys() {
if !active_ips.contains_key(user) { if !active_ips.contains_key(user) {
@ -151,14 +138,8 @@ impl UserIpTracker {
} }
for user in users { for user in users {
let active_empty = active_ips let active_empty = active_ips.get(&user).map(|ips| ips.is_empty()).unwrap_or(true);
.get(&user) let recent_empty = recent_ips.get(&user).map(|ips| ips.is_empty()).unwrap_or(true);
.map(|ips| ips.is_empty())
.unwrap_or(true);
let recent_empty = recent_ips
.get(&user)
.map(|ips| ips.is_empty())
.unwrap_or(true);
if active_empty && recent_empty { if active_empty && recent_empty {
active_ips.remove(&user); active_ips.remove(&user);
recent_ips.remove(&user); recent_ips.remove(&user);

View File

@ -1,5 +1,3 @@
#![allow(clippy::too_many_arguments)]
use std::sync::Arc; use std::sync::Arc;
use std::time::Instant; use std::time::Instant;
@ -13,10 +11,10 @@ use crate::startup::{
COMPONENT_DC_CONNECTIVITY_PING, COMPONENT_ME_CONNECTIVITY_PING, COMPONENT_RUNTIME_READY, COMPONENT_DC_CONNECTIVITY_PING, COMPONENT_ME_CONNECTIVITY_PING, COMPONENT_RUNTIME_READY,
StartupTracker, StartupTracker,
}; };
use crate::transport::UpstreamManager;
use crate::transport::middle_proxy::{ use crate::transport::middle_proxy::{
MePingFamily, MePingSample, MePool, format_me_route, format_sample_line, run_me_ping, MePingFamily, MePingSample, MePool, format_me_route, format_sample_line, run_me_ping,
}; };
use crate::transport::UpstreamManager;
pub(crate) async fn run_startup_connectivity( pub(crate) async fn run_startup_connectivity(
config: &Arc<ProxyConfig>, config: &Arc<ProxyConfig>,
@ -49,15 +47,11 @@ pub(crate) async fn run_startup_connectivity(
let v4_ok = me_results.iter().any(|r| { let v4_ok = me_results.iter().any(|r| {
matches!(r.family, MePingFamily::V4) matches!(r.family, MePingFamily::V4)
&& r.samples && r.samples.iter().any(|s| s.error.is_none() && s.handshake_ms.is_some())
.iter()
.any(|s| s.error.is_none() && s.handshake_ms.is_some())
}); });
let v6_ok = me_results.iter().any(|r| { let v6_ok = me_results.iter().any(|r| {
matches!(r.family, MePingFamily::V6) matches!(r.family, MePingFamily::V6)
&& r.samples && r.samples.iter().any(|s| s.error.is_none() && s.handshake_ms.is_some())
.iter()
.any(|s| s.error.is_none() && s.handshake_ms.is_some())
}); });
info!("================= Telegram ME Connectivity ================="); info!("================= Telegram ME Connectivity =================");
@ -137,14 +131,8 @@ pub(crate) async fn run_startup_connectivity(
.await; .await;
for upstream_result in &ping_results { for upstream_result in &ping_results {
let v6_works = upstream_result let v6_works = upstream_result.v6_results.iter().any(|r| r.rtt_ms.is_some());
.v6_results let v4_works = upstream_result.v4_results.iter().any(|r| r.rtt_ms.is_some());
.iter()
.any(|r| r.rtt_ms.is_some());
let v4_works = upstream_result
.v4_results
.iter()
.any(|r| r.rtt_ms.is_some());
if upstream_result.both_available { if upstream_result.both_available {
if prefer_ipv6 { if prefer_ipv6 {

View File

@ -1,7 +1,5 @@
#![allow(clippy::items_after_test_module)]
use std::path::PathBuf;
use std::time::Duration; use std::time::Duration;
use std::path::PathBuf;
use tokio::sync::watch; use tokio::sync::watch;
use tracing::{debug, error, info, warn}; use tracing::{debug, error, info, warn};
@ -12,10 +10,7 @@ use crate::transport::middle_proxy::{
ProxyConfigData, fetch_proxy_config_with_raw, load_proxy_config_cache, save_proxy_config_cache, ProxyConfigData, fetch_proxy_config_with_raw, load_proxy_config_cache, save_proxy_config_cache,
}; };
pub(crate) fn resolve_runtime_config_path( pub(crate) fn resolve_runtime_config_path(config_path_cli: &str, startup_cwd: &std::path::Path) -> PathBuf {
config_path_cli: &str,
startup_cwd: &std::path::Path,
) -> PathBuf {
let raw = PathBuf::from(config_path_cli); let raw = PathBuf::from(config_path_cli);
let absolute = if raw.is_absolute() { let absolute = if raw.is_absolute() {
raw raw
@ -55,9 +50,7 @@ pub(crate) fn parse_cli() -> (String, Option<PathBuf>, bool, Option<String>) {
} }
} }
s if s.starts_with("--data-path=") => { s if s.starts_with("--data-path=") => {
data_path = Some(PathBuf::from( data_path = Some(PathBuf::from(s.trim_start_matches("--data-path=").to_string()));
s.trim_start_matches("--data-path=").to_string(),
));
} }
"--silent" | "-s" => { "--silent" | "-s" => {
silent = true; silent = true;
@ -75,9 +68,7 @@ pub(crate) fn parse_cli() -> (String, Option<PathBuf>, bool, Option<String>) {
eprintln!("Usage: telemt [config.toml] [OPTIONS]"); eprintln!("Usage: telemt [config.toml] [OPTIONS]");
eprintln!(); eprintln!();
eprintln!("Options:"); eprintln!("Options:");
eprintln!( eprintln!(" --data-path <DIR> Set data directory (absolute path; overrides config value)");
" --data-path <DIR> Set data directory (absolute path; overrides config value)"
);
eprintln!(" --silent, -s Suppress info logs"); eprintln!(" --silent, -s Suppress info logs");
eprintln!(" --log-level <LEVEL> debug|verbose|normal|silent"); eprintln!(" --log-level <LEVEL> debug|verbose|normal|silent");
eprintln!(" --help, -h Show this help"); eprintln!(" --help, -h Show this help");
@ -155,12 +146,7 @@ mod tests {
pub(crate) fn print_proxy_links(host: &str, port: u16, config: &ProxyConfig) { pub(crate) fn print_proxy_links(host: &str, port: u16, config: &ProxyConfig) {
info!(target: "telemt::links", "--- Proxy Links ({}) ---", host); info!(target: "telemt::links", "--- Proxy Links ({}) ---", host);
for user_name in config for user_name in config.general.links.show.resolve_users(&config.access.users) {
.general
.links
.show
.resolve_users(&config.access.users)
{
if let Some(secret) = config.access.users.get(user_name) { if let Some(secret) = config.access.users.get(user_name) {
info!(target: "telemt::links", "User: {}", user_name); info!(target: "telemt::links", "User: {}", user_name);
if config.general.modes.classic { if config.general.modes.classic {
@ -301,10 +287,7 @@ pub(crate) async fn load_startup_proxy_config_snapshot(
return Some(cfg); return Some(cfg);
} }
warn!( warn!(snapshot = label, url, "Startup proxy-config is empty; trying disk cache");
snapshot = label,
url, "Startup proxy-config is empty; trying disk cache"
);
if let Some(path) = cache_path { if let Some(path) = cache_path {
match load_proxy_config_cache(path).await { match load_proxy_config_cache(path).await {
Ok(cached) if !cached.map.is_empty() => { Ok(cached) if !cached.map.is_empty() => {
@ -319,7 +302,8 @@ pub(crate) async fn load_startup_proxy_config_snapshot(
Ok(_) => { Ok(_) => {
warn!( warn!(
snapshot = label, snapshot = label,
path, "Startup proxy-config cache is empty; ignoring cache file" path,
"Startup proxy-config cache is empty; ignoring cache file"
); );
} }
Err(cache_err) => { Err(cache_err) => {
@ -363,7 +347,8 @@ pub(crate) async fn load_startup_proxy_config_snapshot(
Ok(_) => { Ok(_) => {
warn!( warn!(
snapshot = label, snapshot = label,
path, "Startup proxy-config cache is empty; ignoring cache file" path,
"Startup proxy-config cache is empty; ignoring cache file"
); );
} }
Err(cache_err) => { Err(cache_err) => {

View File

@ -12,15 +12,17 @@ use tracing::{debug, error, info, warn};
use crate::config::ProxyConfig; use crate::config::ProxyConfig;
use crate::crypto::SecureRandom; use crate::crypto::SecureRandom;
use crate::ip_tracker::UserIpTracker; use crate::ip_tracker::UserIpTracker;
use crate::proxy::ClientHandler;
use crate::proxy::route_mode::{ROUTE_SWITCH_ERROR_MSG, RouteRuntimeController}; use crate::proxy::route_mode::{ROUTE_SWITCH_ERROR_MSG, RouteRuntimeController};
use crate::proxy::ClientHandler;
use crate::startup::{COMPONENT_LISTENERS_BIND, StartupTracker}; use crate::startup::{COMPONENT_LISTENERS_BIND, StartupTracker};
use crate::stats::beobachten::BeobachtenStore; use crate::stats::beobachten::BeobachtenStore;
use crate::stats::{ReplayChecker, Stats}; use crate::stats::{ReplayChecker, Stats};
use crate::stream::BufferPool; use crate::stream::BufferPool;
use crate::tls_front::TlsFrontCache; use crate::tls_front::TlsFrontCache;
use crate::transport::middle_proxy::MePool; use crate::transport::middle_proxy::MePool;
use crate::transport::{ListenOptions, UpstreamManager, create_listener, find_listener_processes}; use crate::transport::{
ListenOptions, UpstreamManager, create_listener, find_listener_processes,
};
use super::helpers::{is_expected_handshake_eof, print_proxy_links}; use super::helpers::{is_expected_handshake_eof, print_proxy_links};
@ -79,9 +81,8 @@ pub(crate) async fn bind_listeners(
Ok(socket) => { Ok(socket) => {
let listener = TcpListener::from_std(socket.into())?; let listener = TcpListener::from_std(socket.into())?;
info!("Listening on {}", addr); info!("Listening on {}", addr);
let listener_proxy_protocol = listener_conf let listener_proxy_protocol =
.proxy_protocol listener_conf.proxy_protocol.unwrap_or(config.server.proxy_protocol);
.unwrap_or(config.server.proxy_protocol);
let public_host = if let Some(ref announce) = listener_conf.announce { let public_host = if let Some(ref announce) = listener_conf.announce {
announce.clone() announce.clone()
@ -99,14 +100,8 @@ pub(crate) async fn bind_listeners(
listener_conf.ip.to_string() listener_conf.ip.to_string()
}; };
if config.general.links.public_host.is_none() if config.general.links.public_host.is_none() && !config.general.links.show.is_empty() {
&& !config.general.links.show.is_empty() let link_port = config.general.links.public_port.unwrap_or(config.server.port);
{
let link_port = config
.general
.links
.public_port
.unwrap_or(config.server.port);
print_proxy_links(&public_host, link_port, config); print_proxy_links(&public_host, link_port, config);
} }
@ -150,14 +145,12 @@ pub(crate) async fn bind_listeners(
let (host, port) = if let Some(ref h) = config.general.links.public_host { let (host, port) = if let Some(ref h) = config.general.links.public_host {
( (
h.clone(), h.clone(),
config config.general.links.public_port.unwrap_or(config.server.port),
.general
.links
.public_port
.unwrap_or(config.server.port),
) )
} else { } else {
let ip = detected_ip_v4.or(detected_ip_v6).map(|ip| ip.to_string()); let ip = detected_ip_v4
.or(detected_ip_v6)
.map(|ip| ip.to_string());
if ip.is_none() { if ip.is_none() {
warn!( warn!(
"show_link is configured but public IP could not be detected. Set public_host in config." "show_link is configured but public IP could not be detected. Set public_host in config."
@ -165,11 +158,7 @@ pub(crate) async fn bind_listeners(
} }
( (
ip.unwrap_or_else(|| "UNKNOWN".to_string()), ip.unwrap_or_else(|| "UNKNOWN".to_string()),
config config.general.links.public_port.unwrap_or(config.server.port),
.general
.links
.public_port
.unwrap_or(config.server.port),
) )
}; };
@ -189,19 +178,13 @@ pub(crate) async fn bind_listeners(
use std::os::unix::fs::PermissionsExt; use std::os::unix::fs::PermissionsExt;
let perms = std::fs::Permissions::from_mode(mode); let perms = std::fs::Permissions::from_mode(mode);
if let Err(e) = std::fs::set_permissions(unix_path, perms) { if let Err(e) = std::fs::set_permissions(unix_path, perms) {
error!( error!("Failed to set unix socket permissions to {}: {}", perm_str, e);
"Failed to set unix socket permissions to {}: {}",
perm_str, e
);
} else { } else {
info!("Listening on unix:{} (mode {})", unix_path, perm_str); info!("Listening on unix:{} (mode {})", unix_path, perm_str);
} }
} }
Err(e) => { Err(e) => {
warn!( warn!("Invalid listen_unix_sock_perm '{}': {}. Ignoring.", perm_str, e);
"Invalid listen_unix_sock_perm '{}': {}. Ignoring.",
perm_str, e
);
info!("Listening on unix:{}", unix_path); info!("Listening on unix:{}", unix_path);
} }
} }
@ -235,8 +218,10 @@ pub(crate) async fn bind_listeners(
drop(stream); drop(stream);
continue; continue;
} }
let accept_permit_timeout_ms = let accept_permit_timeout_ms = config_rx_unix
config_rx_unix.borrow().server.accept_permit_timeout_ms; .borrow()
.server
.accept_permit_timeout_ms;
let permit = if accept_permit_timeout_ms == 0 { let permit = if accept_permit_timeout_ms == 0 {
match max_connections_unix.clone().acquire_owned().await { match max_connections_unix.clone().acquire_owned().await {
Ok(permit) => permit, Ok(permit) => permit,
@ -376,8 +361,10 @@ pub(crate) fn spawn_tcp_accept_loops(
drop(stream); drop(stream);
continue; continue;
} }
let accept_permit_timeout_ms = let accept_permit_timeout_ms = config_rx
config_rx.borrow().server.accept_permit_timeout_ms; .borrow()
.server
.accept_permit_timeout_ms;
let permit = if accept_permit_timeout_ms == 0 { let permit = if accept_permit_timeout_ms == 0 {
match max_connections_tcp.clone().acquire_owned().await { match max_connections_tcp.clone().acquire_owned().await {
Ok(permit) => permit, Ok(permit) => permit,

View File

@ -1,5 +1,3 @@
#![allow(clippy::too_many_arguments)]
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
@ -14,8 +12,8 @@ use crate::startup::{
COMPONENT_ME_PROXY_CONFIG_V6, COMPONENT_ME_SECRET_FETCH, StartupMeStatus, StartupTracker, COMPONENT_ME_PROXY_CONFIG_V6, COMPONENT_ME_SECRET_FETCH, StartupMeStatus, StartupTracker,
}; };
use crate::stats::Stats; use crate::stats::Stats;
use crate::transport::UpstreamManager;
use crate::transport::middle_proxy::MePool; use crate::transport::middle_proxy::MePool;
use crate::transport::UpstreamManager;
use super::helpers::load_startup_proxy_config_snapshot; use super::helpers::load_startup_proxy_config_snapshot;
@ -231,12 +229,8 @@ pub(crate) async fn initialize_me_pool(
config.general.me_adaptive_floor_recover_grace_secs, config.general.me_adaptive_floor_recover_grace_secs,
config.general.me_adaptive_floor_writers_per_core_total, config.general.me_adaptive_floor_writers_per_core_total,
config.general.me_adaptive_floor_cpu_cores_override, config.general.me_adaptive_floor_cpu_cores_override,
config config.general.me_adaptive_floor_max_extra_writers_single_per_core,
.general config.general.me_adaptive_floor_max_extra_writers_multi_per_core,
.me_adaptive_floor_max_extra_writers_single_per_core,
config
.general
.me_adaptive_floor_max_extra_writers_multi_per_core,
config.general.me_adaptive_floor_max_active_writers_per_core, config.general.me_adaptive_floor_max_active_writers_per_core,
config.general.me_adaptive_floor_max_warm_writers_per_core, config.general.me_adaptive_floor_max_warm_writers_per_core,
config.general.me_adaptive_floor_max_active_writers_global, config.general.me_adaptive_floor_max_active_writers_global,
@ -463,70 +457,64 @@ pub(crate) async fn initialize_me_pool(
"Middle-End pool initialized successfully" "Middle-End pool initialized successfully"
); );
// ── Supervised background tasks ────────────────── // ── Supervised background tasks ──────────────────
let pool_clone = pool.clone(); let pool_clone = pool.clone();
let rng_clone = rng.clone(); let rng_clone = rng.clone();
let min_conns = pool_size; let min_conns = pool_size;
tokio::spawn(async move { tokio::spawn(async move {
loop { loop {
let p = pool_clone.clone(); let p = pool_clone.clone();
let r = rng_clone.clone(); let r = rng_clone.clone();
let res = tokio::spawn(async move { let res = tokio::spawn(async move {
crate::transport::middle_proxy::me_health_monitor( crate::transport::middle_proxy::me_health_monitor(
p, r, min_conns, p, r, min_conns,
) )
.await;
})
.await; .await;
}) match res {
.await; Ok(()) => warn!("me_health_monitor exited unexpectedly, restarting"),
match res { Err(e) => {
Ok(()) => warn!( error!(error = %e, "me_health_monitor panicked, restarting in 1s");
"me_health_monitor exited unexpectedly, restarting" tokio::time::sleep(Duration::from_secs(1)).await;
), }
Err(e) => {
error!(error = %e, "me_health_monitor panicked, restarting in 1s");
tokio::time::sleep(Duration::from_secs(1)).await;
} }
} }
} });
}); let pool_drain_enforcer = pool.clone();
let pool_drain_enforcer = pool.clone(); tokio::spawn(async move {
tokio::spawn(async move { loop {
loop { let p = pool_drain_enforcer.clone();
let p = pool_drain_enforcer.clone(); let res = tokio::spawn(async move {
let res = tokio::spawn(async move {
crate::transport::middle_proxy::me_drain_timeout_enforcer(p).await; crate::transport::middle_proxy::me_drain_timeout_enforcer(p).await;
}) })
.await; .await;
match res { match res {
Ok(()) => warn!( Ok(()) => warn!("me_drain_timeout_enforcer exited unexpectedly, restarting"),
"me_drain_timeout_enforcer exited unexpectedly, restarting" Err(e) => {
), error!(error = %e, "me_drain_timeout_enforcer panicked, restarting in 1s");
Err(e) => { tokio::time::sleep(Duration::from_secs(1)).await;
error!(error = %e, "me_drain_timeout_enforcer panicked, restarting in 1s"); }
tokio::time::sleep(Duration::from_secs(1)).await;
} }
} }
} });
}); let pool_watchdog = pool.clone();
let pool_watchdog = pool.clone(); tokio::spawn(async move {
tokio::spawn(async move { loop {
loop { let p = pool_watchdog.clone();
let p = pool_watchdog.clone(); let res = tokio::spawn(async move {
let res = tokio::spawn(async move {
crate::transport::middle_proxy::me_zombie_writer_watchdog(p).await; crate::transport::middle_proxy::me_zombie_writer_watchdog(p).await;
}) })
.await; .await;
match res { match res {
Ok(()) => warn!( Ok(()) => warn!("me_zombie_writer_watchdog exited unexpectedly, restarting"),
"me_zombie_writer_watchdog exited unexpectedly, restarting" Err(e) => {
), error!(error = %e, "me_zombie_writer_watchdog panicked, restarting in 1s");
Err(e) => { tokio::time::sleep(Duration::from_secs(1)).await;
error!(error = %e, "me_zombie_writer_watchdog panicked, restarting in 1s"); }
tokio::time::sleep(Duration::from_secs(1)).await;
} }
} }
} });
});
break Some(pool); break Some(pool);
} }

View File

@ -11,9 +11,9 @@
// - admission: conditional-cast gate and route mode switching. // - admission: conditional-cast gate and route mode switching.
// - listeners: TCP/Unix listener bind and accept-loop orchestration. // - listeners: TCP/Unix listener bind and accept-loop orchestration.
// - shutdown: graceful shutdown sequence and uptime logging. // - shutdown: graceful shutdown sequence and uptime logging.
mod helpers;
mod admission; mod admission;
mod connectivity; mod connectivity;
mod helpers;
mod listeners; mod listeners;
mod me_startup; mod me_startup;
mod runtime_tasks; mod runtime_tasks;
@ -33,18 +33,18 @@ use crate::crypto::SecureRandom;
use crate::ip_tracker::UserIpTracker; use crate::ip_tracker::UserIpTracker;
use crate::network::probe::{decide_network_capabilities, log_probe_result, run_probe}; use crate::network::probe::{decide_network_capabilities, log_probe_result, run_probe};
use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController}; use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController};
use crate::startup::{
COMPONENT_API_BOOTSTRAP, COMPONENT_CONFIG_LOAD, COMPONENT_ME_POOL_CONSTRUCT,
COMPONENT_ME_POOL_INIT_STAGE1, COMPONENT_ME_PROXY_CONFIG_V4, COMPONENT_ME_PROXY_CONFIG_V6,
COMPONENT_ME_SECRET_FETCH, COMPONENT_NETWORK_PROBE, COMPONENT_TRACING_INIT, StartupMeStatus,
StartupTracker,
};
use crate::stats::beobachten::BeobachtenStore; use crate::stats::beobachten::BeobachtenStore;
use crate::stats::telemetry::TelemetryPolicy; use crate::stats::telemetry::TelemetryPolicy;
use crate::stats::{ReplayChecker, Stats}; use crate::stats::{ReplayChecker, Stats};
use crate::startup::{
COMPONENT_API_BOOTSTRAP, COMPONENT_CONFIG_LOAD,
COMPONENT_ME_POOL_CONSTRUCT, COMPONENT_ME_POOL_INIT_STAGE1,
COMPONENT_ME_PROXY_CONFIG_V4, COMPONENT_ME_PROXY_CONFIG_V6, COMPONENT_ME_SECRET_FETCH,
COMPONENT_NETWORK_PROBE, COMPONENT_TRACING_INIT, StartupMeStatus, StartupTracker,
};
use crate::stream::BufferPool; use crate::stream::BufferPool;
use crate::transport::UpstreamManager;
use crate::transport::middle_proxy::MePool; use crate::transport::middle_proxy::MePool;
use crate::transport::UpstreamManager;
use helpers::{parse_cli, resolve_runtime_config_path}; use helpers::{parse_cli, resolve_runtime_config_path};
/// Runs the full telemt runtime startup pipeline and blocks until shutdown. /// Runs the full telemt runtime startup pipeline and blocks until shutdown.
@ -56,10 +56,7 @@ pub async fn run() -> std::result::Result<(), Box<dyn std::error::Error>> {
.as_secs(); .as_secs();
let startup_tracker = Arc::new(StartupTracker::new(process_started_at_epoch_secs)); let startup_tracker = Arc::new(StartupTracker::new(process_started_at_epoch_secs));
startup_tracker startup_tracker
.start_component( .start_component(COMPONENT_CONFIG_LOAD, Some("load and validate config".to_string()))
COMPONENT_CONFIG_LOAD,
Some("load and validate config".to_string()),
)
.await; .await;
let (config_path_cli, data_path, cli_silent, cli_log_level) = parse_cli(); let (config_path_cli, data_path, cli_silent, cli_log_level) = parse_cli();
let startup_cwd = match std::env::current_dir() { let startup_cwd = match std::env::current_dir() {
@ -80,10 +77,7 @@ pub async fn run() -> std::result::Result<(), Box<dyn std::error::Error>> {
} else { } else {
let default = ProxyConfig::default(); let default = ProxyConfig::default();
std::fs::write(&config_path, toml::to_string_pretty(&default).unwrap()).unwrap(); std::fs::write(&config_path, toml::to_string_pretty(&default).unwrap()).unwrap();
eprintln!( eprintln!("[telemt] Created default config at {}", config_path.display());
"[telemt] Created default config at {}",
config_path.display()
);
default default
} }
} }
@ -100,38 +94,24 @@ pub async fn run() -> std::result::Result<(), Box<dyn std::error::Error>> {
if let Some(ref data_path) = config.general.data_path { if let Some(ref data_path) = config.general.data_path {
if !data_path.is_absolute() { if !data_path.is_absolute() {
eprintln!( eprintln!("[telemt] data_path must be absolute: {}", data_path.display());
"[telemt] data_path must be absolute: {}",
data_path.display()
);
std::process::exit(1); std::process::exit(1);
} }
if data_path.exists() { if data_path.exists() {
if !data_path.is_dir() { if !data_path.is_dir() {
eprintln!( eprintln!("[telemt] data_path exists but is not a directory: {}", data_path.display());
"[telemt] data_path exists but is not a directory: {}",
data_path.display()
);
std::process::exit(1); std::process::exit(1);
} }
} else { } else {
if let Err(e) = std::fs::create_dir_all(data_path) { if let Err(e) = std::fs::create_dir_all(data_path) {
eprintln!( eprintln!("[telemt] Can't create data_path {}: {}", data_path.display(), e);
"[telemt] Can't create data_path {}: {}",
data_path.display(),
e
);
std::process::exit(1); std::process::exit(1);
} }
} }
if let Err(e) = std::env::set_current_dir(data_path) { if let Err(e) = std::env::set_current_dir(data_path) {
eprintln!( eprintln!("[telemt] Can't use data_path {}: {}", data_path.display(), e);
"[telemt] Can't use data_path {}: {}",
data_path.display(),
e
);
std::process::exit(1); std::process::exit(1);
} }
} }
@ -155,10 +135,7 @@ pub async fn run() -> std::result::Result<(), Box<dyn std::error::Error>> {
let (filter_layer, filter_handle) = reload::Layer::new(EnvFilter::new("info")); let (filter_layer, filter_handle) = reload::Layer::new(EnvFilter::new("info"));
startup_tracker startup_tracker
.start_component( .start_component(COMPONENT_TRACING_INIT, Some("initialize tracing subscriber".to_string()))
COMPONENT_TRACING_INIT,
Some("initialize tracing subscriber".to_string()),
)
.await; .await;
// Configure color output based on config // Configure color output based on config
@ -173,10 +150,7 @@ pub async fn run() -> std::result::Result<(), Box<dyn std::error::Error>> {
.with(fmt_layer) .with(fmt_layer)
.init(); .init();
startup_tracker startup_tracker
.complete_component( .complete_component(COMPONENT_TRACING_INIT, Some("tracing initialized".to_string()))
COMPONENT_TRACING_INIT,
Some("tracing initialized".to_string()),
)
.await; .await;
info!("Telemt MTProxy v{}", env!("CARGO_PKG_VERSION")); info!("Telemt MTProxy v{}", env!("CARGO_PKG_VERSION"));
@ -242,8 +216,7 @@ pub async fn run() -> std::result::Result<(), Box<dyn std::error::Error>> {
config.access.user_max_unique_ips_window_secs, config.access.user_max_unique_ips_window_secs,
) )
.await; .await;
if config.access.user_max_unique_ips_global_each > 0 if config.access.user_max_unique_ips_global_each > 0 || !config.access.user_max_unique_ips.is_empty()
|| !config.access.user_max_unique_ips.is_empty()
{ {
info!( info!(
global_each_limit = config.access.user_max_unique_ips_global_each, global_each_limit = config.access.user_max_unique_ips_global_each,
@ -270,10 +243,7 @@ pub async fn run() -> std::result::Result<(), Box<dyn std::error::Error>> {
let route_runtime = Arc::new(RouteRuntimeController::new(initial_route_mode)); let route_runtime = Arc::new(RouteRuntimeController::new(initial_route_mode));
let api_me_pool = Arc::new(RwLock::new(None::<Arc<MePool>>)); let api_me_pool = Arc::new(RwLock::new(None::<Arc<MePool>>));
startup_tracker startup_tracker
.start_component( .start_component(COMPONENT_API_BOOTSTRAP, Some("spawn API listener task".to_string()))
COMPONENT_API_BOOTSTRAP,
Some("spawn API listener task".to_string()),
)
.await; .await;
if config.server.api.enabled { if config.server.api.enabled {
@ -356,10 +326,7 @@ pub async fn run() -> std::result::Result<(), Box<dyn std::error::Error>> {
.await; .await;
startup_tracker startup_tracker
.start_component( .start_component(COMPONENT_NETWORK_PROBE, Some("probe network capabilities".to_string()))
COMPONENT_NETWORK_PROBE,
Some("probe network capabilities".to_string()),
)
.await; .await;
let probe = run_probe( let probe = run_probe(
&config.network, &config.network,
@ -372,8 +339,11 @@ pub async fn run() -> std::result::Result<(), Box<dyn std::error::Error>> {
probe.detected_ipv4.map(IpAddr::V4), probe.detected_ipv4.map(IpAddr::V4),
probe.detected_ipv6.map(IpAddr::V6), probe.detected_ipv6.map(IpAddr::V6),
)); ));
let decision = let decision = decide_network_capabilities(
decide_network_capabilities(&config.network, &probe, config.general.middle_proxy_nat_ip); &config.network,
&probe,
config.general.middle_proxy_nat_ip,
);
log_probe_result(&probe, &decision); log_probe_result(&probe, &decision);
startup_tracker startup_tracker
.complete_component( .complete_component(
@ -476,16 +446,24 @@ pub async fn run() -> std::result::Result<(), Box<dyn std::error::Error>> {
// If ME failed to initialize, force direct-only mode. // If ME failed to initialize, force direct-only mode.
if me_pool.is_some() { if me_pool.is_some() {
startup_tracker.set_transport_mode("middle_proxy").await; startup_tracker
startup_tracker.set_degraded(false).await; .set_transport_mode("middle_proxy")
.await;
startup_tracker
.set_degraded(false)
.await;
info!("Transport: Middle-End Proxy - all DC-over-RPC"); info!("Transport: Middle-End Proxy - all DC-over-RPC");
} else { } else {
let _ = use_middle_proxy; let _ = use_middle_proxy;
use_middle_proxy = false; use_middle_proxy = false;
// Make runtime config reflect direct-only mode for handlers. // Make runtime config reflect direct-only mode for handlers.
config.general.use_middle_proxy = false; config.general.use_middle_proxy = false;
startup_tracker.set_transport_mode("direct").await; startup_tracker
startup_tracker.set_degraded(true).await; .set_transport_mode("direct")
.await;
startup_tracker
.set_degraded(true)
.await;
if me2dc_fallback { if me2dc_fallback {
startup_tracker startup_tracker
.set_me_status(StartupMeStatus::Failed, "fallback_to_direct") .set_me_status(StartupMeStatus::Failed, "fallback_to_direct")

View File

@ -4,24 +4,21 @@ use std::sync::Arc;
use tokio::sync::{mpsc, watch}; use tokio::sync::{mpsc, watch};
use tracing::{debug, warn}; use tracing::{debug, warn};
use tracing_subscriber::EnvFilter;
use tracing_subscriber::reload; use tracing_subscriber::reload;
use tracing_subscriber::EnvFilter;
use crate::config::hot_reload::spawn_config_watcher;
use crate::config::{LogLevel, ProxyConfig}; use crate::config::{LogLevel, ProxyConfig};
use crate::config::hot_reload::spawn_config_watcher;
use crate::crypto::SecureRandom; use crate::crypto::SecureRandom;
use crate::ip_tracker::UserIpTracker; use crate::ip_tracker::UserIpTracker;
use crate::metrics; use crate::metrics;
use crate::network::probe::NetworkProbe; use crate::network::probe::NetworkProbe;
use crate::startup::{ use crate::startup::{COMPONENT_CONFIG_WATCHER_START, COMPONENT_METRICS_START, COMPONENT_RUNTIME_READY, StartupTracker};
COMPONENT_CONFIG_WATCHER_START, COMPONENT_METRICS_START, COMPONENT_RUNTIME_READY,
StartupTracker,
};
use crate::stats::beobachten::BeobachtenStore; use crate::stats::beobachten::BeobachtenStore;
use crate::stats::telemetry::TelemetryPolicy; use crate::stats::telemetry::TelemetryPolicy;
use crate::stats::{ReplayChecker, Stats}; use crate::stats::{ReplayChecker, Stats};
use crate::transport::UpstreamManager;
use crate::transport::middle_proxy::{MePool, MeReinitTrigger}; use crate::transport::middle_proxy::{MePool, MeReinitTrigger};
use crate::transport::UpstreamManager;
use super::helpers::write_beobachten_snapshot; use super::helpers::write_beobachten_snapshot;
@ -82,13 +79,15 @@ pub(crate) async fn spawn_runtime_tasks(
Some("spawn config hot-reload watcher".to_string()), Some("spawn config hot-reload watcher".to_string()),
) )
.await; .await;
let (config_rx, log_level_rx): (watch::Receiver<Arc<ProxyConfig>>, watch::Receiver<LogLevel>) = let (config_rx, log_level_rx): (
spawn_config_watcher( watch::Receiver<Arc<ProxyConfig>>,
config_path.to_path_buf(), watch::Receiver<LogLevel>,
config.clone(), ) = spawn_config_watcher(
detected_ip_v4, config_path.to_path_buf(),
detected_ip_v6, config.clone(),
); detected_ip_v4,
detected_ip_v6,
);
startup_tracker startup_tracker
.complete_component( .complete_component(
COMPONENT_CONFIG_WATCHER_START, COMPONENT_CONFIG_WATCHER_START,
@ -115,8 +114,7 @@ pub(crate) async fn spawn_runtime_tasks(
break; break;
} }
let cfg = config_rx_policy.borrow_and_update().clone(); let cfg = config_rx_policy.borrow_and_update().clone();
stats_policy stats_policy.apply_telemetry_policy(TelemetryPolicy::from_config(&cfg.general.telemetry));
.apply_telemetry_policy(TelemetryPolicy::from_config(&cfg.general.telemetry));
if let Some(pool) = &me_pool_for_policy { if let Some(pool) = &me_pool_for_policy {
pool.update_runtime_transport_policy( pool.update_runtime_transport_policy(
cfg.general.me_socks_kdf_policy, cfg.general.me_socks_kdf_policy,
@ -132,11 +130,7 @@ pub(crate) async fn spawn_runtime_tasks(
let ip_tracker_policy = ip_tracker.clone(); let ip_tracker_policy = ip_tracker.clone();
let mut config_rx_ip_limits = config_rx.clone(); let mut config_rx_ip_limits = config_rx.clone();
tokio::spawn(async move { tokio::spawn(async move {
let mut prev_limits = config_rx_ip_limits let mut prev_limits = config_rx_ip_limits.borrow().access.user_max_unique_ips.clone();
.borrow()
.access
.user_max_unique_ips
.clone();
let mut prev_global_each = config_rx_ip_limits let mut prev_global_each = config_rx_ip_limits
.borrow() .borrow()
.access .access
@ -189,9 +183,7 @@ pub(crate) async fn spawn_runtime_tasks(
let sleep_secs = cfg.general.beobachten_flush_secs.max(1); let sleep_secs = cfg.general.beobachten_flush_secs.max(1);
if cfg.general.beobachten { if cfg.general.beobachten {
let ttl = std::time::Duration::from_secs( let ttl = std::time::Duration::from_secs(cfg.general.beobachten_minutes.saturating_mul(60));
cfg.general.beobachten_minutes.saturating_mul(60),
);
let path = cfg.general.beobachten_file.clone(); let path = cfg.general.beobachten_file.clone();
let snapshot = beobachten_writer.snapshot_text(ttl); let snapshot = beobachten_writer.snapshot_text(ttl);
if let Err(e) = write_beobachten_snapshot(&path, &snapshot).await { if let Err(e) = write_beobachten_snapshot(&path, &snapshot).await {
@ -235,11 +227,8 @@ pub(crate) async fn spawn_runtime_tasks(
let config_rx_clone_rot = config_rx.clone(); let config_rx_clone_rot = config_rx.clone();
let reinit_tx_rotation = reinit_tx.clone(); let reinit_tx_rotation = reinit_tx.clone();
tokio::spawn(async move { tokio::spawn(async move {
crate::transport::middle_proxy::me_rotation_task( crate::transport::middle_proxy::me_rotation_task(config_rx_clone_rot, reinit_tx_rotation)
config_rx_clone_rot, .await;
reinit_tx_rotation,
)
.await;
}); });
} }

View File

@ -16,11 +16,8 @@ pub(crate) async fn wait_for_shutdown(process_started_at: Instant, me_pool: Opti
let uptime_secs = process_started_at.elapsed().as_secs(); let uptime_secs = process_started_at.elapsed().as_secs();
info!("Uptime: {}", format_uptime(uptime_secs)); info!("Uptime: {}", format_uptime(uptime_secs));
if let Some(pool) = &me_pool { if let Some(pool) = &me_pool {
match tokio::time::timeout( match tokio::time::timeout(Duration::from_secs(2), pool.shutdown_send_close_conn_all())
Duration::from_secs(2), .await
pool.shutdown_send_close_conn_all(),
)
.await
{ {
Ok(total) => { Ok(total) => {
info!( info!(

View File

@ -7,14 +7,11 @@ mod crypto;
mod error; mod error;
mod ip_tracker; mod ip_tracker;
#[cfg(test)] #[cfg(test)]
#[path = "tests/ip_tracker_hotpath_adversarial_tests.rs"]
mod ip_tracker_hotpath_adversarial_tests;
#[cfg(test)]
#[path = "tests/ip_tracker_encapsulation_adversarial_tests.rs"]
mod ip_tracker_encapsulation_adversarial_tests;
#[cfg(test)]
#[path = "tests/ip_tracker_regression_tests.rs"] #[path = "tests/ip_tracker_regression_tests.rs"]
mod ip_tracker_regression_tests; mod ip_tracker_regression_tests;
#[cfg(test)]
#[path = "tests/ip_tracker_hotpath_adversarial_tests.rs"]
mod ip_tracker_hotpath_adversarial_tests;
mod maestro; mod maestro;
mod metrics; mod metrics;
mod network; mod network;

View File

@ -1,5 +1,5 @@
use std::collections::{BTreeSet, HashMap};
use std::convert::Infallible; use std::convert::Infallible;
use std::collections::{BTreeSet, HashMap};
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
@ -11,12 +11,12 @@ use hyper::service::service_fn;
use hyper::{Request, Response, StatusCode}; use hyper::{Request, Response, StatusCode};
use ipnetwork::IpNetwork; use ipnetwork::IpNetwork;
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tracing::{debug, info, warn}; use tracing::{info, warn, debug};
use crate::config::ProxyConfig; use crate::config::ProxyConfig;
use crate::ip_tracker::UserIpTracker; use crate::ip_tracker::UserIpTracker;
use crate::stats::Stats;
use crate::stats::beobachten::BeobachtenStore; use crate::stats::beobachten::BeobachtenStore;
use crate::stats::Stats;
use crate::transport::{ListenOptions, create_listener}; use crate::transport::{ListenOptions, create_listener};
pub async fn serve( pub async fn serve(
@ -62,10 +62,7 @@ pub async fn serve(
let addr_v4 = SocketAddr::from(([0, 0, 0, 0], port)); let addr_v4 = SocketAddr::from(([0, 0, 0, 0], port));
match bind_metrics_listener(addr_v4, false) { match bind_metrics_listener(addr_v4, false) {
Ok(listener) => { Ok(listener) => {
info!( info!("Metrics endpoint: http://{}/metrics and /beobachten", addr_v4);
"Metrics endpoint: http://{}/metrics and /beobachten",
addr_v4
);
listener_v4 = Some(listener); listener_v4 = Some(listener);
} }
Err(e) => { Err(e) => {
@ -76,10 +73,7 @@ pub async fn serve(
let addr_v6 = SocketAddr::from(([0, 0, 0, 0, 0, 0, 0, 0], port)); let addr_v6 = SocketAddr::from(([0, 0, 0, 0, 0, 0, 0, 0], port));
match bind_metrics_listener(addr_v6, true) { match bind_metrics_listener(addr_v6, true) {
Ok(listener) => { Ok(listener) => {
info!( info!("Metrics endpoint: http://[::]:{}/metrics and /beobachten", port);
"Metrics endpoint: http://[::]:{}/metrics and /beobachten",
port
);
listener_v6 = Some(listener); listener_v6 = Some(listener);
} }
Err(e) => { Err(e) => {
@ -115,7 +109,12 @@ pub async fn serve(
.await; .await;
}); });
serve_listener( serve_listener(
listener4, stats, beobachten, ip_tracker, config_rx, whitelist, listener4,
stats,
beobachten,
ip_tracker,
config_rx,
whitelist,
) )
.await; .await;
} }
@ -232,10 +231,7 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
let _ = writeln!(out, "# TYPE telemt_uptime_seconds gauge"); let _ = writeln!(out, "# TYPE telemt_uptime_seconds gauge");
let _ = writeln!(out, "telemt_uptime_seconds {:.1}", stats.uptime_secs()); let _ = writeln!(out, "telemt_uptime_seconds {:.1}", stats.uptime_secs());
let _ = writeln!( let _ = writeln!(out, "# HELP telemt_telemetry_core_enabled Runtime core telemetry switch");
out,
"# HELP telemt_telemetry_core_enabled Runtime core telemetry switch"
);
let _ = writeln!(out, "# TYPE telemt_telemetry_core_enabled gauge"); let _ = writeln!(out, "# TYPE telemt_telemetry_core_enabled gauge");
let _ = writeln!( let _ = writeln!(
out, out,
@ -243,10 +239,7 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
if core_enabled { 1 } else { 0 } if core_enabled { 1 } else { 0 }
); );
let _ = writeln!( let _ = writeln!(out, "# HELP telemt_telemetry_user_enabled Runtime per-user telemetry switch");
out,
"# HELP telemt_telemetry_user_enabled Runtime per-user telemetry switch"
);
let _ = writeln!(out, "# TYPE telemt_telemetry_user_enabled gauge"); let _ = writeln!(out, "# TYPE telemt_telemetry_user_enabled gauge");
let _ = writeln!( let _ = writeln!(
out, out,
@ -254,10 +247,7 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
if user_enabled { 1 } else { 0 } if user_enabled { 1 } else { 0 }
); );
let _ = writeln!( let _ = writeln!(out, "# HELP telemt_telemetry_me_level Runtime ME telemetry level flag");
out,
"# HELP telemt_telemetry_me_level Runtime ME telemetry level flag"
);
let _ = writeln!(out, "# TYPE telemt_telemetry_me_level gauge"); let _ = writeln!(out, "# TYPE telemt_telemetry_me_level gauge");
let _ = writeln!( let _ = writeln!(
out, out,
@ -287,40 +277,23 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
} }
); );
let _ = writeln!( let _ = writeln!(out, "# HELP telemt_connections_total Total accepted connections");
out,
"# HELP telemt_connections_total Total accepted connections"
);
let _ = writeln!(out, "# TYPE telemt_connections_total counter"); let _ = writeln!(out, "# TYPE telemt_connections_total counter");
let _ = writeln!( let _ = writeln!(
out, out,
"telemt_connections_total {}", "telemt_connections_total {}",
if core_enabled { if core_enabled { stats.get_connects_all() } else { 0 }
stats.get_connects_all()
} else {
0
}
); );
let _ = writeln!( let _ = writeln!(out, "# HELP telemt_connections_bad_total Bad/rejected connections");
out,
"# HELP telemt_connections_bad_total Bad/rejected connections"
);
let _ = writeln!(out, "# TYPE telemt_connections_bad_total counter"); let _ = writeln!(out, "# TYPE telemt_connections_bad_total counter");
let _ = writeln!( let _ = writeln!(
out, out,
"telemt_connections_bad_total {}", "telemt_connections_bad_total {}",
if core_enabled { if core_enabled { stats.get_connects_bad() } else { 0 }
stats.get_connects_bad()
} else {
0
}
); );
let _ = writeln!( let _ = writeln!(out, "# HELP telemt_handshake_timeouts_total Handshake timeouts");
out,
"# HELP telemt_handshake_timeouts_total Handshake timeouts"
);
let _ = writeln!(out, "# TYPE telemt_handshake_timeouts_total counter"); let _ = writeln!(out, "# TYPE telemt_handshake_timeouts_total counter");
let _ = writeln!( let _ = writeln!(
out, out,
@ -399,10 +372,7 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
out, out,
"# HELP telemt_upstream_connect_attempts_per_request Histogram-like buckets for attempts per upstream connect request cycle" "# HELP telemt_upstream_connect_attempts_per_request Histogram-like buckets for attempts per upstream connect request cycle"
); );
let _ = writeln!( let _ = writeln!(out, "# TYPE telemt_upstream_connect_attempts_per_request counter");
out,
"# TYPE telemt_upstream_connect_attempts_per_request counter"
);
let _ = writeln!( let _ = writeln!(
out, out,
"telemt_upstream_connect_attempts_per_request{{bucket=\"1\"}} {}", "telemt_upstream_connect_attempts_per_request{{bucket=\"1\"}} {}",
@ -444,10 +414,7 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
out, out,
"# HELP telemt_upstream_connect_duration_success_total Histogram-like buckets of successful upstream connect cycle duration" "# HELP telemt_upstream_connect_duration_success_total Histogram-like buckets of successful upstream connect cycle duration"
); );
let _ = writeln!( let _ = writeln!(out, "# TYPE telemt_upstream_connect_duration_success_total counter");
out,
"# TYPE telemt_upstream_connect_duration_success_total counter"
);
let _ = writeln!( let _ = writeln!(
out, out,
"telemt_upstream_connect_duration_success_total{{bucket=\"le_100ms\"}} {}", "telemt_upstream_connect_duration_success_total{{bucket=\"le_100ms\"}} {}",
@ -489,10 +456,7 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
out, out,
"# HELP telemt_upstream_connect_duration_fail_total Histogram-like buckets of failed upstream connect cycle duration" "# HELP telemt_upstream_connect_duration_fail_total Histogram-like buckets of failed upstream connect cycle duration"
); );
let _ = writeln!( let _ = writeln!(out, "# TYPE telemt_upstream_connect_duration_fail_total counter");
out,
"# TYPE telemt_upstream_connect_duration_fail_total counter"
);
let _ = writeln!( let _ = writeln!(
out, out,
"telemt_upstream_connect_duration_fail_total{{bucket=\"le_100ms\"}} {}", "telemt_upstream_connect_duration_fail_total{{bucket=\"le_100ms\"}} {}",
@ -530,10 +494,7 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
} }
); );
let _ = writeln!( let _ = writeln!(out, "# HELP telemt_me_keepalive_sent_total ME keepalive frames sent");
out,
"# HELP telemt_me_keepalive_sent_total ME keepalive frames sent"
);
let _ = writeln!(out, "# TYPE telemt_me_keepalive_sent_total counter"); let _ = writeln!(out, "# TYPE telemt_me_keepalive_sent_total counter");
let _ = writeln!( let _ = writeln!(
out, out,
@ -545,10 +506,7 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
} }
); );
let _ = writeln!( let _ = writeln!(out, "# HELP telemt_me_keepalive_failed_total ME keepalive send failures");
out,
"# HELP telemt_me_keepalive_failed_total ME keepalive send failures"
);
let _ = writeln!(out, "# TYPE telemt_me_keepalive_failed_total counter"); let _ = writeln!(out, "# TYPE telemt_me_keepalive_failed_total counter");
let _ = writeln!( let _ = writeln!(
out, out,
@ -560,10 +518,7 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
} }
); );
let _ = writeln!( let _ = writeln!(out, "# HELP telemt_me_keepalive_pong_total ME keepalive pong replies");
out,
"# HELP telemt_me_keepalive_pong_total ME keepalive pong replies"
);
let _ = writeln!(out, "# TYPE telemt_me_keepalive_pong_total counter"); let _ = writeln!(out, "# TYPE telemt_me_keepalive_pong_total counter");
let _ = writeln!( let _ = writeln!(
out, out,
@ -575,10 +530,7 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
} }
); );
let _ = writeln!( let _ = writeln!(out, "# HELP telemt_me_keepalive_timeout_total ME keepalive ping timeouts");
out,
"# HELP telemt_me_keepalive_timeout_total ME keepalive ping timeouts"
);
let _ = writeln!(out, "# TYPE telemt_me_keepalive_timeout_total counter"); let _ = writeln!(out, "# TYPE telemt_me_keepalive_timeout_total counter");
let _ = writeln!( let _ = writeln!(
out, out,
@ -594,10 +546,7 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
out, out,
"# HELP telemt_me_rpc_proxy_req_signal_sent_total Service RPC_PROXY_REQ activity signals sent" "# HELP telemt_me_rpc_proxy_req_signal_sent_total Service RPC_PROXY_REQ activity signals sent"
); );
let _ = writeln!( let _ = writeln!(out, "# TYPE telemt_me_rpc_proxy_req_signal_sent_total counter");
out,
"# TYPE telemt_me_rpc_proxy_req_signal_sent_total counter"
);
let _ = writeln!( let _ = writeln!(
out, out,
"telemt_me_rpc_proxy_req_signal_sent_total {}", "telemt_me_rpc_proxy_req_signal_sent_total {}",
@ -680,10 +629,7 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
} }
); );
let _ = writeln!( let _ = writeln!(out, "# HELP telemt_me_reconnect_attempts_total ME reconnect attempts");
out,
"# HELP telemt_me_reconnect_attempts_total ME reconnect attempts"
);
let _ = writeln!(out, "# TYPE telemt_me_reconnect_attempts_total counter"); let _ = writeln!(out, "# TYPE telemt_me_reconnect_attempts_total counter");
let _ = writeln!( let _ = writeln!(
out, out,
@ -695,10 +641,7 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
} }
); );
let _ = writeln!( let _ = writeln!(out, "# HELP telemt_me_reconnect_success_total ME reconnect successes");
out,
"# HELP telemt_me_reconnect_success_total ME reconnect successes"
);
let _ = writeln!(out, "# TYPE telemt_me_reconnect_success_total counter"); let _ = writeln!(out, "# TYPE telemt_me_reconnect_success_total counter");
let _ = writeln!( let _ = writeln!(
out, out,
@ -710,10 +653,7 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
} }
); );
let _ = writeln!( let _ = writeln!(out, "# HELP telemt_me_handshake_reject_total ME handshake rejects from upstream");
out,
"# HELP telemt_me_handshake_reject_total ME handshake rejects from upstream"
);
let _ = writeln!(out, "# TYPE telemt_me_handshake_reject_total counter"); let _ = writeln!(out, "# TYPE telemt_me_handshake_reject_total counter");
let _ = writeln!( let _ = writeln!(
out, out,
@ -725,25 +665,20 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
} }
); );
let _ = writeln!( let _ = writeln!(out, "# HELP telemt_me_handshake_error_code_total ME handshake reject errors by code");
out,
"# HELP telemt_me_handshake_error_code_total ME handshake reject errors by code"
);
let _ = writeln!(out, "# TYPE telemt_me_handshake_error_code_total counter"); let _ = writeln!(out, "# TYPE telemt_me_handshake_error_code_total counter");
if me_allows_normal { if me_allows_normal {
for (error_code, count) in stats.get_me_handshake_error_code_counts() { for (error_code, count) in stats.get_me_handshake_error_code_counts() {
let _ = writeln!( let _ = writeln!(
out, out,
"telemt_me_handshake_error_code_total{{error_code=\"{}\"}} {}", "telemt_me_handshake_error_code_total{{error_code=\"{}\"}} {}",
error_code, count error_code,
count
); );
} }
} }
let _ = writeln!( let _ = writeln!(out, "# HELP telemt_me_reader_eof_total ME reader EOF terminations");
out,
"# HELP telemt_me_reader_eof_total ME reader EOF terminations"
);
let _ = writeln!(out, "# TYPE telemt_me_reader_eof_total counter"); let _ = writeln!(out, "# TYPE telemt_me_reader_eof_total counter");
let _ = writeln!( let _ = writeln!(
out, out,
@ -845,10 +780,7 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
} }
); );
let _ = writeln!( let _ = writeln!(out, "# HELP telemt_me_seq_mismatch_total ME sequence mismatches");
out,
"# HELP telemt_me_seq_mismatch_total ME sequence mismatches"
);
let _ = writeln!(out, "# TYPE telemt_me_seq_mismatch_total counter"); let _ = writeln!(out, "# TYPE telemt_me_seq_mismatch_total counter");
let _ = writeln!( let _ = writeln!(
out, out,
@ -860,10 +792,7 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
} }
); );
let _ = writeln!( let _ = writeln!(out, "# HELP telemt_me_route_drop_no_conn_total ME route drops: no conn");
out,
"# HELP telemt_me_route_drop_no_conn_total ME route drops: no conn"
);
let _ = writeln!(out, "# TYPE telemt_me_route_drop_no_conn_total counter"); let _ = writeln!(out, "# TYPE telemt_me_route_drop_no_conn_total counter");
let _ = writeln!( let _ = writeln!(
out, out,
@ -875,14 +804,8 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
} }
); );
let _ = writeln!( let _ = writeln!(out, "# HELP telemt_me_route_drop_channel_closed_total ME route drops: channel closed");
out, let _ = writeln!(out, "# TYPE telemt_me_route_drop_channel_closed_total counter");
"# HELP telemt_me_route_drop_channel_closed_total ME route drops: channel closed"
);
let _ = writeln!(
out,
"# TYPE telemt_me_route_drop_channel_closed_total counter"
);
let _ = writeln!( let _ = writeln!(
out, out,
"telemt_me_route_drop_channel_closed_total {}", "telemt_me_route_drop_channel_closed_total {}",
@ -893,10 +816,7 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
} }
); );
let _ = writeln!( let _ = writeln!(out, "# HELP telemt_me_route_drop_queue_full_total ME route drops: queue full");
out,
"# HELP telemt_me_route_drop_queue_full_total ME route drops: queue full"
);
let _ = writeln!(out, "# TYPE telemt_me_route_drop_queue_full_total counter"); let _ = writeln!(out, "# TYPE telemt_me_route_drop_queue_full_total counter");
let _ = writeln!( let _ = writeln!(
out, out,
@ -1053,10 +973,7 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
out, out,
"# HELP telemt_me_writer_pick_mode_switch_total Writer-pick mode switches via runtime updates" "# HELP telemt_me_writer_pick_mode_switch_total Writer-pick mode switches via runtime updates"
); );
let _ = writeln!( let _ = writeln!(out, "# TYPE telemt_me_writer_pick_mode_switch_total counter");
out,
"# TYPE telemt_me_writer_pick_mode_switch_total counter"
);
let _ = writeln!( let _ = writeln!(
out, out,
"telemt_me_writer_pick_mode_switch_total {}", "telemt_me_writer_pick_mode_switch_total {}",
@ -1106,10 +1023,7 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
} }
); );
let _ = writeln!( let _ = writeln!(out, "# HELP telemt_me_kdf_drift_total ME KDF input drift detections");
out,
"# HELP telemt_me_kdf_drift_total ME KDF input drift detections"
);
let _ = writeln!(out, "# TYPE telemt_me_kdf_drift_total counter"); let _ = writeln!(out, "# TYPE telemt_me_kdf_drift_total counter");
let _ = writeln!( let _ = writeln!(
out, out,
@ -1155,10 +1069,7 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
out, out,
"# HELP telemt_me_hardswap_pending_ttl_expired_total Pending hardswap generations reset by TTL expiration" "# HELP telemt_me_hardswap_pending_ttl_expired_total Pending hardswap generations reset by TTL expiration"
); );
let _ = writeln!( let _ = writeln!(out, "# TYPE telemt_me_hardswap_pending_ttl_expired_total counter");
out,
"# TYPE telemt_me_hardswap_pending_ttl_expired_total counter"
);
let _ = writeln!( let _ = writeln!(
out, out,
"telemt_me_hardswap_pending_ttl_expired_total {}", "telemt_me_hardswap_pending_ttl_expired_total {}",
@ -1390,7 +1301,10 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
out, out,
"# HELP telemt_me_adaptive_floor_global_cap_raw Runtime raw global adaptive floor cap" "# HELP telemt_me_adaptive_floor_global_cap_raw Runtime raw global adaptive floor cap"
); );
let _ = writeln!(out, "# TYPE telemt_me_adaptive_floor_global_cap_raw gauge"); let _ = writeln!(
out,
"# TYPE telemt_me_adaptive_floor_global_cap_raw gauge"
);
let _ = writeln!( let _ = writeln!(
out, out,
"telemt_me_adaptive_floor_global_cap_raw {}", "telemt_me_adaptive_floor_global_cap_raw {}",
@ -1573,10 +1487,7 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
} }
); );
let _ = writeln!( let _ = writeln!(out, "# HELP telemt_secure_padding_invalid_total Invalid secure frame lengths");
out,
"# HELP telemt_secure_padding_invalid_total Invalid secure frame lengths"
);
let _ = writeln!(out, "# TYPE telemt_secure_padding_invalid_total counter"); let _ = writeln!(out, "# TYPE telemt_secure_padding_invalid_total counter");
let _ = writeln!( let _ = writeln!(
out, out,
@ -1588,10 +1499,7 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
} }
); );
let _ = writeln!( let _ = writeln!(out, "# HELP telemt_desync_total Total crypto-desync detections");
out,
"# HELP telemt_desync_total Total crypto-desync detections"
);
let _ = writeln!(out, "# TYPE telemt_desync_total counter"); let _ = writeln!(out, "# TYPE telemt_desync_total counter");
let _ = writeln!( let _ = writeln!(
out, out,
@ -1603,10 +1511,7 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
} }
); );
let _ = writeln!( let _ = writeln!(out, "# HELP telemt_desync_full_logged_total Full forensic desync logs emitted");
out,
"# HELP telemt_desync_full_logged_total Full forensic desync logs emitted"
);
let _ = writeln!(out, "# TYPE telemt_desync_full_logged_total counter"); let _ = writeln!(out, "# TYPE telemt_desync_full_logged_total counter");
let _ = writeln!( let _ = writeln!(
out, out,
@ -1618,10 +1523,7 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
} }
); );
let _ = writeln!( let _ = writeln!(out, "# HELP telemt_desync_suppressed_total Suppressed desync forensic events");
out,
"# HELP telemt_desync_suppressed_total Suppressed desync forensic events"
);
let _ = writeln!(out, "# TYPE telemt_desync_suppressed_total counter"); let _ = writeln!(out, "# TYPE telemt_desync_suppressed_total counter");
let _ = writeln!( let _ = writeln!(
out, out,
@ -1633,10 +1535,7 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
} }
); );
let _ = writeln!( let _ = writeln!(out, "# HELP telemt_desync_frames_bucket_total Desync count by frames_ok bucket");
out,
"# HELP telemt_desync_frames_bucket_total Desync count by frames_ok bucket"
);
let _ = writeln!(out, "# TYPE telemt_desync_frames_bucket_total counter"); let _ = writeln!(out, "# TYPE telemt_desync_frames_bucket_total counter");
let _ = writeln!( let _ = writeln!(
out, out,
@ -1675,10 +1574,7 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
} }
); );
let _ = writeln!( let _ = writeln!(out, "# HELP telemt_pool_swap_total Successful ME pool swaps");
out,
"# HELP telemt_pool_swap_total Successful ME pool swaps"
);
let _ = writeln!(out, "# TYPE telemt_pool_swap_total counter"); let _ = writeln!(out, "# TYPE telemt_pool_swap_total counter");
let _ = writeln!( let _ = writeln!(
out, out,
@ -1690,10 +1586,7 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
} }
); );
let _ = writeln!( let _ = writeln!(out, "# HELP telemt_pool_drain_active Active draining ME writers");
out,
"# HELP telemt_pool_drain_active Active draining ME writers"
);
let _ = writeln!(out, "# TYPE telemt_pool_drain_active gauge"); let _ = writeln!(out, "# TYPE telemt_pool_drain_active gauge");
let _ = writeln!( let _ = writeln!(
out, out,
@ -1705,10 +1598,7 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
} }
); );
let _ = writeln!( let _ = writeln!(out, "# HELP telemt_pool_force_close_total Forced close events for draining writers");
out,
"# HELP telemt_pool_force_close_total Forced close events for draining writers"
);
let _ = writeln!(out, "# TYPE telemt_pool_force_close_total counter"); let _ = writeln!(out, "# TYPE telemt_pool_force_close_total counter");
let _ = writeln!( let _ = writeln!(
out, out,
@ -1720,10 +1610,7 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
} }
); );
let _ = writeln!( let _ = writeln!(out, "# HELP telemt_pool_stale_pick_total Stale writer fallback picks for new binds");
out,
"# HELP telemt_pool_stale_pick_total Stale writer fallback picks for new binds"
);
let _ = writeln!(out, "# TYPE telemt_pool_stale_pick_total counter"); let _ = writeln!(out, "# TYPE telemt_pool_stale_pick_total counter");
let _ = writeln!( let _ = writeln!(
out, out,
@ -1735,10 +1622,7 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
} }
); );
let _ = writeln!( let _ = writeln!(out, "# HELP telemt_me_writer_removed_total Total ME writer removals");
out,
"# HELP telemt_me_writer_removed_total Total ME writer removals"
);
let _ = writeln!(out, "# TYPE telemt_me_writer_removed_total counter"); let _ = writeln!(out, "# TYPE telemt_me_writer_removed_total counter");
let _ = writeln!( let _ = writeln!(
out, out,
@ -1754,10 +1638,7 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
out, out,
"# HELP telemt_me_writer_removed_unexpected_total Unexpected ME writer removals that triggered refill" "# HELP telemt_me_writer_removed_unexpected_total Unexpected ME writer removals that triggered refill"
); );
let _ = writeln!( let _ = writeln!(out, "# TYPE telemt_me_writer_removed_unexpected_total counter");
out,
"# TYPE telemt_me_writer_removed_unexpected_total counter"
);
let _ = writeln!( let _ = writeln!(
out, out,
"telemt_me_writer_removed_unexpected_total {}", "telemt_me_writer_removed_unexpected_total {}",
@ -1768,10 +1649,7 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
} }
); );
let _ = writeln!( let _ = writeln!(out, "# HELP telemt_me_refill_triggered_total Immediate ME refill runs started");
out,
"# HELP telemt_me_refill_triggered_total Immediate ME refill runs started"
);
let _ = writeln!(out, "# TYPE telemt_me_refill_triggered_total counter"); let _ = writeln!(out, "# TYPE telemt_me_refill_triggered_total counter");
let _ = writeln!( let _ = writeln!(
out, out,
@ -1787,10 +1665,7 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
out, out,
"# HELP telemt_me_refill_skipped_inflight_total Immediate ME refill skips due to inflight dedup" "# HELP telemt_me_refill_skipped_inflight_total Immediate ME refill skips due to inflight dedup"
); );
let _ = writeln!( let _ = writeln!(out, "# TYPE telemt_me_refill_skipped_inflight_total counter");
out,
"# TYPE telemt_me_refill_skipped_inflight_total counter"
);
let _ = writeln!( let _ = writeln!(
out, out,
"telemt_me_refill_skipped_inflight_total {}", "telemt_me_refill_skipped_inflight_total {}",
@ -1801,10 +1676,7 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
} }
); );
let _ = writeln!( let _ = writeln!(out, "# HELP telemt_me_refill_failed_total Immediate ME refill failures");
out,
"# HELP telemt_me_refill_failed_total Immediate ME refill failures"
);
let _ = writeln!(out, "# TYPE telemt_me_refill_failed_total counter"); let _ = writeln!(out, "# TYPE telemt_me_refill_failed_total counter");
let _ = writeln!( let _ = writeln!(
out, out,
@ -1820,10 +1692,7 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
out, out,
"# HELP telemt_me_writer_restored_same_endpoint_total Refilled ME writer restored on the same endpoint" "# HELP telemt_me_writer_restored_same_endpoint_total Refilled ME writer restored on the same endpoint"
); );
let _ = writeln!( let _ = writeln!(out, "# TYPE telemt_me_writer_restored_same_endpoint_total counter");
out,
"# TYPE telemt_me_writer_restored_same_endpoint_total counter"
);
let _ = writeln!( let _ = writeln!(
out, out,
"telemt_me_writer_restored_same_endpoint_total {}", "telemt_me_writer_restored_same_endpoint_total {}",
@ -1838,10 +1707,7 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
out, out,
"# HELP telemt_me_writer_restored_fallback_total Refilled ME writer restored via fallback endpoint" "# HELP telemt_me_writer_restored_fallback_total Refilled ME writer restored via fallback endpoint"
); );
let _ = writeln!( let _ = writeln!(out, "# TYPE telemt_me_writer_restored_fallback_total counter");
out,
"# TYPE telemt_me_writer_restored_fallback_total counter"
);
let _ = writeln!( let _ = writeln!(
out, out,
"telemt_me_writer_restored_fallback_total {}", "telemt_me_writer_restored_fallback_total {}",
@ -1919,35 +1785,17 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
unresolved_writer_losses unresolved_writer_losses
); );
let _ = writeln!( let _ = writeln!(out, "# HELP telemt_user_connections_total Per-user total connections");
out,
"# HELP telemt_user_connections_total Per-user total connections"
);
let _ = writeln!(out, "# TYPE telemt_user_connections_total counter"); let _ = writeln!(out, "# TYPE telemt_user_connections_total counter");
let _ = writeln!( let _ = writeln!(out, "# HELP telemt_user_connections_current Per-user active connections");
out,
"# HELP telemt_user_connections_current Per-user active connections"
);
let _ = writeln!(out, "# TYPE telemt_user_connections_current gauge"); let _ = writeln!(out, "# TYPE telemt_user_connections_current gauge");
let _ = writeln!( let _ = writeln!(out, "# HELP telemt_user_octets_from_client Per-user bytes received");
out,
"# HELP telemt_user_octets_from_client Per-user bytes received"
);
let _ = writeln!(out, "# TYPE telemt_user_octets_from_client counter"); let _ = writeln!(out, "# TYPE telemt_user_octets_from_client counter");
let _ = writeln!( let _ = writeln!(out, "# HELP telemt_user_octets_to_client Per-user bytes sent");
out,
"# HELP telemt_user_octets_to_client Per-user bytes sent"
);
let _ = writeln!(out, "# TYPE telemt_user_octets_to_client counter"); let _ = writeln!(out, "# TYPE telemt_user_octets_to_client counter");
let _ = writeln!( let _ = writeln!(out, "# HELP telemt_user_msgs_from_client Per-user messages received");
out,
"# HELP telemt_user_msgs_from_client Per-user messages received"
);
let _ = writeln!(out, "# TYPE telemt_user_msgs_from_client counter"); let _ = writeln!(out, "# TYPE telemt_user_msgs_from_client counter");
let _ = writeln!( let _ = writeln!(out, "# HELP telemt_user_msgs_to_client Per-user messages sent");
out,
"# HELP telemt_user_msgs_to_client Per-user messages sent"
);
let _ = writeln!(out, "# TYPE telemt_user_msgs_to_client counter"); let _ = writeln!(out, "# TYPE telemt_user_msgs_to_client counter");
let _ = writeln!( let _ = writeln!(
out, out,
@ -1987,45 +1835,12 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
for entry in stats.iter_user_stats() { for entry in stats.iter_user_stats() {
let user = entry.key(); let user = entry.key();
let s = entry.value(); let s = entry.value();
let _ = writeln!( let _ = writeln!(out, "telemt_user_connections_total{{user=\"{}\"}} {}", user, s.connects.load(std::sync::atomic::Ordering::Relaxed));
out, let _ = writeln!(out, "telemt_user_connections_current{{user=\"{}\"}} {}", user, s.curr_connects.load(std::sync::atomic::Ordering::Relaxed));
"telemt_user_connections_total{{user=\"{}\"}} {}", let _ = writeln!(out, "telemt_user_octets_from_client{{user=\"{}\"}} {}", user, s.octets_from_client.load(std::sync::atomic::Ordering::Relaxed));
user, let _ = writeln!(out, "telemt_user_octets_to_client{{user=\"{}\"}} {}", user, s.octets_to_client.load(std::sync::atomic::Ordering::Relaxed));
s.connects.load(std::sync::atomic::Ordering::Relaxed) let _ = writeln!(out, "telemt_user_msgs_from_client{{user=\"{}\"}} {}", user, s.msgs_from_client.load(std::sync::atomic::Ordering::Relaxed));
); let _ = writeln!(out, "telemt_user_msgs_to_client{{user=\"{}\"}} {}", user, s.msgs_to_client.load(std::sync::atomic::Ordering::Relaxed));
let _ = writeln!(
out,
"telemt_user_connections_current{{user=\"{}\"}} {}",
user,
s.curr_connects.load(std::sync::atomic::Ordering::Relaxed)
);
let _ = writeln!(
out,
"telemt_user_octets_from_client{{user=\"{}\"}} {}",
user,
s.octets_from_client
.load(std::sync::atomic::Ordering::Relaxed)
);
let _ = writeln!(
out,
"telemt_user_octets_to_client{{user=\"{}\"}} {}",
user,
s.octets_to_client
.load(std::sync::atomic::Ordering::Relaxed)
);
let _ = writeln!(
out,
"telemt_user_msgs_from_client{{user=\"{}\"}} {}",
user,
s.msgs_from_client
.load(std::sync::atomic::Ordering::Relaxed)
);
let _ = writeln!(
out,
"telemt_user_msgs_to_client{{user=\"{}\"}} {}",
user,
s.msgs_to_client.load(std::sync::atomic::Ordering::Relaxed)
);
} }
let ip_stats = ip_tracker.get_stats().await; let ip_stats = ip_tracker.get_stats().await;
@ -2043,25 +1858,16 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
.get_recent_counts_for_users(&unique_users_vec) .get_recent_counts_for_users(&unique_users_vec)
.await; .await;
let _ = writeln!( let _ = writeln!(out, "# HELP telemt_user_unique_ips_current Per-user current number of unique active IPs");
out,
"# HELP telemt_user_unique_ips_current Per-user current number of unique active IPs"
);
let _ = writeln!(out, "# TYPE telemt_user_unique_ips_current gauge"); let _ = writeln!(out, "# TYPE telemt_user_unique_ips_current gauge");
let _ = writeln!( let _ = writeln!(
out, out,
"# HELP telemt_user_unique_ips_recent_window Per-user unique IPs seen in configured observation window" "# HELP telemt_user_unique_ips_recent_window Per-user unique IPs seen in configured observation window"
); );
let _ = writeln!(out, "# TYPE telemt_user_unique_ips_recent_window gauge"); let _ = writeln!(out, "# TYPE telemt_user_unique_ips_recent_window gauge");
let _ = writeln!( let _ = writeln!(out, "# HELP telemt_user_unique_ips_limit Effective per-user unique IP limit (0 means unlimited)");
out,
"# HELP telemt_user_unique_ips_limit Effective per-user unique IP limit (0 means unlimited)"
);
let _ = writeln!(out, "# TYPE telemt_user_unique_ips_limit gauge"); let _ = writeln!(out, "# TYPE telemt_user_unique_ips_limit gauge");
let _ = writeln!( let _ = writeln!(out, "# HELP telemt_user_unique_ips_utilization Per-user unique IP usage ratio (0 for unlimited)");
out,
"# HELP telemt_user_unique_ips_utilization Per-user unique IP usage ratio (0 for unlimited)"
);
let _ = writeln!(out, "# TYPE telemt_user_unique_ips_utilization gauge"); let _ = writeln!(out, "# TYPE telemt_user_unique_ips_utilization gauge");
for user in unique_users { for user in unique_users {
@ -2072,34 +1878,29 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
.get(&user) .get(&user)
.copied() .copied()
.filter(|limit| *limit > 0) .filter(|limit| *limit > 0)
.or((config.access.user_max_unique_ips_global_each > 0) .or(
.then_some(config.access.user_max_unique_ips_global_each)) (config.access.user_max_unique_ips_global_each > 0)
.then_some(config.access.user_max_unique_ips_global_each),
)
.unwrap_or(0); .unwrap_or(0);
let utilization = if limit > 0 { let utilization = if limit > 0 {
current as f64 / limit as f64 current as f64 / limit as f64
} else { } else {
0.0 0.0
}; };
let _ = writeln!( let _ = writeln!(out, "telemt_user_unique_ips_current{{user=\"{}\"}} {}", user, current);
out,
"telemt_user_unique_ips_current{{user=\"{}\"}} {}",
user, current
);
let _ = writeln!( let _ = writeln!(
out, out,
"telemt_user_unique_ips_recent_window{{user=\"{}\"}} {}", "telemt_user_unique_ips_recent_window{{user=\"{}\"}} {}",
user, user,
recent_counts.get(&user).copied().unwrap_or(0) recent_counts.get(&user).copied().unwrap_or(0)
); );
let _ = writeln!( let _ = writeln!(out, "telemt_user_unique_ips_limit{{user=\"{}\"}} {}", user, limit);
out,
"telemt_user_unique_ips_limit{{user=\"{}\"}} {}",
user, limit
);
let _ = writeln!( let _ = writeln!(
out, out,
"telemt_user_unique_ips_utilization{{user=\"{}\"}} {:.6}", "telemt_user_unique_ips_utilization{{user=\"{}\"}} {:.6}",
user, utilization user,
utilization
); );
} }
} }
@ -2110,8 +1911,8 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use http_body_util::BodyExt;
use std::net::IpAddr; use std::net::IpAddr;
use http_body_util::BodyExt;
#[tokio::test] #[tokio::test]
async fn test_render_metrics_format() { async fn test_render_metrics_format() {
@ -2166,10 +1967,13 @@ mod tests {
assert!(output.contains("telemt_upstream_connect_success_total 1")); assert!(output.contains("telemt_upstream_connect_success_total 1"));
assert!(output.contains("telemt_upstream_connect_fail_total 1")); assert!(output.contains("telemt_upstream_connect_fail_total 1"));
assert!(output.contains("telemt_upstream_connect_failfast_hard_error_total 1")); assert!(output.contains("telemt_upstream_connect_failfast_hard_error_total 1"));
assert!(output.contains("telemt_upstream_connect_attempts_per_request{bucket=\"2\"} 1"));
assert!( assert!(
output output.contains("telemt_upstream_connect_attempts_per_request{bucket=\"2\"} 1")
.contains("telemt_upstream_connect_duration_success_total{bucket=\"101_500ms\"} 1") );
assert!(
output.contains(
"telemt_upstream_connect_duration_success_total{bucket=\"101_500ms\"} 1"
)
); );
assert!( assert!(
output.contains("telemt_upstream_connect_duration_fail_total{bucket=\"gt_1000ms\"} 1") output.contains("telemt_upstream_connect_duration_fail_total{bucket=\"gt_1000ms\"} 1")
@ -2246,10 +2050,9 @@ mod tests {
assert!(output.contains("# TYPE telemt_relay_pressure_evict_total counter")); assert!(output.contains("# TYPE telemt_relay_pressure_evict_total counter"));
assert!(output.contains("# TYPE telemt_relay_protocol_desync_close_total counter")); assert!(output.contains("# TYPE telemt_relay_protocol_desync_close_total counter"));
assert!(output.contains("# TYPE telemt_me_writer_removed_total counter")); assert!(output.contains("# TYPE telemt_me_writer_removed_total counter"));
assert!( assert!(output.contains(
output "# TYPE telemt_me_writer_removed_unexpected_minus_restored_total gauge"
.contains("# TYPE telemt_me_writer_removed_unexpected_minus_restored_total gauge") ));
);
assert!(output.contains("# TYPE telemt_user_unique_ips_current gauge")); assert!(output.contains("# TYPE telemt_user_unique_ips_current gauge"));
assert!(output.contains("# TYPE telemt_user_unique_ips_recent_window gauge")); assert!(output.contains("# TYPE telemt_user_unique_ips_recent_window gauge"));
assert!(output.contains("# TYPE telemt_user_unique_ips_limit gauge")); assert!(output.contains("# TYPE telemt_user_unique_ips_limit gauge"));
@ -2266,17 +2069,14 @@ mod tests {
stats.increment_connects_all(); stats.increment_connects_all();
stats.increment_connects_all(); stats.increment_connects_all();
let req = Request::builder().uri("/metrics").body(()).unwrap(); let req = Request::builder()
let resp = handle(req, &stats, &beobachten, &tracker, &config) .uri("/metrics")
.await .body(())
.unwrap(); .unwrap();
let resp = handle(req, &stats, &beobachten, &tracker, &config).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
let body = resp.into_body().collect().await.unwrap().to_bytes(); let body = resp.into_body().collect().await.unwrap().to_bytes();
assert!( assert!(std::str::from_utf8(body.as_ref()).unwrap().contains("telemt_connections_total 3"));
std::str::from_utf8(body.as_ref())
.unwrap()
.contains("telemt_connections_total 3")
);
config.general.beobachten = true; config.general.beobachten = true;
config.general.beobachten_minutes = 10; config.general.beobachten_minutes = 10;
@ -2285,7 +2085,10 @@ mod tests {
"203.0.113.10".parse::<IpAddr>().unwrap(), "203.0.113.10".parse::<IpAddr>().unwrap(),
Duration::from_secs(600), Duration::from_secs(600),
); );
let req_beob = Request::builder().uri("/beobachten").body(()).unwrap(); let req_beob = Request::builder()
.uri("/beobachten")
.body(())
.unwrap();
let resp_beob = handle(req_beob, &stats, &beobachten, &tracker, &config) let resp_beob = handle(req_beob, &stats, &beobachten, &tracker, &config)
.await .await
.unwrap(); .unwrap();
@ -2295,7 +2098,10 @@ mod tests {
assert!(beob_text.contains("[TLS-scanner]")); assert!(beob_text.contains("[TLS-scanner]"));
assert!(beob_text.contains("203.0.113.10-1")); assert!(beob_text.contains("203.0.113.10-1"));
let req404 = Request::builder().uri("/other").body(()).unwrap(); let req404 = Request::builder()
.uri("/other")
.body(())
.unwrap();
let resp404 = handle(req404, &stats, &beobachten, &tracker, &config) let resp404 = handle(req404, &stats, &beobachten, &tracker, &config)
.await .await
.unwrap(); .unwrap();

View File

@ -26,7 +26,9 @@ fn parse_ip_spec(ip_spec: &str) -> Result<IpAddr> {
} }
let ip = ip_spec.parse::<IpAddr>().map_err(|_| { let ip = ip_spec.parse::<IpAddr>().map_err(|_| {
ProxyError::Config(format!("network.dns_overrides IP is invalid: '{ip_spec}'")) ProxyError::Config(format!(
"network.dns_overrides IP is invalid: '{ip_spec}'"
))
})?; })?;
if matches!(ip, IpAddr::V6(_)) { if matches!(ip, IpAddr::V6(_)) {
return Err(ProxyError::Config(format!( return Err(ProxyError::Config(format!(
@ -101,9 +103,9 @@ pub fn validate_entries(entries: &[String]) -> Result<()> {
/// Replace runtime DNS overrides with a new validated snapshot. /// Replace runtime DNS overrides with a new validated snapshot.
pub fn install_entries(entries: &[String]) -> Result<()> { pub fn install_entries(entries: &[String]) -> Result<()> {
let parsed = parse_entries(entries)?; let parsed = parse_entries(entries)?;
let mut guard = overrides_store().write().map_err(|_| { let mut guard = overrides_store()
ProxyError::Config("network.dns_overrides runtime lock is poisoned".to_string()) .write()
})?; .map_err(|_| ProxyError::Config("network.dns_overrides runtime lock is poisoned".to_string()))?;
*guard = parsed; *guard = parsed;
Ok(()) Ok(())
} }

View File

@ -1,5 +1,4 @@
#![allow(dead_code)] #![allow(dead_code)]
#![allow(clippy::items_after_test_module)]
use std::collections::HashMap; use std::collections::HashMap;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket}; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket};
@ -11,9 +10,7 @@ use tracing::{debug, info, warn};
use crate::config::{NetworkConfig, UpstreamConfig, UpstreamType}; use crate::config::{NetworkConfig, UpstreamConfig, UpstreamType};
use crate::error::Result; use crate::error::Result;
use crate::network::stun::{ use crate::network::stun::{stun_probe_family_with_bind, DualStunResult, IpFamily, StunProbeResult};
DualStunResult, IpFamily, StunProbeResult, stun_probe_family_with_bind,
};
use crate::transport::UpstreamManager; use crate::transport::UpstreamManager;
#[derive(Debug, Clone, Default)] #[derive(Debug, Clone, Default)]
@ -81,8 +78,13 @@ pub async fn run_probe(
warn!("STUN probe is enabled but network.stun_servers is empty"); warn!("STUN probe is enabled but network.stun_servers is empty");
DualStunResult::default() DualStunResult::default()
} else { } else {
probe_stun_servers_parallel(&servers, stun_nat_probe_concurrency.max(1), None, None) probe_stun_servers_parallel(
.await &servers,
stun_nat_probe_concurrency.max(1),
None,
None,
)
.await
} }
} else if nat_probe { } else if nat_probe {
info!("STUN probe is disabled by network.stun_use=false"); info!("STUN probe is disabled by network.stun_use=false");
@ -97,8 +99,7 @@ pub async fn run_probe(
let UpstreamType::Direct { let UpstreamType::Direct {
interface, interface,
bind_addresses, bind_addresses,
} = &upstream.upstream_type } = &upstream.upstream_type else {
else {
continue; continue;
}; };
if let Some(addrs) = bind_addresses.as_ref().filter(|v| !v.is_empty()) { if let Some(addrs) = bind_addresses.as_ref().filter(|v| !v.is_empty()) {
@ -198,10 +199,11 @@ pub async fn run_probe(
if nat_probe if nat_probe
&& probe.reflected_ipv4.is_none() && probe.reflected_ipv4.is_none()
&& probe.detected_ipv4.map(is_bogon_v4).unwrap_or(false) && probe.detected_ipv4.map(is_bogon_v4).unwrap_or(false)
&& let Some(public_ip) = detect_public_ipv4_http(&config.http_ip_detect_urls).await
{ {
probe.reflected_ipv4 = Some(SocketAddr::new(IpAddr::V4(public_ip), 0)); if let Some(public_ip) = detect_public_ipv4_http(&config.http_ip_detect_urls).await {
info!(public_ip = %public_ip, "STUN unavailable, using HTTP public IPv4 fallback"); probe.reflected_ipv4 = Some(SocketAddr::new(IpAddr::V4(public_ip), 0));
info!(public_ip = %public_ip, "STUN unavailable, using HTTP public IPv4 fallback");
}
} }
probe.ipv4_nat_detected = match (probe.detected_ipv4, probe.reflected_ipv4) { probe.ipv4_nat_detected = match (probe.detected_ipv4, probe.reflected_ipv4) {
@ -215,20 +217,12 @@ pub async fn run_probe(
probe.ipv4_usable = config.ipv4 probe.ipv4_usable = config.ipv4
&& probe.detected_ipv4.is_some() && probe.detected_ipv4.is_some()
&& (!probe.ipv4_is_bogon && (!probe.ipv4_is_bogon || probe.reflected_ipv4.map(|r| !is_bogon(r.ip())).unwrap_or(false));
|| probe
.reflected_ipv4
.map(|r| !is_bogon(r.ip()))
.unwrap_or(false));
let ipv6_enabled = config.ipv6.unwrap_or(probe.detected_ipv6.is_some()); let ipv6_enabled = config.ipv6.unwrap_or(probe.detected_ipv6.is_some());
probe.ipv6_usable = ipv6_enabled probe.ipv6_usable = ipv6_enabled
&& probe.detected_ipv6.is_some() && probe.detected_ipv6.is_some()
&& (!probe.ipv6_is_bogon && (!probe.ipv6_is_bogon || probe.reflected_ipv6.map(|r| !is_bogon(r.ip())).unwrap_or(false));
|| probe
.reflected_ipv6
.map(|r| !is_bogon(r.ip()))
.unwrap_or(false));
Ok(probe) Ok(probe)
} }
@ -286,6 +280,8 @@ async fn probe_stun_servers_parallel(
while next_idx < servers.len() && join_set.len() < concurrency { while next_idx < servers.len() && join_set.len() < concurrency {
let stun_addr = servers[next_idx].clone(); let stun_addr = servers[next_idx].clone();
next_idx += 1; next_idx += 1;
let bind_v4 = bind_v4;
let bind_v6 = bind_v6;
join_set.spawn(async move { join_set.spawn(async move {
let res = timeout(STUN_BATCH_TIMEOUT, async { let res = timeout(STUN_BATCH_TIMEOUT, async {
let v4 = stun_probe_family_with_bind(&stun_addr, IpFamily::V4, bind_v4).await?; let v4 = stun_probe_family_with_bind(&stun_addr, IpFamily::V4, bind_v4).await?;
@ -304,15 +300,11 @@ async fn probe_stun_servers_parallel(
match task { match task {
Ok((stun_addr, Ok(Ok(result)))) => { Ok((stun_addr, Ok(Ok(result)))) => {
if let Some(v4) = result.v4 { if let Some(v4) = result.v4 {
let entry = best_v4_by_ip let entry = best_v4_by_ip.entry(v4.reflected_addr.ip()).or_insert((0, v4));
.entry(v4.reflected_addr.ip())
.or_insert((0, v4));
entry.0 += 1; entry.0 += 1;
} }
if let Some(v6) = result.v6 { if let Some(v6) = result.v6 {
let entry = best_v6_by_ip let entry = best_v6_by_ip.entry(v6.reflected_addr.ip()).or_insert((0, v6));
.entry(v6.reflected_addr.ip())
.or_insert((0, v6));
entry.0 += 1; entry.0 += 1;
} }
if result.v4.is_some() || result.v6.is_some() { if result.v4.is_some() || result.v6.is_some() {
@ -332,11 +324,17 @@ async fn probe_stun_servers_parallel(
} }
let mut out = DualStunResult::default(); let mut out = DualStunResult::default();
if let Some((_, best)) = best_v4_by_ip.into_values().max_by_key(|(count, _)| *count) { if let Some((_, best)) = best_v4_by_ip
.into_values()
.max_by_key(|(count, _)| *count)
{
info!("STUN-Quorum reached, IP: {}", best.reflected_addr.ip()); info!("STUN-Quorum reached, IP: {}", best.reflected_addr.ip());
out.v4 = Some(best); out.v4 = Some(best);
} }
if let Some((_, best)) = best_v6_by_ip.into_values().max_by_key(|(count, _)| *count) { if let Some((_, best)) = best_v6_by_ip
.into_values()
.max_by_key(|(count, _)| *count)
{
info!("STUN-Quorum reached, IP: {}", best.reflected_addr.ip()); info!("STUN-Quorum reached, IP: {}", best.reflected_addr.ip());
out.v6 = Some(best); out.v6 = Some(best);
} }
@ -349,8 +347,7 @@ pub fn decide_network_capabilities(
middle_proxy_nat_ip: Option<IpAddr>, middle_proxy_nat_ip: Option<IpAddr>,
) -> NetworkDecision { ) -> NetworkDecision {
let ipv4_dc = config.ipv4 && probe.detected_ipv4.is_some(); let ipv4_dc = config.ipv4 && probe.detected_ipv4.is_some();
let ipv6_dc = let ipv6_dc = config.ipv6.unwrap_or(probe.detected_ipv6.is_some()) && probe.detected_ipv6.is_some();
config.ipv6.unwrap_or(probe.detected_ipv6.is_some()) && probe.detected_ipv6.is_some();
let nat_ip_v4 = matches!(middle_proxy_nat_ip, Some(IpAddr::V4(_))); let nat_ip_v4 = matches!(middle_proxy_nat_ip, Some(IpAddr::V4(_)));
let nat_ip_v6 = matches!(middle_proxy_nat_ip, Some(IpAddr::V6(_))); let nat_ip_v6 = matches!(middle_proxy_nat_ip, Some(IpAddr::V6(_)));
@ -537,26 +534,10 @@ pub fn is_bogon_v6(ip: Ipv6Addr) -> bool {
pub fn log_probe_result(probe: &NetworkProbe, decision: &NetworkDecision) { pub fn log_probe_result(probe: &NetworkProbe, decision: &NetworkDecision) {
info!( info!(
ipv4 = probe ipv4 = probe.detected_ipv4.as_ref().map(|v| v.to_string()).unwrap_or_else(|| "-".into()),
.detected_ipv4 ipv6 = probe.detected_ipv6.as_ref().map(|v| v.to_string()).unwrap_or_else(|| "-".into()),
.as_ref() reflected_v4 = probe.reflected_ipv4.as_ref().map(|v| v.ip().to_string()).unwrap_or_else(|| "-".into()),
.map(|v| v.to_string()) reflected_v6 = probe.reflected_ipv6.as_ref().map(|v| v.ip().to_string()).unwrap_or_else(|| "-".into()),
.unwrap_or_else(|| "-".into()),
ipv6 = probe
.detected_ipv6
.as_ref()
.map(|v| v.to_string())
.unwrap_or_else(|| "-".into()),
reflected_v4 = probe
.reflected_ipv4
.as_ref()
.map(|v| v.ip().to_string())
.unwrap_or_else(|| "-".into()),
reflected_v6 = probe
.reflected_ipv6
.as_ref()
.map(|v| v.ip().to_string())
.unwrap_or_else(|| "-".into()),
ipv4_bogon = probe.ipv4_is_bogon, ipv4_bogon = probe.ipv4_is_bogon,
ipv6_bogon = probe.ipv6_is_bogon, ipv6_bogon = probe.ipv6_is_bogon,
ipv4_me = decision.ipv4_me, ipv4_me = decision.ipv4_me,

View File

@ -4,8 +4,8 @@
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::sync::OnceLock; use std::sync::OnceLock;
use tokio::net::{UdpSocket, lookup_host}; use tokio::net::{lookup_host, UdpSocket};
use tokio::time::{Duration, sleep, timeout}; use tokio::time::{timeout, Duration, sleep};
use crate::crypto::SecureRandom; use crate::crypto::SecureRandom;
use crate::error::{ProxyError, Result}; use crate::error::{ProxyError, Result};
@ -41,13 +41,13 @@ pub async fn stun_probe_dual(stun_addr: &str) -> Result<DualStunResult> {
stun_probe_family(stun_addr, IpFamily::V6), stun_probe_family(stun_addr, IpFamily::V6),
); );
Ok(DualStunResult { v4: v4?, v6: v6? }) Ok(DualStunResult {
v4: v4?,
v6: v6?,
})
} }
pub async fn stun_probe_family( pub async fn stun_probe_family(stun_addr: &str, family: IpFamily) -> Result<Option<StunProbeResult>> {
stun_addr: &str,
family: IpFamily,
) -> Result<Option<StunProbeResult>> {
stun_probe_family_with_bind(stun_addr, family, None).await stun_probe_family_with_bind(stun_addr, family, None).await
} }
@ -76,18 +76,13 @@ pub async fn stun_probe_family_with_bind(
if let Some(addr) = target_addr { if let Some(addr) = target_addr {
match socket.connect(addr).await { match socket.connect(addr).await {
Ok(()) => {} Ok(()) => {}
Err(e) Err(e) if family == IpFamily::V6 && matches!(
if family == IpFamily::V6 e.kind(),
&& matches!( std::io::ErrorKind::NetworkUnreachable
e.kind(), | std::io::ErrorKind::HostUnreachable
std::io::ErrorKind::NetworkUnreachable | std::io::ErrorKind::Unsupported
| std::io::ErrorKind::HostUnreachable | std::io::ErrorKind::NetworkDown
| std::io::ErrorKind::Unsupported ) => return Ok(None),
| std::io::ErrorKind::NetworkDown
) =>
{
return Ok(None);
}
Err(e) => return Err(ProxyError::Proxy(format!("STUN connect failed: {e}"))), Err(e) => return Err(ProxyError::Proxy(format!("STUN connect failed: {e}"))),
} }
} else { } else {
@ -130,16 +125,16 @@ pub async fn stun_probe_family_with_bind(
let magic = 0x2112A442u32.to_be_bytes(); let magic = 0x2112A442u32.to_be_bytes();
let txid = &req[8..20]; let txid = &req[8..20];
let mut idx = 20; let mut idx = 20;
while idx + 4 <= n { while idx + 4 <= n {
let atype = u16::from_be_bytes(buf[idx..idx + 2].try_into().unwrap()); let atype = u16::from_be_bytes(buf[idx..idx + 2].try_into().unwrap());
let alen = u16::from_be_bytes(buf[idx + 2..idx + 4].try_into().unwrap()) as usize; let alen = u16::from_be_bytes(buf[idx + 2..idx + 4].try_into().unwrap()) as usize;
idx += 4; idx += 4;
if idx + alen > n { if idx + alen > n {
break; break;
} }
match atype { match atype {
0x0020 /* XOR-MAPPED-ADDRESS */ | 0x0001 /* MAPPED-ADDRESS */ => { 0x0020 /* XOR-MAPPED-ADDRESS */ | 0x0001 /* MAPPED-ADDRESS */ => {
if alen < 8 { if alen < 8 {
break; break;
@ -208,8 +203,9 @@ pub async fn stun_probe_family_with_bind(
_ => {} _ => {}
} }
idx += (alen + 3) & !3; idx += (alen + 3) & !3;
} }
} }
Ok(None) Ok(None)
@ -237,11 +233,7 @@ async fn resolve_stun_addr(stun_addr: &str, family: IpFamily) -> Result<Option<S
.await .await
.map_err(|e| ProxyError::Proxy(format!("STUN resolve failed: {e}")))?; .map_err(|e| ProxyError::Proxy(format!("STUN resolve failed: {e}")))?;
let target = addrs.find(|a| { let target = addrs
matches!( .find(|a| matches!((a.is_ipv4(), family), (true, IpFamily::V4) | (false, IpFamily::V6)));
(a.is_ipv4(), family),
(true, IpFamily::V4) | (false, IpFamily::V6)
)
});
Ok(target) Ok(target)
} }

View File

@ -33,89 +33,35 @@ pub static TG_DATACENTERS_V6: LazyLock<Vec<IpAddr>> = LazyLock::new(|| {
// ============= Middle Proxies (for advertising) ============= // ============= Middle Proxies (for advertising) =============
pub static TG_MIDDLE_PROXIES_V4: LazyLock<std::collections::HashMap<i32, Vec<(IpAddr, u16)>>> = pub static TG_MIDDLE_PROXIES_V4: LazyLock<std::collections::HashMap<i32, Vec<(IpAddr, u16)>>> =
LazyLock::new(|| { LazyLock::new(|| {
let mut m = std::collections::HashMap::new(); let mut m = std::collections::HashMap::new();
m.insert( m.insert(1, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), 8888)]);
1, m.insert(-1, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), 8888)]);
vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), 8888)], m.insert(2, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 161, 144)), 8888)]);
); m.insert(-2, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 161, 144)), 8888)]);
m.insert( m.insert(3, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 100)), 8888)]);
-1, m.insert(-3, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 100)), 8888)]);
vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), 8888)],
);
m.insert(
2,
vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 161, 144)), 8888)],
);
m.insert(
-2,
vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 161, 144)), 8888)],
);
m.insert(
3,
vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 100)), 8888)],
);
m.insert(
-3,
vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 100)), 8888)],
);
m.insert(4, vec![(IpAddr::V4(Ipv4Addr::new(91, 108, 4, 136)), 8888)]); m.insert(4, vec![(IpAddr::V4(Ipv4Addr::new(91, 108, 4, 136)), 8888)]);
m.insert( m.insert(-4, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 165, 109)), 8888)]);
-4,
vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 165, 109)), 8888)],
);
m.insert(5, vec![(IpAddr::V4(Ipv4Addr::new(91, 108, 56, 183)), 8888)]); m.insert(5, vec![(IpAddr::V4(Ipv4Addr::new(91, 108, 56, 183)), 8888)]);
m.insert( m.insert(-5, vec![(IpAddr::V4(Ipv4Addr::new(91, 108, 56, 183)), 8888)]);
-5,
vec![(IpAddr::V4(Ipv4Addr::new(91, 108, 56, 183)), 8888)],
);
m m
}); });
pub static TG_MIDDLE_PROXIES_V6: LazyLock<std::collections::HashMap<i32, Vec<(IpAddr, u16)>>> = pub static TG_MIDDLE_PROXIES_V6: LazyLock<std::collections::HashMap<i32, Vec<(IpAddr, u16)>>> =
LazyLock::new(|| { LazyLock::new(|| {
let mut m = std::collections::HashMap::new(); let mut m = std::collections::HashMap::new();
m.insert( m.insert(1, vec![(IpAddr::V6("2001:b28:f23d:f001::d".parse().unwrap()), 8888)]);
1, m.insert(-1, vec![(IpAddr::V6("2001:b28:f23d:f001::d".parse().unwrap()), 8888)]);
vec![(IpAddr::V6("2001:b28:f23d:f001::d".parse().unwrap()), 8888)], m.insert(2, vec![(IpAddr::V6("2001:67c:04e8:f002::d".parse().unwrap()), 80)]);
); m.insert(-2, vec![(IpAddr::V6("2001:67c:04e8:f002::d".parse().unwrap()), 80)]);
m.insert( m.insert(3, vec![(IpAddr::V6("2001:b28:f23d:f003::d".parse().unwrap()), 8888)]);
-1, m.insert(-3, vec![(IpAddr::V6("2001:b28:f23d:f003::d".parse().unwrap()), 8888)]);
vec![(IpAddr::V6("2001:b28:f23d:f001::d".parse().unwrap()), 8888)], m.insert(4, vec![(IpAddr::V6("2001:67c:04e8:f004::d".parse().unwrap()), 8888)]);
); m.insert(-4, vec![(IpAddr::V6("2001:67c:04e8:f004::d".parse().unwrap()), 8888)]);
m.insert( m.insert(5, vec![(IpAddr::V6("2001:b28:f23f:f005::d".parse().unwrap()), 8888)]);
2, m.insert(-5, vec![(IpAddr::V6("2001:b28:f23f:f005::d".parse().unwrap()), 8888)]);
vec![(IpAddr::V6("2001:67c:04e8:f002::d".parse().unwrap()), 80)],
);
m.insert(
-2,
vec![(IpAddr::V6("2001:67c:04e8:f002::d".parse().unwrap()), 80)],
);
m.insert(
3,
vec![(IpAddr::V6("2001:b28:f23d:f003::d".parse().unwrap()), 8888)],
);
m.insert(
-3,
vec![(IpAddr::V6("2001:b28:f23d:f003::d".parse().unwrap()), 8888)],
);
m.insert(
4,
vec![(IpAddr::V6("2001:67c:04e8:f004::d".parse().unwrap()), 8888)],
);
m.insert(
-4,
vec![(IpAddr::V6("2001:67c:04e8:f004::d".parse().unwrap()), 8888)],
);
m.insert(
5,
vec![(IpAddr::V6("2001:b28:f23f:f005::d".parse().unwrap()), 8888)],
);
m.insert(
-5,
vec![(IpAddr::V6("2001:b28:f23f:f005::d".parse().unwrap()), 8888)],
);
m m
}); });
@ -143,12 +89,12 @@ impl ProtoTag {
_ => None, _ => None,
} }
} }
/// Convert to 4 bytes (little-endian) /// Convert to 4 bytes (little-endian)
pub fn to_bytes(self) -> [u8; 4] { pub fn to_bytes(self) -> [u8; 4] {
(self as u32).to_le_bytes() (self as u32).to_le_bytes()
} }
/// Get protocol tag as bytes slice /// Get protocol tag as bytes slice
pub fn as_bytes(&self) -> &'static [u8; 4] { pub fn as_bytes(&self) -> &'static [u8; 4] {
match self { match self {
@ -276,7 +222,9 @@ pub const SMALL_BUFFER_SIZE: usize = 8192;
// ============= Statistics ============= // ============= Statistics =============
/// Duration buckets for histogram metrics /// Duration buckets for histogram metrics
pub static DURATION_BUCKETS: &[f64] = &[0.1, 0.5, 1.0, 2.0, 5.0, 15.0, 60.0, 300.0, 600.0, 1800.0]; pub static DURATION_BUCKETS: &[f64] = &[
0.1, 0.5, 1.0, 2.0, 5.0, 15.0, 60.0, 300.0, 600.0, 1800.0,
];
// ============= Reserved Nonce Patterns ============= // ============= Reserved Nonce Patterns =============
@ -287,27 +235,29 @@ pub static RESERVED_NONCE_FIRST_BYTES: &[u8] = &[0xef];
pub static RESERVED_NONCE_BEGINNINGS: &[[u8; 4]] = &[ pub static RESERVED_NONCE_BEGINNINGS: &[[u8; 4]] = &[
[0x48, 0x45, 0x41, 0x44], // HEAD [0x48, 0x45, 0x41, 0x44], // HEAD
[0x50, 0x4F, 0x53, 0x54], // POST [0x50, 0x4F, 0x53, 0x54], // POST
[0x47, 0x45, 0x54, 0x20], // GET [0x47, 0x45, 0x54, 0x20], // GET
[0xee, 0xee, 0xee, 0xee], // Intermediate [0xee, 0xee, 0xee, 0xee], // Intermediate
[0xdd, 0xdd, 0xdd, 0xdd], // Secure [0xdd, 0xdd, 0xdd, 0xdd], // Secure
[0x16, 0x03, 0x01, 0x02], // TLS [0x16, 0x03, 0x01, 0x02], // TLS
]; ];
/// Reserved continuation bytes (bytes 4-7) /// Reserved continuation bytes (bytes 4-7)
pub static RESERVED_NONCE_CONTINUES: &[[u8; 4]] = &[[0x00, 0x00, 0x00, 0x00]]; pub static RESERVED_NONCE_CONTINUES: &[[u8; 4]] = &[
[0x00, 0x00, 0x00, 0x00],
];
// ============= RPC Constants (for Middle Proxy) ============= // ============= RPC Constants (for Middle Proxy) =============
/// RPC Proxy Request /// RPC Proxy Request
/// RPC Flags (from Erlang mtp_rpc.erl) /// RPC Flags (from Erlang mtp_rpc.erl)
pub const RPC_FLAG_NOT_ENCRYPTED: u32 = 0x2; pub const RPC_FLAG_NOT_ENCRYPTED: u32 = 0x2;
pub const RPC_FLAG_HAS_AD_TAG: u32 = 0x8; pub const RPC_FLAG_HAS_AD_TAG: u32 = 0x8;
pub const RPC_FLAG_MAGIC: u32 = 0x1000; pub const RPC_FLAG_MAGIC: u32 = 0x1000;
pub const RPC_FLAG_EXTMODE2: u32 = 0x20000; pub const RPC_FLAG_EXTMODE2: u32 = 0x20000;
pub const RPC_FLAG_PAD: u32 = 0x8000000; pub const RPC_FLAG_PAD: u32 = 0x8000000;
pub const RPC_FLAG_INTERMEDIATE: u32 = 0x20000000; pub const RPC_FLAG_INTERMEDIATE: u32 = 0x20000000;
pub const RPC_FLAG_ABRIDGED: u32 = 0x40000000; pub const RPC_FLAG_ABRIDGED: u32 = 0x40000000;
pub const RPC_FLAG_QUICKACK: u32 = 0x80000000; pub const RPC_FLAG_QUICKACK: u32 = 0x80000000;
pub const RPC_PROXY_REQ: [u8; 4] = [0xee, 0xf1, 0xce, 0x36]; pub const RPC_PROXY_REQ: [u8; 4] = [0xee, 0xf1, 0xce, 0x36];
/// RPC Proxy Answer /// RPC Proxy Answer
@ -335,66 +285,67 @@ pub mod rpc_flags {
pub const FLAG_QUICKACK: u32 = 0x80000000; pub const FLAG_QUICKACK: u32 = 0x80000000;
} }
// ============= Middle-End Proxy Servers =============
pub const ME_PROXY_PORT: u16 = 8888;
pub static TG_MIDDLE_PROXIES_FLAT_V4: LazyLock<Vec<(IpAddr, u16)>> = LazyLock::new(|| { // ============= Middle-End Proxy Servers =============
vec![ pub const ME_PROXY_PORT: u16 = 8888;
(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), 8888),
(IpAddr::V4(Ipv4Addr::new(149, 154, 161, 144)), 8888), pub static TG_MIDDLE_PROXIES_FLAT_V4: LazyLock<Vec<(IpAddr, u16)>> = LazyLock::new(|| {
(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 100)), 8888), vec![
(IpAddr::V4(Ipv4Addr::new(91, 108, 4, 136)), 8888), (IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), 8888),
(IpAddr::V4(Ipv4Addr::new(91, 108, 56, 183)), 8888), (IpAddr::V4(Ipv4Addr::new(149, 154, 161, 144)), 8888),
] (IpAddr::V4(Ipv4Addr::new(149, 154, 175, 100)), 8888),
}); (IpAddr::V4(Ipv4Addr::new(91, 108, 4, 136)), 8888),
(IpAddr::V4(Ipv4Addr::new(91, 108, 56, 183)), 8888),
]
});
// ============= RPC Constants (u32 native endian) =============
// From mtproto-common.h + net-tcp-rpc-common.h + mtproto-proxy.c
pub const RPC_NONCE_U32: u32 = 0x7acb87aa;
pub const RPC_HANDSHAKE_U32: u32 = 0x7682eef5;
pub const RPC_HANDSHAKE_ERROR_U32: u32 = 0x6a27beda;
pub const TL_PROXY_TAG_U32: u32 = 0xdb1e26ae; // mtproto-proxy.c:121
// mtproto-common.h
pub const RPC_PROXY_REQ_U32: u32 = 0x36cef1ee;
pub const RPC_PROXY_ANS_U32: u32 = 0x4403da0d;
pub const RPC_CLOSE_CONN_U32: u32 = 0x1fcf425d;
pub const RPC_CLOSE_EXT_U32: u32 = 0x5eb634a2;
pub const RPC_SIMPLE_ACK_U32: u32 = 0x3bac409b;
pub const RPC_PING_U32: u32 = 0x5730a2df;
pub const RPC_PONG_U32: u32 = 0x8430eaa7;
pub const RPC_CRYPTO_NONE_U32: u32 = 0;
pub const RPC_CRYPTO_AES_U32: u32 = 1;
pub mod proxy_flags {
pub const FLAG_HAS_AD_TAG: u32 = 1;
pub const FLAG_NOT_ENCRYPTED: u32 = 0x2;
pub const FLAG_HAS_AD_TAG2: u32 = 0x8;
pub const FLAG_MAGIC: u32 = 0x1000;
pub const FLAG_EXTMODE2: u32 = 0x20000;
pub const FLAG_PAD: u32 = 0x8000000;
pub const FLAG_INTERMEDIATE: u32 = 0x20000000;
pub const FLAG_ABRIDGED: u32 = 0x40000000;
pub const FLAG_QUICKACK: u32 = 0x80000000;
}
// ============= RPC Constants (u32 native endian) ============= pub mod rpc_crypto_flags {
// From mtproto-common.h + net-tcp-rpc-common.h + mtproto-proxy.c pub const USE_CRC32C: u32 = 0x800;
}
pub const ME_CONNECT_TIMEOUT_SECS: u64 = 5;
pub const ME_HANDSHAKE_TIMEOUT_SECS: u64 = 10;
pub const RPC_NONCE_U32: u32 = 0x7acb87aa; #[cfg(test)]
pub const RPC_HANDSHAKE_U32: u32 = 0x7682eef5; #[path = "tests/tls_size_constants_security_tests.rs"]
pub const RPC_HANDSHAKE_ERROR_U32: u32 = 0x6a27beda; mod tls_size_constants_security_tests;
pub const TL_PROXY_TAG_U32: u32 = 0xdb1e26ae; // mtproto-proxy.c:121
#[cfg(test)]
// mtproto-common.h
pub const RPC_PROXY_REQ_U32: u32 = 0x36cef1ee;
pub const RPC_PROXY_ANS_U32: u32 = 0x4403da0d;
pub const RPC_CLOSE_CONN_U32: u32 = 0x1fcf425d;
pub const RPC_CLOSE_EXT_U32: u32 = 0x5eb634a2;
pub const RPC_SIMPLE_ACK_U32: u32 = 0x3bac409b;
pub const RPC_PING_U32: u32 = 0x5730a2df;
pub const RPC_PONG_U32: u32 = 0x8430eaa7;
pub const RPC_CRYPTO_NONE_U32: u32 = 0;
pub const RPC_CRYPTO_AES_U32: u32 = 1;
pub mod proxy_flags {
pub const FLAG_HAS_AD_TAG: u32 = 1;
pub const FLAG_NOT_ENCRYPTED: u32 = 0x2;
pub const FLAG_HAS_AD_TAG2: u32 = 0x8;
pub const FLAG_MAGIC: u32 = 0x1000;
pub const FLAG_EXTMODE2: u32 = 0x20000;
pub const FLAG_PAD: u32 = 0x8000000;
pub const FLAG_INTERMEDIATE: u32 = 0x20000000;
pub const FLAG_ABRIDGED: u32 = 0x40000000;
pub const FLAG_QUICKACK: u32 = 0x80000000;
}
pub mod rpc_crypto_flags {
pub const USE_CRC32C: u32 = 0x800;
}
pub const ME_CONNECT_TIMEOUT_SECS: u64 = 5;
pub const ME_HANDSHAKE_TIMEOUT_SECS: u64 = 10;
#[cfg(test)]
#[path = "tests/tls_size_constants_security_tests.rs"]
mod tls_size_constants_security_tests;
#[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
#[test] #[test]
fn test_proto_tag_roundtrip() { fn test_proto_tag_roundtrip() {
for tag in [ProtoTag::Abridged, ProtoTag::Intermediate, ProtoTag::Secure] { for tag in [ProtoTag::Abridged, ProtoTag::Intermediate, ProtoTag::Secure] {
@ -403,20 +354,20 @@ mod tests {
assert_eq!(tag, parsed); assert_eq!(tag, parsed);
} }
} }
#[test] #[test]
fn test_proto_tag_values() { fn test_proto_tag_values() {
assert_eq!(ProtoTag::Abridged.to_bytes(), PROTO_TAG_ABRIDGED); assert_eq!(ProtoTag::Abridged.to_bytes(), PROTO_TAG_ABRIDGED);
assert_eq!(ProtoTag::Intermediate.to_bytes(), PROTO_TAG_INTERMEDIATE); assert_eq!(ProtoTag::Intermediate.to_bytes(), PROTO_TAG_INTERMEDIATE);
assert_eq!(ProtoTag::Secure.to_bytes(), PROTO_TAG_SECURE); assert_eq!(ProtoTag::Secure.to_bytes(), PROTO_TAG_SECURE);
} }
#[test] #[test]
fn test_invalid_proto_tag() { fn test_invalid_proto_tag() {
assert!(ProtoTag::from_bytes([0, 0, 0, 0]).is_none()); assert!(ProtoTag::from_bytes([0, 0, 0, 0]).is_none());
assert!(ProtoTag::from_bytes([0xff, 0xff, 0xff, 0xff]).is_none()); assert!(ProtoTag::from_bytes([0xff, 0xff, 0xff, 0xff]).is_none());
} }
#[test] #[test]
fn test_datacenters_count() { fn test_datacenters_count() {
assert_eq!(TG_DATACENTERS_V4.len(), 5); assert_eq!(TG_DATACENTERS_V4.len(), 5);

View File

@ -22,7 +22,7 @@ impl FrameExtra {
pub fn new() -> Self { pub fn new() -> Self {
Self::default() Self::default()
} }
/// Create with quickack flag set /// Create with quickack flag set
pub fn with_quickack() -> Self { pub fn with_quickack() -> Self {
Self { Self {
@ -30,7 +30,7 @@ impl FrameExtra {
..Default::default() ..Default::default()
} }
} }
/// Create with simple_ack flag set /// Create with simple_ack flag set
pub fn with_simple_ack() -> Self { pub fn with_simple_ack() -> Self {
Self { Self {
@ -38,7 +38,7 @@ impl FrameExtra {
..Default::default() ..Default::default()
} }
} }
/// Check if any flags are set /// Check if any flags are set
pub fn has_flags(&self) -> bool { pub fn has_flags(&self) -> bool {
self.quickack || self.simple_ack || self.skip_send self.quickack || self.simple_ack || self.skip_send
@ -76,22 +76,22 @@ impl FrameMode {
FrameMode::Abridged => 4, FrameMode::Abridged => 4,
FrameMode::Intermediate => 4, FrameMode::Intermediate => 4,
FrameMode::SecureIntermediate => 4 + 3, // length + padding FrameMode::SecureIntermediate => 4 + 3, // length + padding
FrameMode::Full => 12 + 16, // header + max CBC padding FrameMode::Full => 12 + 16, // header + max CBC padding
} }
} }
} }
/// Validate message length for MTProto /// Validate message length for MTProto
pub fn validate_message_length(len: usize) -> bool { pub fn validate_message_length(len: usize) -> bool {
use super::constants::{MAX_MSG_LEN, MIN_MSG_LEN, PADDING_FILLER}; use super::constants::{MIN_MSG_LEN, MAX_MSG_LEN, PADDING_FILLER};
(MIN_MSG_LEN..=MAX_MSG_LEN).contains(&len) && len.is_multiple_of(PADDING_FILLER.len()) (MIN_MSG_LEN..=MAX_MSG_LEN).contains(&len) && len.is_multiple_of(PADDING_FILLER.len())
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
#[test] #[test]
fn test_frame_extra_default() { fn test_frame_extra_default() {
let extra = FrameExtra::default(); let extra = FrameExtra::default();
@ -100,18 +100,18 @@ mod tests {
assert!(!extra.skip_send); assert!(!extra.skip_send);
assert!(!extra.has_flags()); assert!(!extra.has_flags());
} }
#[test] #[test]
fn test_frame_extra_flags() { fn test_frame_extra_flags() {
let extra = FrameExtra::with_quickack(); let extra = FrameExtra::with_quickack();
assert!(extra.quickack); assert!(extra.quickack);
assert!(extra.has_flags()); assert!(extra.has_flags());
let extra = FrameExtra::with_simple_ack(); let extra = FrameExtra::with_simple_ack();
assert!(extra.simple_ack); assert!(extra.simple_ack);
assert!(extra.has_flags()); assert!(extra.has_flags());
} }
#[test] #[test]
fn test_validate_message_length() { fn test_validate_message_length() {
assert!(validate_message_length(12)); // MIN_MSG_LEN assert!(validate_message_length(12)); // MIN_MSG_LEN
@ -119,4 +119,4 @@ mod tests {
assert!(!validate_message_length(8)); // Too small assert!(!validate_message_length(8)); // Too small
assert!(!validate_message_length(13)); // Not aligned to 4 assert!(!validate_message_length(13)); // Not aligned to 4
} }
} }

View File

@ -12,4 +12,4 @@ pub use frame::*;
#[allow(unused_imports)] #[allow(unused_imports)]
pub use obfuscation::*; pub use obfuscation::*;
#[allow(unused_imports)] #[allow(unused_imports)]
pub use tls::*; pub use tls::*;

View File

@ -2,9 +2,9 @@
#![allow(dead_code)] #![allow(dead_code)]
use super::constants::*;
use crate::crypto::{AesCtr, sha256};
use zeroize::Zeroize; use zeroize::Zeroize;
use crate::crypto::{sha256, AesCtr};
use super::constants::*;
/// Obfuscation parameters from handshake /// Obfuscation parameters from handshake
/// ///
@ -44,40 +44,41 @@ impl ObfuscationParams {
let dec_prekey_iv = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]; let dec_prekey_iv = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN];
let dec_prekey = &dec_prekey_iv[..PREKEY_LEN]; let dec_prekey = &dec_prekey_iv[..PREKEY_LEN];
let dec_iv_bytes = &dec_prekey_iv[PREKEY_LEN..]; let dec_iv_bytes = &dec_prekey_iv[PREKEY_LEN..];
let enc_prekey_iv: Vec<u8> = dec_prekey_iv.iter().rev().copied().collect(); let enc_prekey_iv: Vec<u8> = dec_prekey_iv.iter().rev().copied().collect();
let enc_prekey = &enc_prekey_iv[..PREKEY_LEN]; let enc_prekey = &enc_prekey_iv[..PREKEY_LEN];
let enc_iv_bytes = &enc_prekey_iv[PREKEY_LEN..]; let enc_iv_bytes = &enc_prekey_iv[PREKEY_LEN..];
for (username, secret) in secrets { for (username, secret) in secrets {
let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len()); let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len());
dec_key_input.extend_from_slice(dec_prekey); dec_key_input.extend_from_slice(dec_prekey);
dec_key_input.extend_from_slice(secret); dec_key_input.extend_from_slice(secret);
let decrypt_key = sha256(&dec_key_input); let decrypt_key = sha256(&dec_key_input);
let decrypt_iv = u128::from_be_bytes(dec_iv_bytes.try_into().unwrap()); let decrypt_iv = u128::from_be_bytes(dec_iv_bytes.try_into().unwrap());
let mut decryptor = AesCtr::new(&decrypt_key, decrypt_iv); let mut decryptor = AesCtr::new(&decrypt_key, decrypt_iv);
let decrypted = decryptor.decrypt(handshake); let decrypted = decryptor.decrypt(handshake);
let tag_bytes: [u8; 4] = decrypted[PROTO_TAG_POS..PROTO_TAG_POS + 4] let tag_bytes: [u8; 4] = decrypted[PROTO_TAG_POS..PROTO_TAG_POS + 4]
.try_into() .try_into()
.unwrap(); .unwrap();
let proto_tag = match ProtoTag::from_bytes(tag_bytes) { let proto_tag = match ProtoTag::from_bytes(tag_bytes) {
Some(tag) => tag, Some(tag) => tag,
None => continue, None => continue,
}; };
let dc_idx = let dc_idx = i16::from_le_bytes(
i16::from_le_bytes(decrypted[DC_IDX_POS..DC_IDX_POS + 2].try_into().unwrap()); decrypted[DC_IDX_POS..DC_IDX_POS + 2].try_into().unwrap()
);
let mut enc_key_input = Vec::with_capacity(PREKEY_LEN + secret.len()); let mut enc_key_input = Vec::with_capacity(PREKEY_LEN + secret.len());
enc_key_input.extend_from_slice(enc_prekey); enc_key_input.extend_from_slice(enc_prekey);
enc_key_input.extend_from_slice(secret); enc_key_input.extend_from_slice(secret);
let encrypt_key = sha256(&enc_key_input); let encrypt_key = sha256(&enc_key_input);
let encrypt_iv = u128::from_be_bytes(enc_iv_bytes.try_into().unwrap()); let encrypt_iv = u128::from_be_bytes(enc_iv_bytes.try_into().unwrap());
return Some(( return Some((
ObfuscationParams { ObfuscationParams {
decrypt_key, decrypt_key,
@ -90,20 +91,20 @@ impl ObfuscationParams {
username.clone(), username.clone(),
)); ));
} }
None None
} }
/// Create AES-CTR decryptor for client -> proxy direction /// Create AES-CTR decryptor for client -> proxy direction
pub fn create_decryptor(&self) -> AesCtr { pub fn create_decryptor(&self) -> AesCtr {
AesCtr::new(&self.decrypt_key, self.decrypt_iv) AesCtr::new(&self.decrypt_key, self.decrypt_iv)
} }
/// Create AES-CTR encryptor for proxy -> client direction /// Create AES-CTR encryptor for proxy -> client direction
pub fn create_encryptor(&self) -> AesCtr { pub fn create_encryptor(&self) -> AesCtr {
AesCtr::new(&self.encrypt_key, self.encrypt_iv) AesCtr::new(&self.encrypt_key, self.encrypt_iv)
} }
/// Get the combined encrypt key and IV for fast mode /// Get the combined encrypt key and IV for fast mode
pub fn enc_key_iv(&self) -> Vec<u8> { pub fn enc_key_iv(&self) -> Vec<u8> {
let mut result = Vec::with_capacity(KEY_LEN + IV_LEN); let mut result = Vec::with_capacity(KEY_LEN + IV_LEN);
@ -119,7 +120,7 @@ pub fn generate_nonce<R: FnMut(usize) -> Vec<u8>>(mut random_bytes: R) -> [u8; H
let nonce_vec = random_bytes(HANDSHAKE_LEN); let nonce_vec = random_bytes(HANDSHAKE_LEN);
let mut nonce = [0u8; HANDSHAKE_LEN]; let mut nonce = [0u8; HANDSHAKE_LEN];
nonce.copy_from_slice(&nonce_vec); nonce.copy_from_slice(&nonce_vec);
if is_valid_nonce(&nonce) { if is_valid_nonce(&nonce) {
return nonce; return nonce;
} }
@ -131,17 +132,17 @@ pub fn is_valid_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> bool {
if RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]) { if RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]) {
return false; return false;
} }
let first_four: [u8; 4] = nonce[..4].try_into().unwrap(); let first_four: [u8; 4] = nonce[..4].try_into().unwrap();
if RESERVED_NONCE_BEGINNINGS.contains(&first_four) { if RESERVED_NONCE_BEGINNINGS.contains(&first_four) {
return false; return false;
} }
let continue_four: [u8; 4] = nonce[4..8].try_into().unwrap(); let continue_four: [u8; 4] = nonce[4..8].try_into().unwrap();
if RESERVED_NONCE_CONTINUES.contains(&continue_four) { if RESERVED_NONCE_CONTINUES.contains(&continue_four) {
return false; return false;
} }
true true
} }
@ -152,7 +153,7 @@ pub fn prepare_tg_nonce(
enc_key_iv: Option<&[u8]>, enc_key_iv: Option<&[u8]>,
) { ) {
nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes()); nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes());
if let Some(key_iv) = enc_key_iv { if let Some(key_iv) = enc_key_iv {
let reversed: Vec<u8> = key_iv.iter().rev().copied().collect(); let reversed: Vec<u8> = key_iv.iter().rev().copied().collect();
nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN].copy_from_slice(&reversed); nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN].copy_from_slice(&reversed);
@ -170,39 +171,39 @@ pub fn encrypt_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec<u8> {
let key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN]; let key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN];
let enc_key = sha256(key_iv); let enc_key = sha256(key_iv);
let enc_iv = u128::from_be_bytes(key_iv[..IV_LEN].try_into().unwrap()); let enc_iv = u128::from_be_bytes(key_iv[..IV_LEN].try_into().unwrap());
let mut encryptor = AesCtr::new(&enc_key, enc_iv); let mut encryptor = AesCtr::new(&enc_key, enc_iv);
let mut result = nonce.to_vec(); let mut result = nonce.to_vec();
let encrypted_part = encryptor.encrypt(&nonce[PROTO_TAG_POS..]); let encrypted_part = encryptor.encrypt(&nonce[PROTO_TAG_POS..]);
result[PROTO_TAG_POS..].copy_from_slice(&encrypted_part); result[PROTO_TAG_POS..].copy_from_slice(&encrypted_part);
result result
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
#[test] #[test]
fn test_is_valid_nonce() { fn test_is_valid_nonce() {
let mut valid = [0x42u8; HANDSHAKE_LEN]; let mut valid = [0x42u8; HANDSHAKE_LEN];
valid[4..8].copy_from_slice(&[1, 2, 3, 4]); valid[4..8].copy_from_slice(&[1, 2, 3, 4]);
assert!(is_valid_nonce(&valid)); assert!(is_valid_nonce(&valid));
let mut invalid = [0x00u8; HANDSHAKE_LEN]; let mut invalid = [0x00u8; HANDSHAKE_LEN];
invalid[0] = 0xef; invalid[0] = 0xef;
assert!(!is_valid_nonce(&invalid)); assert!(!is_valid_nonce(&invalid));
let mut invalid = [0x00u8; HANDSHAKE_LEN]; let mut invalid = [0x00u8; HANDSHAKE_LEN];
invalid[..4].copy_from_slice(b"HEAD"); invalid[..4].copy_from_slice(b"HEAD");
assert!(!is_valid_nonce(&invalid)); assert!(!is_valid_nonce(&invalid));
let mut invalid = [0x42u8; HANDSHAKE_LEN]; let mut invalid = [0x42u8; HANDSHAKE_LEN];
invalid[4..8].copy_from_slice(&[0, 0, 0, 0]); invalid[4..8].copy_from_slice(&[0, 0, 0, 0]);
assert!(!is_valid_nonce(&invalid)); assert!(!is_valid_nonce(&invalid));
} }
#[test] #[test]
fn test_generate_nonce() { fn test_generate_nonce() {
let mut counter = 0u8; let mut counter = 0u8;
@ -210,7 +211,7 @@ mod tests {
counter = counter.wrapping_add(1); counter = counter.wrapping_add(1);
vec![counter; n] vec![counter; n]
}); });
assert!(is_valid_nonce(&nonce)); assert!(is_valid_nonce(&nonce));
assert_eq!(nonce.len(), HANDSHAKE_LEN); assert_eq!(nonce.len(), HANDSHAKE_LEN);
} }

View File

@ -1,6 +1,6 @@
use super::*; use super::*;
use crate::crypto::sha256_hmac;
use std::time::Instant; use std::time::Instant;
use crate::crypto::sha256_hmac;
/// Helper to create a byte vector of specific length. /// Helper to create a byte vector of specific length.
fn make_garbage(len: usize) -> Vec<u8> { fn make_garbage(len: usize) -> Vec<u8> {
@ -33,7 +33,8 @@ fn make_valid_tls_handshake_with_session_id(
let digest = make_digest(secret, &handshake, timestamp); let digest = make_digest(secret, &handshake, timestamp);
handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].copy_from_slice(&digest); handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN]
.copy_from_slice(&digest);
handshake handshake
} }
@ -95,15 +96,15 @@ fn extract_sni_with_overlapping_extension_lengths_rejected() {
h.push(0); // Session ID length: 0 h.push(0); // Session ID length: 0
h.extend_from_slice(&[0x00, 0x02, 0x13, 0x01]); // Cipher suites h.extend_from_slice(&[0x00, 0x02, 0x13, 0x01]); // Cipher suites
h.extend_from_slice(&[0x01, 0x00]); // Compression h.extend_from_slice(&[0x01, 0x00]); // Compression
// Extensions start // Extensions start
h.extend_from_slice(&[0x00, 0x20]); // Total Extensions length: 32 h.extend_from_slice(&[0x00, 0x20]); // Total Extensions length: 32
// Extension 1: SNI (type 0) // Extension 1: SNI (type 0)
h.extend_from_slice(&[0x00, 0x00]); h.extend_from_slice(&[0x00, 0x00]);
h.extend_from_slice(&[0x00, 0x40]); // Claimed len: 64 (OVERFLOWS total extensions len 32) h.extend_from_slice(&[0x00, 0x40]); // Claimed len: 64 (OVERFLOWS total extensions len 32)
h.extend_from_slice(&[0u8; 64]); h.extend_from_slice(&[0u8; 64]);
assert!(extract_sni_from_client_hello(&h).is_none()); assert!(extract_sni_from_client_hello(&h).is_none());
} }
@ -117,19 +118,19 @@ fn extract_sni_with_infinite_loop_potential_extension_rejected() {
h.push(0); // Session ID length: 0 h.push(0); // Session ID length: 0
h.extend_from_slice(&[0x00, 0x02, 0x13, 0x01]); // Cipher suites h.extend_from_slice(&[0x00, 0x02, 0x13, 0x01]); // Cipher suites
h.extend_from_slice(&[0x01, 0x00]); // Compression h.extend_from_slice(&[0x01, 0x00]); // Compression
// Extensions start // Extensions start
h.extend_from_slice(&[0x00, 0x10]); // Total Extensions length: 16 h.extend_from_slice(&[0x00, 0x10]); // Total Extensions length: 16
// Extension: zero length but claims more? // Extension: zero length but claims more?
// If our parser didn't advance, it might loop. // If our parser didn't advance, it might loop.
// Telemt uses `pos += 4 + elen;` so it always advances. // Telemt uses `pos += 4 + elen;` so it always advances.
h.extend_from_slice(&[0x12, 0x34]); // Unknown type h.extend_from_slice(&[0x12, 0x34]); // Unknown type
h.extend_from_slice(&[0x00, 0x00]); // Length 0 h.extend_from_slice(&[0x00, 0x00]); // Length 0
// Fill the rest with garbage // Fill the rest with garbage
h.extend_from_slice(&[0x42; 12]); h.extend_from_slice(&[0x42; 12]);
// We expect it to finish without SNI found // We expect it to finish without SNI found
assert!(extract_sni_from_client_hello(&h).is_none()); assert!(extract_sni_from_client_hello(&h).is_none());
} }
@ -142,7 +143,7 @@ fn extract_sni_with_invalid_hostname_rejected() {
sni.push(0); sni.push(0);
sni.extend_from_slice(&(host.len() as u16).to_be_bytes()); sni.extend_from_slice(&(host.len() as u16).to_be_bytes());
sni.extend_from_slice(host); sni.extend_from_slice(host);
let mut h = vec![0x16, 0x03, 0x03, 0x00, 0x60]; // Record header let mut h = vec![0x16, 0x03, 0x03, 0x00, 0x60]; // Record header
h.push(0x01); // ClientHello h.push(0x01); // ClientHello
h.extend_from_slice(&[0x00, 0x00, 0x5C]); h.extend_from_slice(&[0x00, 0x00, 0x5C]);
@ -151,19 +152,16 @@ fn extract_sni_with_invalid_hostname_rejected() {
h.push(0); h.push(0);
h.extend_from_slice(&[0x00, 0x02, 0x13, 0x01]); h.extend_from_slice(&[0x00, 0x02, 0x13, 0x01]);
h.extend_from_slice(&[0x01, 0x00]); h.extend_from_slice(&[0x01, 0x00]);
let mut ext = Vec::new(); let mut ext = Vec::new();
ext.extend_from_slice(&0x0000u16.to_be_bytes()); ext.extend_from_slice(&0x0000u16.to_be_bytes());
ext.extend_from_slice(&(sni.len() as u16).to_be_bytes()); ext.extend_from_slice(&(sni.len() as u16).to_be_bytes());
ext.extend_from_slice(&sni); ext.extend_from_slice(&sni);
h.extend_from_slice(&(ext.len() as u16).to_be_bytes()); h.extend_from_slice(&(ext.len() as u16).to_be_bytes());
h.extend_from_slice(&ext); h.extend_from_slice(&ext);
assert!( assert!(extract_sni_from_client_hello(&h).is_none(), "Invalid SNI hostname must be rejected");
extract_sni_from_client_hello(&h).is_none(),
"Invalid SNI hostname must be rejected"
);
} }
// ------------------------------------------------------------------ // ------------------------------------------------------------------
@ -235,7 +233,7 @@ fn is_tls_handshake_robustness_against_probing() {
assert!(is_tls_handshake(&[0x16, 0x03, 0x01])); assert!(is_tls_handshake(&[0x16, 0x03, 0x01]));
// Valid TLS 1.2/1.3 ClientHello (Legacy Record Layer) // Valid TLS 1.2/1.3 ClientHello (Legacy Record Layer)
assert!(is_tls_handshake(&[0x16, 0x03, 0x03])); assert!(is_tls_handshake(&[0x16, 0x03, 0x03]));
// Invalid record type but matching version // Invalid record type but matching version
assert!(!is_tls_handshake(&[0x17, 0x03, 0x03])); assert!(!is_tls_handshake(&[0x17, 0x03, 0x03]));
// Plaintext HTTP request // Plaintext HTTP request
@ -249,12 +247,12 @@ fn validate_tls_handshake_at_time_strict_boundary() {
let secret = b"strict_boundary_secret_32_bytes_"; let secret = b"strict_boundary_secret_32_bytes_";
let secrets = vec![("u".to_string(), secret.to_vec())]; let secrets = vec![("u".to_string(), secret.to_vec())];
let now: i64 = 1_000_000_000; let now: i64 = 1_000_000_000;
// Boundary: exactly TIME_SKEW_MAX (120s past) // Boundary: exactly TIME_SKEW_MAX (120s past)
let ts_past = (now - TIME_SKEW_MAX) as u32; let ts_past = (now - TIME_SKEW_MAX) as u32;
let h = make_valid_tls_handshake_with_session_id(secret, ts_past, &[0x42; 32]); let h = make_valid_tls_handshake_with_session_id(secret, ts_past, &[0x42; 32]);
assert!(validate_tls_handshake_at_time(&h, &secrets, false, now).is_some()); assert!(validate_tls_handshake_at_time(&h, &secrets, false, now).is_some());
// Boundary + 1s: should be rejected // Boundary + 1s: should be rejected
let ts_too_past = (now - TIME_SKEW_MAX - 1) as u32; let ts_too_past = (now - TIME_SKEW_MAX - 1) as u32;
let h2 = make_valid_tls_handshake_with_session_id(secret, ts_too_past, &[0x42; 32]); let h2 = make_valid_tls_handshake_with_session_id(secret, ts_too_past, &[0x42; 32]);
@ -270,14 +268,14 @@ fn extract_sni_with_duplicate_extensions_rejected() {
sni1.push(0); sni1.push(0);
sni1.extend_from_slice(&(host1.len() as u16).to_be_bytes()); sni1.extend_from_slice(&(host1.len() as u16).to_be_bytes());
sni1.extend_from_slice(host1); sni1.extend_from_slice(host1);
let host2 = b"second.com"; let host2 = b"second.com";
let mut sni2 = Vec::new(); let mut sni2 = Vec::new();
sni2.extend_from_slice(&((host2.len() + 3) as u16).to_be_bytes()); sni2.extend_from_slice(&((host2.len() + 3) as u16).to_be_bytes());
sni2.push(0); sni2.push(0);
sni2.extend_from_slice(&(host2.len() as u16).to_be_bytes()); sni2.extend_from_slice(&(host2.len() as u16).to_be_bytes());
sni2.extend_from_slice(host2); sni2.extend_from_slice(host2);
let mut ext = Vec::new(); let mut ext = Vec::new();
// Ext 1: SNI // Ext 1: SNI
ext.extend_from_slice(&0x0000u16.to_be_bytes()); ext.extend_from_slice(&0x0000u16.to_be_bytes());
@ -287,7 +285,7 @@ fn extract_sni_with_duplicate_extensions_rejected() {
ext.extend_from_slice(&0x0000u16.to_be_bytes()); ext.extend_from_slice(&0x0000u16.to_be_bytes());
ext.extend_from_slice(&(sni2.len() as u16).to_be_bytes()); ext.extend_from_slice(&(sni2.len() as u16).to_be_bytes());
ext.extend_from_slice(&sni2); ext.extend_from_slice(&sni2);
let mut body = Vec::new(); let mut body = Vec::new();
body.extend_from_slice(&[0x03, 0x03]); body.extend_from_slice(&[0x03, 0x03]);
body.extend_from_slice(&[0u8; 32]); body.extend_from_slice(&[0u8; 32]);
@ -308,7 +306,7 @@ fn extract_sni_with_duplicate_extensions_rejected() {
h.extend_from_slice(&[0x03, 0x03]); h.extend_from_slice(&[0x03, 0x03]);
h.extend_from_slice(&(handshake.len() as u16).to_be_bytes()); h.extend_from_slice(&(handshake.len() as u16).to_be_bytes());
h.extend_from_slice(&handshake); h.extend_from_slice(&handshake);
// Duplicate SNI extensions are ambiguous and must fail closed. // Duplicate SNI extensions are ambiguous and must fail closed.
assert!(extract_sni_from_client_hello(&h).is_none()); assert!(extract_sni_from_client_hello(&h).is_none());
} }
@ -319,26 +317,21 @@ fn extract_alpn_with_malformed_list_rejected() {
alpn_payload.extend_from_slice(&0x0005u16.to_be_bytes()); // Total len 5 alpn_payload.extend_from_slice(&0x0005u16.to_be_bytes()); // Total len 5
alpn_payload.push(10); // Labeled len 10 (OVERFLOWS total 5) alpn_payload.push(10); // Labeled len 10 (OVERFLOWS total 5)
alpn_payload.extend_from_slice(b"h2"); alpn_payload.extend_from_slice(b"h2");
let mut ext = Vec::new(); let mut ext = Vec::new();
ext.extend_from_slice(&0x0010u16.to_be_bytes()); // Type: ALPN (16) ext.extend_from_slice(&0x0010u16.to_be_bytes()); // Type: ALPN (16)
ext.extend_from_slice(&(alpn_payload.len() as u16).to_be_bytes()); ext.extend_from_slice(&(alpn_payload.len() as u16).to_be_bytes());
ext.extend_from_slice(&alpn_payload); ext.extend_from_slice(&alpn_payload);
let mut h = vec![ let mut h = vec![0x16, 0x03, 0x03, 0x00, 0x40, 0x01, 0x00, 0x00, 0x3C, 0x03, 0x03];
0x16, 0x03, 0x03, 0x00, 0x40, 0x01, 0x00, 0x00, 0x3C, 0x03, 0x03,
];
h.extend_from_slice(&[0u8; 32]); h.extend_from_slice(&[0u8; 32]);
h.push(0); h.push(0);
h.extend_from_slice(&[0x00, 0x02, 0x13, 0x01, 0x01, 0x00]); h.extend_from_slice(&[0x00, 0x02, 0x13, 0x01, 0x01, 0x00]);
h.extend_from_slice(&(ext.len() as u16).to_be_bytes()); h.extend_from_slice(&(ext.len() as u16).to_be_bytes());
h.extend_from_slice(&ext); h.extend_from_slice(&ext);
let res = extract_alpn_from_client_hello(&h); let res = extract_alpn_from_client_hello(&h);
assert!( assert!(res.is_empty(), "Malformed ALPN list must return empty or fail");
res.is_empty(),
"Malformed ALPN list must return empty or fail"
);
} }
#[test] #[test]
@ -350,9 +343,9 @@ fn extract_sni_with_huge_extension_header_rejected() {
h.extend_from_slice(&[0u8; 32]); h.extend_from_slice(&[0u8; 32]);
h.push(0); h.push(0);
h.extend_from_slice(&[0x00, 0x02, 0x13, 0x01, 0x01, 0x00]); h.extend_from_slice(&[0x00, 0x02, 0x13, 0x01, 0x01, 0x00]);
// Extensions start // Extensions start
h.extend_from_slice(&[0xFF, 0xFF]); // Total extensions: 65535 (OVERFLOWS everything) h.extend_from_slice(&[0xFF, 0xFF]); // Total extensions: 65535 (OVERFLOWS everything)
assert!(extract_sni_from_client_hello(&h).is_none()); assert!(extract_sni_from_client_hello(&h).is_none());
} }

View File

@ -84,10 +84,7 @@ fn make_valid_client_hello_record(host: &str, alpn_protocols: &[&[u8]]) -> Vec<u
#[test] #[test]
fn client_hello_fuzz_corpus_never_panics_or_accepts_corruption() { fn client_hello_fuzz_corpus_never_panics_or_accepts_corruption() {
let valid = make_valid_client_hello_record("example.com", &[b"h2", b"http/1.1"]); let valid = make_valid_client_hello_record("example.com", &[b"h2", b"http/1.1"]);
assert_eq!( assert_eq!(extract_sni_from_client_hello(&valid).as_deref(), Some("example.com"));
extract_sni_from_client_hello(&valid).as_deref(),
Some("example.com")
);
assert_eq!( assert_eq!(
extract_alpn_from_client_hello(&valid), extract_alpn_from_client_hello(&valid),
vec![b"h2".to_vec(), b"http/1.1".to_vec()] vec![b"h2".to_vec(), b"http/1.1".to_vec()]
@ -124,14 +121,8 @@ fn client_hello_fuzz_corpus_never_panics_or_accepts_corruption() {
continue; continue;
} }
assert!( assert!(extract_sni_from_client_hello(input).is_none(), "corpus item {idx} must fail closed for SNI");
extract_sni_from_client_hello(input).is_none(), assert!(extract_alpn_from_client_hello(input).is_empty(), "corpus item {idx} must fail closed for ALPN");
"corpus item {idx} must fail closed for SNI"
);
assert!(
extract_alpn_from_client_hello(input).is_empty(),
"corpus item {idx} must fail closed for ALPN"
);
} }
} }
@ -172,9 +163,7 @@ fn tls_handshake_fuzz_corpus_never_panics_and_rejects_digest_mutations() {
for _ in 0..32 { for _ in 0..32 {
let mut mutated = base.clone(); let mut mutated = base.clone();
for _ in 0..2 { for _ in 0..2 {
seed = seed seed = seed.wrapping_mul(2862933555777941757).wrapping_add(3037000493);
.wrapping_mul(2862933555777941757)
.wrapping_add(3037000493);
let idx = TLS_DIGEST_POS + (seed as usize % TLS_DIGEST_LEN); let idx = TLS_DIGEST_POS + (seed as usize % TLS_DIGEST_LEN);
mutated[idx] ^= ((seed >> 17) as u8).wrapping_add(1); mutated[idx] ^= ((seed >> 17) as u8).wrapping_add(1);
} }
@ -182,13 +171,9 @@ fn tls_handshake_fuzz_corpus_never_panics_and_rejects_digest_mutations() {
} }
for (idx, handshake) in corpus.iter().enumerate() { for (idx, handshake) in corpus.iter().enumerate() {
let result = let result = catch_unwind(|| validate_tls_handshake_at_time(handshake, &secrets, false, now));
catch_unwind(|| validate_tls_handshake_at_time(handshake, &secrets, false, now));
assert!(result.is_ok(), "corpus item {idx} must not panic"); assert!(result.is_ok(), "corpus item {idx} must not panic");
assert!( assert!(result.unwrap().is_none(), "corpus item {idx} must fail closed");
result.unwrap().is_none(),
"corpus item {idx} must fail closed"
);
} }
} }

View File

@ -1,37 +0,0 @@
use super::*;
#[test]
fn extension_builder_fails_closed_on_u16_length_overflow() {
let builder = TlsExtensionBuilder {
extensions: vec![0u8; (u16::MAX as usize) + 1],
};
let built = builder.build();
assert!(
built.is_empty(),
"oversized extension blob must fail closed instead of truncating length field"
);
}
#[test]
fn server_hello_builder_fails_closed_on_session_id_len_overflow() {
let builder = ServerHelloBuilder {
random: [0u8; 32],
session_id: vec![0xAB; (u8::MAX as usize) + 1],
cipher_suite: cipher_suite::TLS_AES_128_GCM_SHA256,
compression: 0,
extensions: TlsExtensionBuilder::new(),
};
let message = builder.build_message();
let record = builder.build_record();
assert!(
message.is_empty(),
"session_id length overflow must fail closed in message builder"
);
assert!(
record.is_empty(),
"session_id length overflow must fail closed in record builder"
);
}

View File

@ -1,9 +1,7 @@
use super::*; use super::*;
use crate::crypto::sha256_hmac; use crate::crypto::sha256_hmac;
use crate::tls_front::emulator::build_emulated_server_hello; use crate::tls_front::emulator::build_emulated_server_hello;
use crate::tls_front::types::{ use crate::tls_front::types::{CachedTlsData, ParsedServerHello, TlsBehaviorProfile, TlsProfileSource};
CachedTlsData, ParsedServerHello, TlsBehaviorProfile, TlsProfileSource,
};
use std::time::SystemTime; use std::time::SystemTime;
/// Build a TLS-handshake-like buffer that contains a valid HMAC digest /// Build a TLS-handshake-like buffer that contains a valid HMAC digest
@ -41,7 +39,8 @@ fn make_valid_tls_handshake_with_session_id(
digest[28 + i] ^= ts[i]; digest[28 + i] ^= ts[i];
} }
handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].copy_from_slice(&digest); handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN]
.copy_from_slice(&digest);
handshake handshake
} }
@ -181,10 +180,7 @@ fn second_user_in_list_found_when_first_does_not_match() {
("user_b".to_string(), secret_b.to_vec()), ("user_b".to_string(), secret_b.to_vec()),
]; ];
let result = validate_tls_handshake(&handshake, &secrets, true); let result = validate_tls_handshake(&handshake, &secrets, true);
assert!( assert!(result.is_some(), "user_b must be found even though user_a comes first");
result.is_some(),
"user_b must be found even though user_a comes first"
);
assert_eq!(result.unwrap().user, "user_b"); assert_eq!(result.unwrap().user, "user_b");
} }
@ -432,7 +428,8 @@ fn censor_probe_random_digests_all_rejected() {
let mut h = vec![0x42u8; min_len]; let mut h = vec![0x42u8; min_len];
h[TLS_DIGEST_POS + TLS_DIGEST_LEN] = session_id_len as u8; h[TLS_DIGEST_POS + TLS_DIGEST_LEN] = session_id_len as u8;
let rand_digest = rng.bytes(TLS_DIGEST_LEN); let rand_digest = rng.bytes(TLS_DIGEST_LEN);
h[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].copy_from_slice(&rand_digest); h[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN]
.copy_from_slice(&rand_digest);
assert!( assert!(
validate_tls_handshake(&h, &secrets, true).is_none(), validate_tls_handshake(&h, &secrets, true).is_none(),
"Random digest at attempt {attempt} must not match" "Random digest at attempt {attempt} must not match"
@ -556,7 +553,8 @@ fn system_time_before_unix_epoch_is_rejected_without_panic() {
fn system_time_far_future_overflowing_i64_returns_none() { fn system_time_far_future_overflowing_i64_returns_none() {
// i64::MAX + 1 seconds past epoch overflows i64 when cast naively with `as`. // i64::MAX + 1 seconds past epoch overflows i64 when cast naively with `as`.
let overflow_secs = u64::try_from(i64::MAX).unwrap() + 1; let overflow_secs = u64::try_from(i64::MAX).unwrap() + 1;
if let Some(far_future) = UNIX_EPOCH.checked_add(std::time::Duration::from_secs(overflow_secs)) if let Some(far_future) =
UNIX_EPOCH.checked_add(std::time::Duration::from_secs(overflow_secs))
{ {
assert!( assert!(
system_time_to_unix_secs(far_future).is_none(), system_time_to_unix_secs(far_future).is_none(),
@ -622,10 +620,7 @@ fn appended_trailing_byte_causes_rejection() {
let mut h = make_valid_tls_handshake(secret, 0); let mut h = make_valid_tls_handshake(secret, 0);
let secrets = vec![("u".to_string(), secret.to_vec())]; let secrets = vec![("u".to_string(), secret.to_vec())];
assert!( assert!(validate_tls_handshake(&h, &secrets, true).is_some(), "baseline");
validate_tls_handshake(&h, &secrets, true).is_some(),
"baseline"
);
h.push(0x00); h.push(0x00);
assert!( assert!(
@ -652,7 +647,8 @@ fn zero_length_session_id_accepted() {
let computed = sha256_hmac(secret, &handshake); let computed = sha256_hmac(secret, &handshake);
// timestamp = 0 → ts XOR bytes are all zero → digest = computed unchanged. // timestamp = 0 → ts XOR bytes are all zero → digest = computed unchanged.
handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].copy_from_slice(&computed); handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN]
.copy_from_slice(&computed);
let secrets = vec![("u".to_string(), secret.to_vec())]; let secrets = vec![("u".to_string(), secret.to_vec())];
let result = validate_tls_handshake(&handshake, &secrets, true); let result = validate_tls_handshake(&handshake, &secrets, true);
@ -777,18 +773,10 @@ fn ignore_time_skew_explicitly_decouples_from_boot_time_cap() {
let secrets = vec![("u".to_string(), secret.to_vec())]; let secrets = vec![("u".to_string(), secret.to_vec())];
let cap_zero = validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, true, 0, 0); let cap_zero = validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, true, 0, 0);
let cap_nonzero = validate_tls_handshake_at_time_with_boot_cap( let cap_nonzero =
&h, validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, true, 0, BOOT_TIME_COMPAT_MAX_SECS);
&secrets,
true,
0,
BOOT_TIME_COMPAT_MAX_SECS,
);
assert!( assert!(cap_zero.is_some(), "ignore_time_skew=true must accept valid HMAC");
cap_zero.is_some(),
"ignore_time_skew=true must accept valid HMAC"
);
assert!( assert!(
cap_nonzero.is_some(), cap_nonzero.is_some(),
"ignore_time_skew path must not depend on boot-time cap" "ignore_time_skew path must not depend on boot-time cap"
@ -900,8 +888,8 @@ fn adversarial_skew_boundary_matrix_accepts_only_inclusive_window_when_boot_disa
let ts_i64 = now - offset; let ts_i64 = now - offset;
let ts = u32::try_from(ts_i64).expect("timestamp must fit u32 for test matrix"); let ts = u32::try_from(ts_i64).expect("timestamp must fit u32 for test matrix");
let h = make_valid_tls_handshake(secret, ts); let h = make_valid_tls_handshake(secret, ts);
let accepted = let accepted = validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, false, now, 0)
validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, false, now, 0).is_some(); .is_some();
let expected = (TIME_SKEW_MIN..=TIME_SKEW_MAX).contains(&offset); let expected = (TIME_SKEW_MIN..=TIME_SKEW_MAX).contains(&offset);
assert_eq!( assert_eq!(
accepted, expected, accepted, expected,
@ -929,8 +917,8 @@ fn light_fuzz_skew_window_rejects_outside_range_when_boot_disabled() {
let ts = u32::try_from(ts_i64).expect("timestamp must fit u32 for fuzz test"); let ts = u32::try_from(ts_i64).expect("timestamp must fit u32 for fuzz test");
let h = make_valid_tls_handshake(secret, ts); let h = make_valid_tls_handshake(secret, ts);
let accepted = let accepted = validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, false, now, 0)
validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, false, now, 0).is_some(); .is_some();
assert!( assert!(
!accepted, !accepted,
"offset {offset} must be rejected outside strict skew window" "offset {offset} must be rejected outside strict skew window"
@ -952,8 +940,8 @@ fn stress_boot_disabled_validation_matches_time_diff_oracle() {
let ts = s as u32; let ts = s as u32;
let h = make_valid_tls_handshake(secret, ts); let h = make_valid_tls_handshake(secret, ts);
let accepted = let accepted = validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, false, now, 0)
validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, false, now, 0).is_some(); .is_some();
let time_diff = now - i64::from(ts); let time_diff = now - i64::from(ts);
let expected = (TIME_SKEW_MIN..=TIME_SKEW_MAX).contains(&time_diff); let expected = (TIME_SKEW_MIN..=TIME_SKEW_MAX).contains(&time_diff);
assert_eq!( assert_eq!(
@ -972,10 +960,7 @@ fn integration_large_user_list_with_boot_disabled_finds_only_matching_user() {
let mut secrets = Vec::new(); let mut secrets = Vec::new();
for i in 0..512u32 { for i in 0..512u32 {
secrets.push(( secrets.push((format!("noise-{i}"), format!("noise-secret-{i}").into_bytes()));
format!("noise-{i}"),
format!("noise-secret-{i}").into_bytes(),
));
} }
secrets.push(("target-user".to_string(), target_secret.to_vec())); secrets.push(("target-user".to_string(), target_secret.to_vec()));
@ -1033,10 +1018,7 @@ fn u32_max_timestamp_accepted_with_ignore_time_skew() {
let secrets = vec![("u".to_string(), secret.to_vec())]; let secrets = vec![("u".to_string(), secret.to_vec())];
let result = validate_tls_handshake(&h, &secrets, true); let result = validate_tls_handshake(&h, &secrets, true);
assert!( assert!(result.is_some(), "u32::MAX timestamp must be accepted with ignore_time_skew=true");
result.is_some(),
"u32::MAX timestamp must be accepted with ignore_time_skew=true"
);
assert_eq!( assert_eq!(
result.unwrap().timestamp, result.unwrap().timestamp,
u32::MAX, u32::MAX,
@ -1168,17 +1150,16 @@ fn first_matching_user_wins_over_later_duplicate_secret() {
let secrets = vec![ let secrets = vec![
("decoy_1".to_string(), b"wrong_1".to_vec()), ("decoy_1".to_string(), b"wrong_1".to_vec()),
("winner".to_string(), shared.to_vec()), // first match ("winner".to_string(), shared.to_vec()), // first match
("decoy_2".to_string(), b"wrong_2".to_vec()), ("decoy_2".to_string(), b"wrong_2".to_vec()),
("loser".to_string(), shared.to_vec()), // second match — must not win ("loser".to_string(), shared.to_vec()), // second match — must not win
("decoy_3".to_string(), b"wrong_3".to_vec()), ("decoy_3".to_string(), b"wrong_3".to_vec()),
]; ];
let result = validate_tls_handshake(&h, &secrets, true); let result = validate_tls_handshake(&h, &secrets, true);
assert!(result.is_some()); assert!(result.is_some());
assert_eq!( assert_eq!(
result.unwrap().user, result.unwrap().user, "winner",
"winner",
"first matching user must be returned even when a later entry also matches" "first matching user must be returned even when a later entry also matches"
); );
} }
@ -1444,8 +1425,7 @@ fn test_build_server_hello_structure() {
assert!(response.len() > ccs_start + 6); assert!(response.len() > ccs_start + 6);
assert_eq!(response[ccs_start], TLS_RECORD_CHANGE_CIPHER); assert_eq!(response[ccs_start], TLS_RECORD_CHANGE_CIPHER);
let ccs_len = let ccs_len = 5 + u16::from_be_bytes([response[ccs_start + 3], response[ccs_start + 4]]) as usize;
5 + u16::from_be_bytes([response[ccs_start + 3], response[ccs_start + 4]]) as usize;
let app_start = ccs_start + ccs_len; let app_start = ccs_start + ccs_len;
assert!(response.len() > app_start + 5); assert!(response.len() > app_start + 5);
assert_eq!(response[app_start], TLS_RECORD_APPLICATION); assert_eq!(response[app_start], TLS_RECORD_APPLICATION);
@ -1749,10 +1729,7 @@ fn empty_secret_hmac_is_supported() {
let handshake = make_valid_tls_handshake(secret, 0); let handshake = make_valid_tls_handshake(secret, 0);
let secrets = vec![("empty".to_string(), secret.to_vec())]; let secrets = vec![("empty".to_string(), secret.to_vec())];
let result = validate_tls_handshake(&handshake, &secrets, true); let result = validate_tls_handshake(&handshake, &secrets, true);
assert!( assert!(result.is_some(), "Empty HMAC key must not panic and must validate when correct");
result.is_some(),
"Empty HMAC key must not panic and must validate when correct"
);
} }
#[test] #[test]
@ -1825,10 +1802,7 @@ fn server_hello_application_data_payload_varies_across_runs() {
let app_len = u16::from_be_bytes([response[app_pos + 3], response[app_pos + 4]]) as usize; let app_len = u16::from_be_bytes([response[app_pos + 3], response[app_pos + 4]]) as usize;
let payload = response[app_pos + 5..app_pos + 5 + app_len].to_vec(); let payload = response[app_pos + 5..app_pos + 5 + app_len].to_vec();
assert!( assert!(payload.iter().any(|&b| b != 0), "Payload must not be all-zero deterministic filler");
payload.iter().any(|&b| b != 0),
"Payload must not be all-zero deterministic filler"
);
unique_payloads.insert(payload); unique_payloads.insert(payload);
} }
@ -1872,13 +1846,7 @@ fn large_replay_window_does_not_expand_time_skew_acceptance() {
#[test] #[test]
fn parse_tls_record_header_accepts_tls_version_constant() { fn parse_tls_record_header_accepts_tls_version_constant() {
let header = [ let header = [TLS_RECORD_HANDSHAKE, TLS_VERSION[0], TLS_VERSION[1], 0x00, 0x2A];
TLS_RECORD_HANDSHAKE,
TLS_VERSION[0],
TLS_VERSION[1],
0x00,
0x2A,
];
let parsed = parse_tls_record_header(&header).expect("TLS_VERSION header should be accepted"); let parsed = parse_tls_record_header(&header).expect("TLS_VERSION header should be accepted");
assert_eq!(parsed.0, TLS_RECORD_HANDSHAKE); assert_eq!(parsed.0, TLS_RECORD_HANDSHAKE);
assert_eq!(parsed.1, 42); assert_eq!(parsed.1, 42);
@ -1900,10 +1868,7 @@ fn server_hello_clamps_fake_cert_len_lower_bound() {
let app_len = u16::from_be_bytes([response[app_pos + 3], response[app_pos + 4]]) as usize; let app_len = u16::from_be_bytes([response[app_pos + 3], response[app_pos + 4]]) as usize;
assert_eq!(response[app_pos], TLS_RECORD_APPLICATION); assert_eq!(response[app_pos], TLS_RECORD_APPLICATION);
assert_eq!( assert_eq!(app_len, 64, "fake cert payload must be clamped to minimum 64 bytes");
app_len, 64,
"fake cert payload must be clamped to minimum 64 bytes"
);
} }
#[test] #[test]
@ -1922,10 +1887,7 @@ fn server_hello_clamps_fake_cert_len_upper_bound() {
let app_len = u16::from_be_bytes([response[app_pos + 3], response[app_pos + 4]]) as usize; let app_len = u16::from_be_bytes([response[app_pos + 3], response[app_pos + 4]]) as usize;
assert_eq!(response[app_pos], TLS_RECORD_APPLICATION); assert_eq!(response[app_pos], TLS_RECORD_APPLICATION);
assert_eq!( assert_eq!(app_len, MAX_TLS_CIPHERTEXT_SIZE, "fake cert payload must be clamped to TLS record max bound");
app_len, MAX_TLS_CIPHERTEXT_SIZE,
"fake cert payload must be clamped to TLS record max bound"
);
} }
#[test] #[test]
@ -1936,15 +1898,7 @@ fn server_hello_new_session_ticket_count_matches_configuration() {
let rng = crate::crypto::SecureRandom::new(); let rng = crate::crypto::SecureRandom::new();
let tickets: u8 = 3; let tickets: u8 = 3;
let response = build_server_hello( let response = build_server_hello(secret, &client_digest, &session_id, 1024, &rng, None, tickets);
secret,
&client_digest,
&session_id,
1024,
&rng,
None,
tickets,
);
let mut pos = 0usize; let mut pos = 0usize;
let mut app_records = 0usize; let mut app_records = 0usize;
@ -1952,10 +1906,7 @@ fn server_hello_new_session_ticket_count_matches_configuration() {
let rtype = response[pos]; let rtype = response[pos];
let rlen = u16::from_be_bytes([response[pos + 3], response[pos + 4]]) as usize; let rlen = u16::from_be_bytes([response[pos + 3], response[pos + 4]]) as usize;
let next = pos + 5 + rlen; let next = pos + 5 + rlen;
assert!( assert!(next <= response.len(), "TLS record must stay inside response bounds");
next <= response.len(),
"TLS record must stay inside response bounds"
);
if rtype == TLS_RECORD_APPLICATION { if rtype == TLS_RECORD_APPLICATION {
app_records += 1; app_records += 1;
} }
@ -1976,15 +1927,7 @@ fn server_hello_new_session_ticket_count_is_safely_capped() {
let session_id = vec![0x54; 32]; let session_id = vec![0x54; 32];
let rng = crate::crypto::SecureRandom::new(); let rng = crate::crypto::SecureRandom::new();
let response = build_server_hello( let response = build_server_hello(secret, &client_digest, &session_id, 1024, &rng, None, u8::MAX);
secret,
&client_digest,
&session_id,
1024,
&rng,
None,
u8::MAX,
);
let mut pos = 0usize; let mut pos = 0usize;
let mut app_records = 0usize; let mut app_records = 0usize;
@ -1992,10 +1935,7 @@ fn server_hello_new_session_ticket_count_is_safely_capped() {
let rtype = response[pos]; let rtype = response[pos];
let rlen = u16::from_be_bytes([response[pos + 3], response[pos + 4]]) as usize; let rlen = u16::from_be_bytes([response[pos + 3], response[pos + 4]]) as usize;
let next = pos + 5 + rlen; let next = pos + 5 + rlen;
assert!( assert!(next <= response.len(), "TLS record must stay inside response bounds");
next <= response.len(),
"TLS record must stay inside response bounds"
);
if rtype == TLS_RECORD_APPLICATION { if rtype == TLS_RECORD_APPLICATION {
app_records += 1; app_records += 1;
} }
@ -2003,7 +1943,8 @@ fn server_hello_new_session_ticket_count_is_safely_capped() {
} }
assert_eq!( assert_eq!(
app_records, 5, app_records,
5,
"response must cap ticket-like tail records to four plus one main application record" "response must cap ticket-like tail records to four plus one main application record"
); );
} }
@ -2031,14 +1972,10 @@ fn boot_time_handshake_replay_remains_blocked_after_cache_window_expires() {
std::thread::sleep(std::time::Duration::from_millis(70)); std::thread::sleep(std::time::Duration::from_millis(70));
let validation_after_expiry = let validation_after_expiry = validate_tls_handshake_with_replay_window(&handshake, &secrets, false, 2)
validate_tls_handshake_with_replay_window(&handshake, &secrets, false, 2) .expect("boot-time handshake must still cryptographically validate after cache expiry");
.expect("boot-time handshake must still cryptographically validate after cache expiry");
let digest_half_after_expiry = &validation_after_expiry.digest[..TLS_DIGEST_HALF_LEN]; let digest_half_after_expiry = &validation_after_expiry.digest[..TLS_DIGEST_HALF_LEN];
assert_eq!( assert_eq!(digest_half, digest_half_after_expiry, "replay key must be stable for same handshake");
digest_half, digest_half_after_expiry,
"replay key must be stable for same handshake"
);
assert!( assert!(
checker.check_and_add_tls_digest(digest_half_after_expiry), checker.check_and_add_tls_digest(digest_half_after_expiry),
@ -2069,9 +2006,8 @@ fn adversarial_boot_time_handshake_should_not_be_replayable_after_cache_expiry()
std::thread::sleep(std::time::Duration::from_millis(70)); std::thread::sleep(std::time::Duration::from_millis(70));
let validation_after_expiry = let validation_after_expiry = validate_tls_handshake_with_replay_window(&handshake, &secrets, false, 2)
validate_tls_handshake_with_replay_window(&handshake, &secrets, false, 2) .expect("boot-time handshake still validates cryptographically after cache expiry");
.expect("boot-time handshake still validates cryptographically after cache expiry");
let digest_half_after_expiry = &validation_after_expiry.digest[..TLS_DIGEST_HALF_LEN]; let digest_half_after_expiry = &validation_after_expiry.digest[..TLS_DIGEST_HALF_LEN];
assert_eq!( assert_eq!(
@ -2131,14 +2067,11 @@ fn light_fuzz_boot_time_timestamp_matrix_with_short_replay_window_obeys_boot_cap
let ts = (s as u32) % 8; let ts = (s as u32) % 8;
let handshake = make_valid_tls_handshake(secret, ts); let handshake = make_valid_tls_handshake(secret, ts);
let accepted = let accepted = validate_tls_handshake_with_replay_window(&handshake, &secrets, false, 2)
validate_tls_handshake_with_replay_window(&handshake, &secrets, false, 2).is_some(); .is_some();
if ts < 2 { if ts < 2 {
assert!( assert!(accepted, "timestamp {ts} must remain boot-time compatible under 2s cap");
accepted,
"timestamp {ts} must remain boot-time compatible under 2s cap"
);
} else { } else {
assert!( assert!(
!accepted, !accepted,
@ -2174,9 +2107,7 @@ fn server_hello_application_data_contains_alpn_marker_when_selected() {
let expected = [0x00u8, 0x10, 0x00, 0x05, 0x00, 0x03, 0x02, b'h', b'2']; let expected = [0x00u8, 0x10, 0x00, 0x05, 0x00, 0x03, 0x02, b'h', b'2'];
assert!( assert!(
app_payload app_payload.windows(expected.len()).any(|window| window == expected),
.windows(expected.len())
.any(|window| window == expected),
"first application payload must carry ALPN marker for selected protocol" "first application payload must carry ALPN marker for selected protocol"
); );
} }
@ -2206,10 +2137,7 @@ fn server_hello_ignores_oversized_alpn_and_still_caps_ticket_tail() {
let rtype = response[pos]; let rtype = response[pos];
let rlen = u16::from_be_bytes([response[pos + 3], response[pos + 4]]) as usize; let rlen = u16::from_be_bytes([response[pos + 3], response[pos + 4]]) as usize;
let next = pos + 5 + rlen; let next = pos + 5 + rlen;
assert!( assert!(next <= response.len(), "TLS record must stay inside response bounds");
next <= response.len(),
"TLS record must stay inside response bounds"
);
if rtype == TLS_RECORD_APPLICATION { if rtype == TLS_RECORD_APPLICATION {
app_records += 1; app_records += 1;
if first_app_payload.is_none() { if first_app_payload.is_none() {
@ -2218,9 +2146,7 @@ fn server_hello_ignores_oversized_alpn_and_still_caps_ticket_tail() {
} }
pos = next; pos = next;
} }
let marker = [ let marker = [0x00u8, 0x10, 0x00, 0x06, 0x00, 0x04, 0x03, b'x', b'x', b'x', b'x'];
0x00u8, 0x10, 0x00, 0x06, 0x00, 0x04, 0x03, b'x', b'x', b'x', b'x',
];
assert_eq!( assert_eq!(
app_records, 5, app_records, 5,
@ -2384,13 +2310,13 @@ fn light_fuzz_tls_header_classifier_and_parser_policy_consistency() {
&& header[1] == 0x03 && header[1] == 0x03
&& (header[2] == 0x01 || header[2] == 0x03); && (header[2] == 0x01 || header[2] == 0x03);
assert_eq!( assert_eq!(
classified, expected_classified, classified,
expected_classified,
"classifier policy mismatch for header {header:02x?}" "classifier policy mismatch for header {header:02x?}"
); );
let parsed = parse_tls_record_header(&header); let parsed = parse_tls_record_header(&header);
let expected_parsed = let expected_parsed = header[1] == 0x03 && (header[2] == 0x01 || header[2] == TLS_VERSION[1]);
header[1] == 0x03 && (header[2] == 0x01 || header[2] == TLS_VERSION[1]);
assert_eq!( assert_eq!(
parsed.is_some(), parsed.is_some(),
expected_parsed, expected_parsed,

View File

@ -1,4 +1,8 @@
use super::{MAX_TLS_CIPHERTEXT_SIZE, MAX_TLS_PLAINTEXT_SIZE, MIN_TLS_CLIENT_HELLO_SIZE}; use super::{
MAX_TLS_CIPHERTEXT_SIZE,
MAX_TLS_PLAINTEXT_SIZE,
MIN_TLS_CLIENT_HELLO_SIZE,
};
#[test] #[test]
fn tls_size_constants_match_rfc_8446() { fn tls_size_constants_match_rfc_8446() {

View File

@ -5,66 +5,11 @@
//! actually carries MTProto authentication data. //! actually carries MTProto authentication data.
#![allow(dead_code)] #![allow(dead_code)]
#![cfg_attr(not(test), forbid(clippy::undocumented_unsafe_blocks))]
#![cfg_attr(
not(test),
deny(
clippy::unwrap_used,
clippy::expect_used,
clippy::panic,
clippy::todo,
clippy::unimplemented,
clippy::correctness,
clippy::option_if_let_else,
clippy::or_fun_call,
clippy::branches_sharing_code,
clippy::single_option_map,
clippy::useless_let_if_seq,
clippy::redundant_locals,
clippy::cloned_ref_to_slice_refs,
unsafe_code,
clippy::await_holding_lock,
clippy::await_holding_refcell_ref,
clippy::debug_assert_with_mut_call,
clippy::macro_use_imports,
clippy::cast_ptr_alignment,
clippy::cast_lossless,
clippy::ptr_as_ptr,
clippy::large_stack_arrays,
clippy::same_functions_in_if_condition,
trivial_casts,
trivial_numeric_casts,
unused_extern_crates,
unused_import_braces,
rust_2018_idioms
)
)]
#![cfg_attr(
not(test),
allow(
clippy::use_self,
clippy::redundant_closure,
clippy::too_many_arguments,
clippy::doc_markdown,
clippy::missing_const_for_fn,
clippy::unnecessary_operation,
clippy::redundant_pub_crate,
clippy::derive_partial_eq_without_eq,
clippy::type_complexity,
clippy::new_ret_no_self,
clippy::cast_possible_truncation,
clippy::cast_possible_wrap,
clippy::significant_drop_tightening,
clippy::significant_drop_in_scrutinee,
clippy::float_cmp,
clippy::nursery
)
)]
use super::constants::*; use crate::crypto::{sha256_hmac, SecureRandom};
use crate::crypto::{SecureRandom, sha256_hmac};
#[cfg(test)] #[cfg(test)]
use crate::error::ProxyError; use crate::error::ProxyError;
use super::constants::*;
use std::time::{SystemTime, UNIX_EPOCH}; use std::time::{SystemTime, UNIX_EPOCH};
use subtle::ConstantTimeEq; use subtle::ConstantTimeEq;
use x25519_dalek::{X25519_BASEPOINT_BYTES, x25519}; use x25519_dalek::{X25519_BASEPOINT_BYTES, x25519};
@ -86,7 +31,7 @@ pub const TLS_DIGEST_HALF_LEN: usize = 16;
/// Operators with known clock-drifted clients should tune deployment config /// Operators with known clock-drifted clients should tune deployment config
/// (for example replay-window policy) to match their environment. /// (for example replay-window policy) to match their environment.
pub const TIME_SKEW_MIN: i64 = -2 * 60; // 2 minutes before pub const TIME_SKEW_MIN: i64 = -2 * 60; // 2 minutes before
pub const TIME_SKEW_MAX: i64 = 2 * 60; // 2 minutes after pub const TIME_SKEW_MAX: i64 = 2 * 60; // 2 minutes after
/// Maximum accepted boot-time timestamp (seconds) before skew checks are enforced. /// Maximum accepted boot-time timestamp (seconds) before skew checks are enforced.
pub const BOOT_TIME_MAX_SECS: u32 = 7 * 24 * 60 * 60; pub const BOOT_TIME_MAX_SECS: u32 = 7 * 24 * 60 * 60;
/// Hard cap for boot-time compatibility bypass to avoid oversized acceptance /// Hard cap for boot-time compatibility bypass to avoid oversized acceptance
@ -124,6 +69,7 @@ pub struct TlsValidation {
/// Client digest for response generation /// Client digest for response generation
pub digest: [u8; TLS_DIGEST_LEN], pub digest: [u8; TLS_DIGEST_LEN],
/// Timestamp extracted from digest /// Timestamp extracted from digest
pub timestamp: u32, pub timestamp: u32,
} }
@ -141,63 +87,60 @@ impl TlsExtensionBuilder {
extensions: Vec::with_capacity(128), extensions: Vec::with_capacity(128),
} }
} }
/// Add Key Share extension with X25519 key /// Add Key Share extension with X25519 key
fn add_key_share(&mut self, public_key: &[u8; 32]) -> &mut Self { fn add_key_share(&mut self, public_key: &[u8; 32]) -> &mut Self {
// Extension type: key_share (0x0033) // Extension type: key_share (0x0033)
self.extensions self.extensions.extend_from_slice(&extension_type::KEY_SHARE.to_be_bytes());
.extend_from_slice(&extension_type::KEY_SHARE.to_be_bytes());
// Key share entry: curve (2) + key_len (2) + key (32) = 36 bytes // Key share entry: curve (2) + key_len (2) + key (32) = 36 bytes
// Extension data length // Extension data length
let entry_len: u16 = 2 + 2 + 32; // curve + length + key let entry_len: u16 = 2 + 2 + 32; // curve + length + key
self.extensions.extend_from_slice(&entry_len.to_be_bytes()); self.extensions.extend_from_slice(&entry_len.to_be_bytes());
// Named curve: x25519 // Named curve: x25519
self.extensions self.extensions.extend_from_slice(&named_curve::X25519.to_be_bytes());
.extend_from_slice(&named_curve::X25519.to_be_bytes());
// Key length // Key length
self.extensions.extend_from_slice(&(32u16).to_be_bytes()); self.extensions.extend_from_slice(&(32u16).to_be_bytes());
// Key data // Key data
self.extensions.extend_from_slice(public_key); self.extensions.extend_from_slice(public_key);
self self
} }
/// Add Supported Versions extension /// Add Supported Versions extension
fn add_supported_versions(&mut self, version: u16) -> &mut Self { fn add_supported_versions(&mut self, version: u16) -> &mut Self {
// Extension type: supported_versions (0x002b) // Extension type: supported_versions (0x002b)
self.extensions self.extensions.extend_from_slice(&extension_type::SUPPORTED_VERSIONS.to_be_bytes());
.extend_from_slice(&extension_type::SUPPORTED_VERSIONS.to_be_bytes());
// Extension data: length (2) + version (2) // Extension data: length (2) + version (2)
self.extensions.extend_from_slice(&(2u16).to_be_bytes()); self.extensions.extend_from_slice(&(2u16).to_be_bytes());
// Selected version // Selected version
self.extensions.extend_from_slice(&version.to_be_bytes()); self.extensions.extend_from_slice(&version.to_be_bytes());
self self
} }
/// Build final extensions with length prefix /// Build final extensions with length prefix
fn build(self) -> Vec<u8> { fn build(self) -> Vec<u8> {
let Ok(len) = u16::try_from(self.extensions.len()) else {
return Vec::new();
};
let mut result = Vec::with_capacity(2 + self.extensions.len()); let mut result = Vec::with_capacity(2 + self.extensions.len());
// Extensions length (2 bytes) // Extensions length (2 bytes)
let len = self.extensions.len() as u16;
result.extend_from_slice(&len.to_be_bytes()); result.extend_from_slice(&len.to_be_bytes());
// Extensions data // Extensions data
result.extend_from_slice(&self.extensions); result.extend_from_slice(&self.extensions);
result result
} }
/// Get current extensions without length prefix (for calculation) /// Get current extensions without length prefix (for calculation)
fn as_bytes(&self) -> &[u8] { fn as_bytes(&self) -> &[u8] {
&self.extensions &self.extensions
} }
@ -229,12 +172,12 @@ impl ServerHelloBuilder {
extensions: TlsExtensionBuilder::new(), extensions: TlsExtensionBuilder::new(),
} }
} }
fn with_x25519_key(mut self, key: &[u8; 32]) -> Self { fn with_x25519_key(mut self, key: &[u8; 32]) -> Self {
self.extensions.add_key_share(key); self.extensions.add_key_share(key);
self self
} }
fn with_tls13_version(mut self) -> Self { fn with_tls13_version(mut self) -> Self {
// TLS 1.3 = 0x0304 // TLS 1.3 = 0x0304
self.extensions.add_supported_versions(0x0304); self.extensions.add_supported_versions(0x0304);
@ -243,14 +186,9 @@ impl ServerHelloBuilder {
/// Build ServerHello message (without record header) /// Build ServerHello message (without record header)
fn build_message(&self) -> Vec<u8> { fn build_message(&self) -> Vec<u8> {
let Ok(session_id_len) = u8::try_from(self.session_id.len()) else {
return Vec::new();
};
let extensions = self.extensions.extensions.clone(); let extensions = self.extensions.extensions.clone();
let Ok(extensions_len) = u16::try_from(extensions.len()) else { let extensions_len = extensions.len() as u16;
return Vec::new();
};
// Calculate total length // Calculate total length
let body_len = 2 + // version let body_len = 2 + // version
32 + // random 32 + // random
@ -258,67 +196,55 @@ impl ServerHelloBuilder {
2 + // cipher suite 2 + // cipher suite
1 + // compression 1 + // compression
2 + extensions.len(); // extensions length + data 2 + extensions.len(); // extensions length + data
if body_len > 0x00ff_ffff {
return Vec::new();
}
let mut message = Vec::with_capacity(4 + body_len); let mut message = Vec::with_capacity(4 + body_len);
// Handshake header // Handshake header
message.push(0x02); // ServerHello message type message.push(0x02); // ServerHello message type
// 3-byte length // 3-byte length
let Ok(body_len_u32) = u32::try_from(body_len) else { let len_bytes = (body_len as u32).to_be_bytes();
return Vec::new();
};
let len_bytes = body_len_u32.to_be_bytes();
message.extend_from_slice(&len_bytes[1..4]); message.extend_from_slice(&len_bytes[1..4]);
// Server version (TLS 1.2 in header, actual version in extension) // Server version (TLS 1.2 in header, actual version in extension)
message.extend_from_slice(&TLS_VERSION); message.extend_from_slice(&TLS_VERSION);
// Random (32 bytes) - placeholder, will be replaced with digest // Random (32 bytes) - placeholder, will be replaced with digest
message.extend_from_slice(&self.random); message.extend_from_slice(&self.random);
// Session ID // Session ID
message.push(session_id_len); message.push(self.session_id.len() as u8);
message.extend_from_slice(&self.session_id); message.extend_from_slice(&self.session_id);
// Cipher suite // Cipher suite
message.extend_from_slice(&self.cipher_suite); message.extend_from_slice(&self.cipher_suite);
// Compression method // Compression method
message.push(self.compression); message.push(self.compression);
// Extensions length // Extensions length
message.extend_from_slice(&extensions_len.to_be_bytes()); message.extend_from_slice(&extensions_len.to_be_bytes());
// Extensions data // Extensions data
message.extend_from_slice(&extensions); message.extend_from_slice(&extensions);
message message
} }
/// Build complete ServerHello TLS record /// Build complete ServerHello TLS record
fn build_record(&self) -> Vec<u8> { fn build_record(&self) -> Vec<u8> {
let message = self.build_message(); let message = self.build_message();
if message.is_empty() {
return Vec::new();
}
let Ok(message_len) = u16::try_from(message.len()) else {
return Vec::new();
};
let mut record = Vec::with_capacity(5 + message.len()); let mut record = Vec::with_capacity(5 + message.len());
// TLS record header // TLS record header
record.push(TLS_RECORD_HANDSHAKE); record.push(TLS_RECORD_HANDSHAKE);
record.extend_from_slice(&TLS_VERSION); record.extend_from_slice(&TLS_VERSION);
record.extend_from_slice(&message_len.to_be_bytes()); record.extend_from_slice(&(message.len() as u16).to_be_bytes());
// Message // Message
record.extend_from_slice(&message); record.extend_from_slice(&message);
record record
} }
} }
@ -330,6 +256,7 @@ impl ServerHelloBuilder {
/// Returns validation result if a matching user is found. /// Returns validation result if a matching user is found.
/// The result **must** be used — ignoring it silently bypasses authentication. /// The result **must** be used — ignoring it silently bypasses authentication.
#[must_use] #[must_use]
pub fn validate_tls_handshake( pub fn validate_tls_handshake(
handshake: &[u8], handshake: &[u8],
secrets: &[(String, Vec<u8>)], secrets: &[(String, Vec<u8>)],
@ -393,6 +320,7 @@ fn system_time_to_unix_secs(now: SystemTime) -> Option<i64> {
i64::try_from(d.as_secs()).ok() i64::try_from(d.as_secs()).ok()
} }
fn validate_tls_handshake_at_time( fn validate_tls_handshake_at_time(
handshake: &[u8], handshake: &[u8],
secrets: &[(String, Vec<u8>)], secrets: &[(String, Vec<u8>)],
@ -418,12 +346,12 @@ fn validate_tls_handshake_at_time_with_boot_cap(
if handshake.len() < TLS_DIGEST_POS + TLS_DIGEST_LEN + 1 { if handshake.len() < TLS_DIGEST_POS + TLS_DIGEST_LEN + 1 {
return None; return None;
} }
// Extract digest // Extract digest
let digest: [u8; TLS_DIGEST_LEN] = handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN] let digest: [u8; TLS_DIGEST_LEN] = handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN]
.try_into() .try_into()
.ok()?; .ok()?;
// Extract session ID // Extract session ID
let session_id_len_pos = TLS_DIGEST_POS + TLS_DIGEST_LEN; let session_id_len_pos = TLS_DIGEST_POS + TLS_DIGEST_LEN;
let session_id_len = handshake.get(session_id_len_pos).copied()? as usize; let session_id_len = handshake.get(session_id_len_pos).copied()? as usize;
@ -431,17 +359,17 @@ fn validate_tls_handshake_at_time_with_boot_cap(
return None; return None;
} }
let session_id_start = session_id_len_pos + 1; let session_id_start = session_id_len_pos + 1;
if handshake.len() < session_id_start + session_id_len { if handshake.len() < session_id_start + session_id_len {
return None; return None;
} }
let session_id = handshake[session_id_start..session_id_start + session_id_len].to_vec(); let session_id = handshake[session_id_start..session_id_start + session_id_len].to_vec();
// Build message for HMAC (with zeroed digest) // Build message for HMAC (with zeroed digest)
let mut msg = handshake.to_vec(); let mut msg = handshake.to_vec();
msg[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].fill(0); msg[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].fill(0);
let mut first_match: Option<(&String, u32)> = None; let mut first_match: Option<(&String, u32)> = None;
for (user, secret) in secrets { for (user, secret) in secrets {
@ -480,7 +408,7 @@ fn validate_tls_handshake_at_time_with_boot_cap(
} }
} }
} }
if first_match.is_none() { if first_match.is_none() {
first_match = Some((user, timestamp)); first_match = Some((user, timestamp));
} }
@ -525,30 +453,25 @@ pub fn build_server_hello(
const MAX_APP_DATA: usize = MAX_TLS_CIPHERTEXT_SIZE; const MAX_APP_DATA: usize = MAX_TLS_CIPHERTEXT_SIZE;
let fake_cert_len = fake_cert_len.clamp(MIN_APP_DATA, MAX_APP_DATA); let fake_cert_len = fake_cert_len.clamp(MIN_APP_DATA, MAX_APP_DATA);
let x25519_key = gen_fake_x25519_key(rng); let x25519_key = gen_fake_x25519_key(rng);
// Build ServerHello // Build ServerHello
let server_hello = ServerHelloBuilder::new(session_id.to_vec()) let server_hello = ServerHelloBuilder::new(session_id.to_vec())
.with_x25519_key(&x25519_key) .with_x25519_key(&x25519_key)
.with_tls13_version() .with_tls13_version()
.build_record(); .build_record();
// Build Change Cipher Spec record // Build Change Cipher Spec record
let change_cipher_spec = [ let change_cipher_spec = [
TLS_RECORD_CHANGE_CIPHER, TLS_RECORD_CHANGE_CIPHER,
TLS_VERSION[0], TLS_VERSION[0], TLS_VERSION[1],
TLS_VERSION[1], 0x00, 0x01, // length = 1
0x00, 0x01, // CCS byte
0x01, // length = 1
0x01, // CCS byte
]; ];
// Build first encrypted flight mimic as opaque ApplicationData bytes. // Build first encrypted flight mimic as opaque ApplicationData bytes.
// Embed a compact EncryptedExtensions-like ALPN block when selected. // Embed a compact EncryptedExtensions-like ALPN block when selected.
let mut fake_cert = Vec::with_capacity(fake_cert_len); let mut fake_cert = Vec::with_capacity(fake_cert_len);
if let Some(proto) = alpn if let Some(proto) = alpn.as_ref().filter(|p| !p.is_empty() && p.len() <= u8::MAX as usize) {
.as_ref()
.filter(|p| !p.is_empty() && p.len() <= u8::MAX as usize)
{
let proto_list_len = 1usize + proto.len(); let proto_list_len = 1usize + proto.len();
let ext_data_len = 2usize + proto_list_len; let ext_data_len = 2usize + proto_list_len;
let marker_len = 4usize + ext_data_len; let marker_len = 4usize + ext_data_len;
@ -573,7 +496,7 @@ pub fn build_server_hello(
// Fill ApplicationData with fully random bytes of desired length to avoid // Fill ApplicationData with fully random bytes of desired length to avoid
// deterministic DPI fingerprints (fixed inner content type markers). // deterministic DPI fingerprints (fixed inner content type markers).
app_data_record.extend_from_slice(&fake_cert); app_data_record.extend_from_slice(&fake_cert);
// Build optional NewSessionTicket records (TLS 1.3 handshake messages are encrypted; // Build optional NewSessionTicket records (TLS 1.3 handshake messages are encrypted;
// here we mimic with opaque ApplicationData records of plausible size). // here we mimic with opaque ApplicationData records of plausible size).
let mut tickets = Vec::new(); let mut tickets = Vec::new();
@ -592,10 +515,7 @@ pub fn build_server_hello(
// Combine all records // Combine all records
let mut response = Vec::with_capacity( let mut response = Vec::with_capacity(
server_hello.len() server_hello.len() + change_cipher_spec.len() + app_data_record.len() + tickets.iter().map(|r| r.len()).sum::<usize>()
+ change_cipher_spec.len()
+ app_data_record.len()
+ tickets.iter().map(|r| r.len()).sum::<usize>(),
); );
response.extend_from_slice(&server_hello); response.extend_from_slice(&server_hello);
response.extend_from_slice(&change_cipher_spec); response.extend_from_slice(&change_cipher_spec);
@ -603,17 +523,18 @@ pub fn build_server_hello(
for t in &tickets { for t in &tickets {
response.extend_from_slice(t); response.extend_from_slice(t);
} }
// Compute HMAC for the response // Compute HMAC for the response
let mut hmac_input = Vec::with_capacity(TLS_DIGEST_LEN + response.len()); let mut hmac_input = Vec::with_capacity(TLS_DIGEST_LEN + response.len());
hmac_input.extend_from_slice(client_digest); hmac_input.extend_from_slice(client_digest);
hmac_input.extend_from_slice(&response); hmac_input.extend_from_slice(&response);
let response_digest = sha256_hmac(secret, &hmac_input); let response_digest = sha256_hmac(secret, &hmac_input);
// Insert computed digest into ServerHello // Insert computed digest into ServerHello
// Position: record header (5) + message type (1) + length (3) + version (2) = 11 // Position: record header (5) + message type (1) + length (3) + version (2) = 11
response[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].copy_from_slice(&response_digest); response[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN]
.copy_from_slice(&response_digest);
response response
} }
@ -690,19 +611,18 @@ pub fn extract_sni_from_client_hello(handshake: &[u8]) -> Option<String> {
let sn_end = std::cmp::min(sn_pos + list_len, pos + elen); let sn_end = std::cmp::min(sn_pos + list_len, pos + elen);
while sn_pos + 3 <= sn_end { while sn_pos + 3 <= sn_end {
let name_type = handshake[sn_pos]; let name_type = handshake[sn_pos];
let name_len = let name_len = u16::from_be_bytes([handshake[sn_pos + 1], handshake[sn_pos + 2]]) as usize;
u16::from_be_bytes([handshake[sn_pos + 1], handshake[sn_pos + 2]]) as usize;
sn_pos += 3; sn_pos += 3;
if sn_pos + name_len > sn_end { if sn_pos + name_len > sn_end {
break; break;
} }
if name_type == 0 if name_type == 0 && name_len > 0
&& name_len > 0
&& let Ok(host) = std::str::from_utf8(&handshake[sn_pos..sn_pos + name_len]) && let Ok(host) = std::str::from_utf8(&handshake[sn_pos..sn_pos + name_len])
&& is_valid_sni_hostname(host)
{ {
extracted_sni = Some(host.to_string()); if is_valid_sni_hostname(host) {
break; extracted_sni = Some(host.to_string());
break;
}
} }
sn_pos += name_len; sn_pos += name_len;
} }
@ -759,49 +679,35 @@ pub fn extract_alpn_from_client_hello(handshake: &[u8]) -> Vec<Vec<u8>> {
} }
pos += 4; // type + len pos += 4; // type + len
pos += 2 + 32; // version + random pos += 2 + 32; // version + random
if pos >= handshake.len() { if pos >= handshake.len() { return Vec::new(); }
return Vec::new();
}
let session_id_len = *handshake.get(pos).unwrap_or(&0) as usize; let session_id_len = *handshake.get(pos).unwrap_or(&0) as usize;
pos += 1 + session_id_len; pos += 1 + session_id_len;
if pos + 2 > handshake.len() { if pos + 2 > handshake.len() { return Vec::new(); }
return Vec::new(); let cipher_len = u16::from_be_bytes([handshake[pos], handshake[pos+1]]) as usize;
}
let cipher_len = u16::from_be_bytes([handshake[pos], handshake[pos + 1]]) as usize;
pos += 2 + cipher_len; pos += 2 + cipher_len;
if pos >= handshake.len() { if pos >= handshake.len() { return Vec::new(); }
return Vec::new();
}
let comp_len = *handshake.get(pos).unwrap_or(&0) as usize; let comp_len = *handshake.get(pos).unwrap_or(&0) as usize;
pos += 1 + comp_len; pos += 1 + comp_len;
if pos + 2 > handshake.len() { if pos + 2 > handshake.len() { return Vec::new(); }
return Vec::new(); let ext_len = u16::from_be_bytes([handshake[pos], handshake[pos+1]]) as usize;
}
let ext_len = u16::from_be_bytes([handshake[pos], handshake[pos + 1]]) as usize;
pos += 2; pos += 2;
let ext_end = pos + ext_len; let ext_end = pos + ext_len;
if ext_end > handshake.len() { if ext_end > handshake.len() { return Vec::new(); }
return Vec::new();
}
let mut out = Vec::new(); let mut out = Vec::new();
while pos + 4 <= ext_end { while pos + 4 <= ext_end {
let etype = u16::from_be_bytes([handshake[pos], handshake[pos + 1]]); let etype = u16::from_be_bytes([handshake[pos], handshake[pos+1]]);
let elen = u16::from_be_bytes([handshake[pos + 2], handshake[pos + 3]]) as usize; let elen = u16::from_be_bytes([handshake[pos+2], handshake[pos+3]]) as usize;
pos += 4; pos += 4;
if pos + elen > ext_end { if pos + elen > ext_end { break; }
break;
}
if etype == extension_type::ALPN && elen >= 3 { if etype == extension_type::ALPN && elen >= 3 {
let list_len = u16::from_be_bytes([handshake[pos], handshake[pos + 1]]) as usize; let list_len = u16::from_be_bytes([handshake[pos], handshake[pos+1]]) as usize;
let mut lp = pos + 2; let mut lp = pos + 2;
let list_end = (pos + 2).saturating_add(list_len).min(pos + elen); let list_end = (pos + 2).saturating_add(list_len).min(pos + elen);
while lp < list_end { while lp < list_end {
let plen = handshake[lp] as usize; let plen = handshake[lp] as usize;
lp += 1; lp += 1;
if lp + plen > list_end { if lp + plen > list_end { break; }
break; out.push(handshake[lp..lp+plen].to_vec());
}
out.push(handshake[lp..lp + plen].to_vec());
lp += plen; lp += plen;
} }
break; break;
@ -811,28 +717,30 @@ pub fn extract_alpn_from_client_hello(handshake: &[u8]) -> Vec<Vec<u8>> {
out out
} }
/// Check if bytes look like a TLS ClientHello /// Check if bytes look like a TLS ClientHello
pub fn is_tls_handshake(first_bytes: &[u8]) -> bool { pub fn is_tls_handshake(first_bytes: &[u8]) -> bool {
if first_bytes.len() < 3 { if first_bytes.len() < 3 {
return false; return false;
} }
// TLS ClientHello commonly uses legacy record versions 0x0301 or 0x0303. // TLS ClientHello commonly uses legacy record versions 0x0301 or 0x0303.
first_bytes[0] == TLS_RECORD_HANDSHAKE first_bytes[0] == TLS_RECORD_HANDSHAKE
&& first_bytes[1] == 0x03 && first_bytes[1] == 0x03
&& (first_bytes[2] == 0x01 || first_bytes[2] == 0x03) && (first_bytes[2] == 0x01 || first_bytes[2] == 0x03)
} }
/// Parse TLS record header, returns (record_type, length) /// Parse TLS record header, returns (record_type, length)
pub fn parse_tls_record_header(header: &[u8; 5]) -> Option<(u8, u16)> { pub fn parse_tls_record_header(header: &[u8; 5]) -> Option<(u8, u16)> {
let record_type = header[0]; let record_type = header[0];
let version = [header[1], header[2]]; let version = [header[1], header[2]];
// We accept both TLS 1.0 header (for ClientHello) and TLS 1.2/1.3 // We accept both TLS 1.0 header (for ClientHello) and TLS 1.2/1.3
if version != [0x03, 0x01] && version != TLS_VERSION { if version != [0x03, 0x01] && version != TLS_VERSION {
return None; return None;
} }
let length = u16::from_be_bytes([header[3], header[4]]); let length = u16::from_be_bytes([header[3], header[4]]);
Some((record_type, length)) Some((record_type, length))
} }
@ -848,7 +756,7 @@ fn validate_server_hello_structure(data: &[u8]) -> Result<(), ProxyError> {
version: [0, 0], version: [0, 0],
}); });
} }
// Check record header // Check record header
if data[0] != TLS_RECORD_HANDSHAKE { if data[0] != TLS_RECORD_HANDSHAKE {
return Err(ProxyError::InvalidTlsRecord { return Err(ProxyError::InvalidTlsRecord {
@ -856,7 +764,7 @@ fn validate_server_hello_structure(data: &[u8]) -> Result<(), ProxyError> {
version: [data[1], data[2]], version: [data[1], data[2]],
}); });
} }
// Check version // Check version
if data[1..3] != TLS_VERSION { if data[1..3] != TLS_VERSION {
return Err(ProxyError::InvalidTlsRecord { return Err(ProxyError::InvalidTlsRecord {
@ -864,34 +772,31 @@ fn validate_server_hello_structure(data: &[u8]) -> Result<(), ProxyError> {
version: [data[1], data[2]], version: [data[1], data[2]],
}); });
} }
// Check record length // Check record length
let record_len = u16::from_be_bytes([data[3], data[4]]) as usize; let record_len = u16::from_be_bytes([data[3], data[4]]) as usize;
if data.len() < 5 + record_len { if data.len() < 5 + record_len {
return Err(ProxyError::InvalidHandshake(format!( return Err(ProxyError::InvalidHandshake(
"ServerHello record truncated: expected {}, got {}", format!("ServerHello record truncated: expected {}, got {}",
5 + record_len, 5 + record_len, data.len())
data.len() ));
)));
} }
// Check message type // Check message type
if data[5] != 0x02 { if data[5] != 0x02 {
return Err(ProxyError::InvalidHandshake(format!( return Err(ProxyError::InvalidHandshake(
"Expected ServerHello (0x02), got 0x{:02x}", format!("Expected ServerHello (0x02), got 0x{:02x}", data[5])
data[5] ));
)));
} }
// Parse message length // Parse message length
let msg_len = u32::from_be_bytes([0, data[6], data[7], data[8]]) as usize; let msg_len = u32::from_be_bytes([0, data[6], data[7], data[8]]) as usize;
if msg_len + 4 != record_len { if msg_len + 4 != record_len {
return Err(ProxyError::InvalidHandshake(format!( return Err(ProxyError::InvalidHandshake(
"Message length mismatch: {} + 4 != {}", format!("Message length mismatch: {} + 4 != {}", msg_len, record_len)
msg_len, record_len ));
)));
} }
Ok(()) Ok(())
} }
@ -901,7 +806,7 @@ fn validate_server_hello_structure(data: &[u8]) -> Result<(), ProxyError> {
/// Using `static_assertions` ensures these can never silently break across /// Using `static_assertions` ensures these can never silently break across
/// refactors without a compile error. /// refactors without a compile error.
mod compile_time_security_checks { mod compile_time_security_checks {
use super::{TLS_DIGEST_HALF_LEN, TLS_DIGEST_LEN}; use super::{TLS_DIGEST_LEN, TLS_DIGEST_HALF_LEN};
use static_assertions::const_assert; use static_assertions::const_assert;
// The digest must be exactly one SHA-256 output. // The digest must be exactly one SHA-256 output.
@ -929,7 +834,3 @@ mod adversarial_tests;
#[cfg(test)] #[cfg(test)]
#[path = "tests/tls_fuzz_security_tests.rs"] #[path = "tests/tls_fuzz_security_tests.rs"]
mod fuzz_security_tests; mod fuzz_security_tests;
#[cfg(test)]
#[path = "tests/tls_length_cast_hardening_security_tests.rs"]
mod length_cast_hardening_security_tests;

View File

@ -1,8 +1,3 @@
#![allow(dead_code)]
// Adaptive buffer policy is staged and retained for deterministic rollout.
// Keep definitions compiled for compatibility and security test scaffolding.
use dashmap::DashMap; use dashmap::DashMap;
use std::cmp::max; use std::cmp::max;
use std::sync::OnceLock; use std::sync::OnceLock;
@ -175,8 +170,7 @@ impl SessionAdaptiveController {
return self.promote(TierTransitionReason::SoftConfirmed, 0); return self.promote(TierTransitionReason::SoftConfirmed, 0);
} }
let demote_candidate = let demote_candidate = self.throughput_ema_bps < THROUGHPUT_DOWN_BPS && !tier2_now && !hard_now;
self.throughput_ema_bps < THROUGHPUT_DOWN_BPS && !tier2_now && !hard_now;
if demote_candidate { if demote_candidate {
self.quiet_ticks = self.quiet_ticks.saturating_add(1); self.quiet_ticks = self.quiet_ticks.saturating_add(1);
if self.quiet_ticks >= QUIET_DEMOTE_TICKS { if self.quiet_ticks >= QUIET_DEMOTE_TICKS {
@ -259,7 +253,10 @@ pub fn record_user_tier(user: &str, tier: AdaptiveTier) {
}; };
return; return;
} }
profiles().insert(user.to_string(), UserAdaptiveProfile { tier, seen_at: now }); profiles().insert(
user.to_string(),
UserAdaptiveProfile { tier, seen_at: now },
);
} }
pub fn direct_copy_buffers_for_tier( pub fn direct_copy_buffers_for_tier(
@ -342,7 +339,10 @@ mod tests {
sample( sample(
300_000, // ~9.6 Mbps 300_000, // ~9.6 Mbps
320_000, // incoming > outgoing to confirm tier2 320_000, // incoming > outgoing to confirm tier2
250_000, 10, 0, 0, 250_000,
10,
0,
0,
), ),
tick_secs, tick_secs,
); );
@ -358,7 +358,10 @@ mod tests {
fn test_hard_promotion_on_pending_pressure() { fn test_hard_promotion_on_pending_pressure() {
let mut ctrl = SessionAdaptiveController::new(AdaptiveTier::Base); let mut ctrl = SessionAdaptiveController::new(AdaptiveTier::Base);
let transition = ctrl let transition = ctrl
.observe(sample(10_000, 20_000, 10_000, 4, 1, 3), 0.25) .observe(
sample(10_000, 20_000, 10_000, 4, 1, 3),
0.25,
)
.expect("expected hard promotion"); .expect("expected hard promotion");
assert_eq!(transition.reason, TierTransitionReason::HardPressure); assert_eq!(transition.reason, TierTransitionReason::HardPressure);
assert_eq!(transition.to, AdaptiveTier::Tier1); assert_eq!(transition.to, AdaptiveTier::Tier1);

View File

@ -1,7 +1,5 @@
//! Client Handler //! Client Handler
use ipnetwork::IpNetwork;
use rand::RngExt;
use std::future::Future; use std::future::Future;
use std::net::{IpAddr, SocketAddr}; use std::net::{IpAddr, SocketAddr};
use std::pin::Pin; use std::pin::Pin;
@ -9,6 +7,8 @@ use std::sync::Arc;
use std::sync::OnceLock; use std::sync::OnceLock;
use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Duration; use std::time::Duration;
use ipnetwork::IpNetwork;
use rand::RngExt;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio::time::timeout; use tokio::time::timeout;
@ -75,10 +75,10 @@ use crate::protocol::tls;
use crate::stats::beobachten::BeobachtenStore; use crate::stats::beobachten::BeobachtenStore;
use crate::stats::{ReplayChecker, Stats}; use crate::stats::{ReplayChecker, Stats};
use crate::stream::{BufferPool, CryptoReader, CryptoWriter}; use crate::stream::{BufferPool, CryptoReader, CryptoWriter};
use crate::tls_front::TlsFrontCache;
use crate::transport::middle_proxy::MePool; use crate::transport::middle_proxy::MePool;
use crate::transport::socket::normalize_ip;
use crate::transport::{UpstreamManager, configure_client_socket, parse_proxy_protocol}; use crate::transport::{UpstreamManager, configure_client_socket, parse_proxy_protocol};
use crate::transport::socket::normalize_ip;
use crate::tls_front::TlsFrontCache;
use crate::proxy::direct_relay::handle_via_direct; use crate::proxy::direct_relay::handle_via_direct;
use crate::proxy::handshake::{HandshakeSuccess, handle_mtproto_handshake, handle_tls_handshake}; use crate::proxy::handshake::{HandshakeSuccess, handle_mtproto_handshake, handle_tls_handshake};
@ -116,23 +116,11 @@ fn beobachten_ttl(config: &ProxyConfig) -> Duration {
} }
fn wrap_tls_application_record(payload: &[u8]) -> Vec<u8> { fn wrap_tls_application_record(payload: &[u8]) -> Vec<u8> {
let chunks = payload.len().div_ceil(u16::MAX as usize).max(1); let mut record = Vec::with_capacity(5 + payload.len());
let mut record = Vec::with_capacity(payload.len() + 5 * chunks); record.push(TLS_RECORD_APPLICATION);
record.extend_from_slice(&TLS_VERSION);
if payload.is_empty() { record.extend_from_slice(&(payload.len() as u16).to_be_bytes());
record.push(TLS_RECORD_APPLICATION); record.extend_from_slice(payload);
record.extend_from_slice(&TLS_VERSION);
record.extend_from_slice(&0u16.to_be_bytes());
return record;
}
for chunk in payload.chunks(u16::MAX as usize) {
record.push(TLS_RECORD_APPLICATION);
record.extend_from_slice(&TLS_VERSION);
record.extend_from_slice(&(chunk.len() as u16).to_be_bytes());
record.extend_from_slice(chunk);
}
record record
} }
@ -140,10 +128,7 @@ fn tls_clienthello_len_in_bounds(tls_len: usize) -> bool {
(MIN_TLS_CLIENT_HELLO_SIZE..=MAX_TLS_PLAINTEXT_SIZE).contains(&tls_len) (MIN_TLS_CLIENT_HELLO_SIZE..=MAX_TLS_PLAINTEXT_SIZE).contains(&tls_len)
} }
async fn read_with_progress<R: AsyncRead + Unpin>( async fn read_with_progress<R: AsyncRead + Unpin>(reader: &mut R, mut buf: &mut [u8]) -> std::io::Result<usize> {
reader: &mut R,
mut buf: &mut [u8],
) -> std::io::Result<usize> {
let mut total = 0usize; let mut total = 0usize;
while !buf.is_empty() { while !buf.is_empty() {
match reader.read(buf).await { match reader.read(buf).await {
@ -286,14 +271,10 @@ where
let mut local_addr = synthetic_local_addr(config.server.port); let mut local_addr = synthetic_local_addr(config.server.port);
if proxy_protocol_enabled { if proxy_protocol_enabled {
let proxy_header_timeout = let proxy_header_timeout = Duration::from_millis(
Duration::from_millis(config.server.proxy_protocol_header_timeout_ms.max(1)); config.server.proxy_protocol_header_timeout_ms.max(1),
match timeout( );
proxy_header_timeout, match timeout(proxy_header_timeout, parse_proxy_protocol(&mut stream, peer)).await {
parse_proxy_protocol(&mut stream, peer),
)
.await
{
Ok(Ok(info)) => { Ok(Ok(info)) => {
if !is_trusted_proxy_source(peer.ip(), &config.server.proxy_protocol_trusted_cidrs) if !is_trusted_proxy_source(peer.ip(), &config.server.proxy_protocol_trusted_cidrs)
{ {
@ -693,8 +674,9 @@ impl RunningClientHandler {
let mut local_addr = self.stream.local_addr().map_err(ProxyError::Io)?; let mut local_addr = self.stream.local_addr().map_err(ProxyError::Io)?;
if self.proxy_protocol_enabled { if self.proxy_protocol_enabled {
let proxy_header_timeout = let proxy_header_timeout = Duration::from_millis(
Duration::from_millis(self.config.server.proxy_protocol_header_timeout_ms.max(1)); self.config.server.proxy_protocol_header_timeout_ms.max(1),
);
match timeout( match timeout(
proxy_header_timeout, proxy_header_timeout,
parse_proxy_protocol(&mut self.stream, self.peer), parse_proxy_protocol(&mut self.stream, self.peer),
@ -779,11 +761,7 @@ impl RunningClientHandler {
} }
} }
async fn handle_tls_client( async fn handle_tls_client(mut self, first_bytes: [u8; 5], local_addr: SocketAddr) -> Result<HandshakeOutcome> {
mut self,
first_bytes: [u8; 5],
local_addr: SocketAddr,
) -> Result<HandshakeOutcome> {
let peer = self.peer; let peer = self.peer;
let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize; let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize;
@ -917,8 +895,7 @@ impl RunningClientHandler {
} else { } else {
wrap_tls_application_record(&pending_plaintext) wrap_tls_application_record(&pending_plaintext)
}; };
let reader = let reader = tokio::io::AsyncReadExt::chain(std::io::Cursor::new(pending_record), reader);
tokio::io::AsyncReadExt::chain(std::io::Cursor::new(pending_record), reader);
stats.increment_connects_bad(); stats.increment_connects_bad();
debug!( debug!(
peer = %peer, peer = %peer,
@ -956,11 +933,7 @@ impl RunningClientHandler {
))) )))
} }
async fn handle_direct_client( async fn handle_direct_client(mut self, first_bytes: [u8; 5], local_addr: SocketAddr) -> Result<HandshakeOutcome> {
mut self,
first_bytes: [u8; 5],
local_addr: SocketAddr,
) -> Result<HandshakeOutcome> {
let peer = self.peer; let peer = self.peer;
if !self.config.general.modes.classic && !self.config.general.modes.secure { if !self.config.general.modes.classic && !self.config.general.modes.secure {
@ -1062,21 +1035,22 @@ impl RunningClientHandler {
{ {
let user = success.user.clone(); let user = success.user.clone();
let user_limit_reservation = match Self::acquire_user_connection_reservation_static( let user_limit_reservation =
&user, match Self::acquire_user_connection_reservation_static(
&config, &user,
stats.clone(), &config,
peer_addr, stats.clone(),
ip_tracker, peer_addr,
) ip_tracker,
.await )
{ .await
Ok(reservation) => reservation, {
Err(e) => { Ok(reservation) => reservation,
warn!(user = %user, error = %e, "User admission check failed"); Err(e) => {
return Err(e); warn!(user = %user, error = %e, "User admission check failed");
} return Err(e);
}; }
};
let route_snapshot = route_runtime.snapshot(); let route_snapshot = route_runtime.snapshot();
let session_id = rng.u64(); let session_id = rng.u64();
@ -1160,11 +1134,7 @@ impl RunningClientHandler {
}); });
} }
let limit = config let limit = config.access.user_max_tcp_conns.get(user).map(|v| *v as u64);
.access
.user_max_tcp_conns
.get(user)
.map(|v| *v as u64);
if !stats.try_acquire_user_curr_connects(user, limit) { if !stats.try_acquire_user_curr_connects(user, limit) {
return Err(ProxyError::ConnectionLimitExceeded { return Err(ProxyError::ConnectionLimitExceeded {
user: user.to_string(), user: user.to_string(),
@ -1324,7 +1294,3 @@ mod masking_probe_evasion_blackhat_tests;
#[cfg(test)] #[cfg(test)]
#[path = "tests/client_beobachten_ttl_bounds_security_tests.rs"] #[path = "tests/client_beobachten_ttl_bounds_security_tests.rs"]
mod beobachten_ttl_bounds_security_tests; mod beobachten_ttl_bounds_security_tests;
#[cfg(test)]
#[path = "tests/client_tls_record_wrap_hardening_security_tests.rs"]
mod tls_record_wrap_hardening_security_tests;

View File

@ -1,10 +1,10 @@
use std::collections::HashSet;
use std::ffi::OsString; use std::ffi::OsString;
use std::fs::OpenOptions; use std::fs::OpenOptions;
use std::io::Write; use std::io::Write;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::path::{Component, Path, PathBuf}; use std::path::{Component, Path, PathBuf};
use std::sync::Arc; use std::sync::Arc;
use std::collections::HashSet;
use std::sync::{Mutex, OnceLock}; use std::sync::{Mutex, OnceLock};
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf, split}; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf, split};
@ -24,13 +24,13 @@ use crate::proxy::route_mode::{
use crate::stats::Stats; use crate::stats::Stats;
use crate::stream::{BufferPool, CryptoReader, CryptoWriter}; use crate::stream::{BufferPool, CryptoReader, CryptoWriter};
use crate::transport::UpstreamManager; use crate::transport::UpstreamManager;
#[cfg(unix)]
use nix::fcntl::{Flock, FlockArg, OFlag, openat};
#[cfg(unix)]
use nix::sys::stat::Mode;
#[cfg(unix)] #[cfg(unix)]
use std::os::unix::fs::OpenOptionsExt; use std::os::unix::fs::OpenOptionsExt;
#[cfg(unix)]
use std::os::unix::ffi::OsStrExt;
#[cfg(unix)]
use std::os::unix::io::{AsRawFd, FromRawFd};
const UNKNOWN_DC_LOG_DISTINCT_LIMIT: usize = 1024; const UNKNOWN_DC_LOG_DISTINCT_LIMIT: usize = 1024;
static LOGGED_UNKNOWN_DCS: OnceLock<Mutex<HashSet<i16>>> = OnceLock::new(); static LOGGED_UNKNOWN_DCS: OnceLock<Mutex<HashSet<i16>>> = OnceLock::new();
@ -160,9 +160,7 @@ fn open_unknown_dc_log_append(path: &Path) -> std::io::Result<std::fs::File> {
} }
} }
fn open_unknown_dc_log_append_anchored( fn open_unknown_dc_log_append_anchored(path: &SanitizedUnknownDcLogPath) -> std::io::Result<std::fs::File> {
path: &SanitizedUnknownDcLogPath,
) -> std::io::Result<std::fs::File> {
#[cfg(unix)] #[cfg(unix)]
{ {
let parent = OpenOptions::new() let parent = OpenOptions::new()
@ -170,16 +168,23 @@ fn open_unknown_dc_log_append_anchored(
.custom_flags(libc::O_DIRECTORY | libc::O_NOFOLLOW | libc::O_CLOEXEC) .custom_flags(libc::O_DIRECTORY | libc::O_NOFOLLOW | libc::O_CLOEXEC)
.open(&path.allowed_parent)?; .open(&path.allowed_parent)?;
let oflags = OFlag::O_CREAT let file_name = std::ffi::CString::new(path.file_name.as_os_str().as_bytes())
| OFlag::O_APPEND .map_err(|_| std::io::Error::new(std::io::ErrorKind::InvalidInput, "unknown DC log file name contains NUL byte"))?;
| OFlag::O_WRONLY
| OFlag::O_NOFOLLOW let fd = unsafe {
| OFlag::O_CLOEXEC; libc::openat(
let mode = Mode::from_bits_truncate(0o600); parent.as_raw_fd(),
let path_component = Path::new(path.file_name.as_os_str()); file_name.as_ptr(),
let fd = openat(&parent, path_component, oflags, mode) libc::O_CREAT | libc::O_APPEND | libc::O_WRONLY | libc::O_NOFOLLOW | libc::O_CLOEXEC,
.map_err(|err| std::io::Error::from_raw_os_error(err as i32))?; 0o600,
let file = std::fs::File::from(fd); )
};
if fd < 0 {
return Err(std::io::Error::last_os_error());
}
let file = unsafe { std::fs::File::from_raw_fd(fd) };
Ok(file) Ok(file)
} }
#[cfg(not(unix))] #[cfg(not(unix))]
@ -195,13 +200,16 @@ fn open_unknown_dc_log_append_anchored(
fn append_unknown_dc_line(file: &mut std::fs::File, dc_idx: i16) -> std::io::Result<()> { fn append_unknown_dc_line(file: &mut std::fs::File, dc_idx: i16) -> std::io::Result<()> {
#[cfg(unix)] #[cfg(unix)]
{ {
let cloned = file.try_clone()?; if unsafe { libc::flock(file.as_raw_fd(), libc::LOCK_EX) } != 0 {
let mut locked = Flock::lock(cloned, FlockArg::LockExclusive) return Err(std::io::Error::last_os_error());
.map_err(|(_, err)| std::io::Error::from_raw_os_error(err as i32))?; }
let write_result = writeln!(&mut *locked, "dc_idx={dc_idx}");
let _ = locked let write_result = writeln!(file, "dc_idx={dc_idx}");
.unlock()
.map_err(|(_, err)| std::io::Error::from_raw_os_error(err as i32))?; if unsafe { libc::flock(file.as_raw_fd(), libc::LOCK_UN) } != 0 {
return Err(std::io::Error::last_os_error());
}
write_result write_result
} }
#[cfg(not(unix))] #[cfg(not(unix))]

View File

@ -2,29 +2,29 @@
#![allow(dead_code)] #![allow(dead_code)]
use dashmap::DashMap; use std::net::SocketAddr;
use dashmap::mapref::entry::Entry;
use std::collections::HashSet; use std::collections::HashSet;
use std::collections::hash_map::RandomState; use std::collections::hash_map::RandomState;
use std::hash::{BuildHasher, Hash, Hasher};
use std::net::SocketAddr;
use std::net::{IpAddr, Ipv6Addr}; use std::net::{IpAddr, Ipv6Addr};
use std::sync::Arc; use std::sync::Arc;
use std::sync::{Mutex, OnceLock}; use std::sync::{Mutex, OnceLock};
use std::hash::{BuildHasher, Hash, Hasher};
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use dashmap::DashMap;
use dashmap::mapref::entry::Entry;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tracing::{debug, trace, warn}; use tracing::{debug, warn, trace};
use zeroize::{Zeroize, Zeroizing}; use zeroize::{Zeroize, Zeroizing};
use crate::config::ProxyConfig; use crate::crypto::{sha256, AesCtr, SecureRandom};
use crate::crypto::{AesCtr, SecureRandom, sha256}; use rand::RngExt;
use crate::error::{HandshakeResult, ProxyError};
use crate::protocol::constants::*; use crate::protocol::constants::*;
use crate::protocol::tls; use crate::protocol::tls;
use crate::stream::{FakeTlsReader, FakeTlsWriter, CryptoReader, CryptoWriter};
use crate::error::{ProxyError, HandshakeResult};
use crate::stats::ReplayChecker; use crate::stats::ReplayChecker;
use crate::stream::{CryptoReader, CryptoWriter, FakeTlsReader, FakeTlsWriter}; use crate::config::ProxyConfig;
use crate::tls_front::{TlsFrontCache, emulator}; use crate::tls_front::{TlsFrontCache, emulator};
use rand::RngExt;
const ACCESS_SECRET_BYTES: usize = 16; const ACCESS_SECRET_BYTES: usize = 16;
static INVALID_SECRET_WARNED: OnceLock<Mutex<HashSet<(String, String)>>> = OnceLock::new(); static INVALID_SECRET_WARNED: OnceLock<Mutex<HashSet<(String, String)>>> = OnceLock::new();
@ -67,8 +67,7 @@ struct AuthProbeSaturationState {
} }
static AUTH_PROBE_STATE: OnceLock<DashMap<IpAddr, AuthProbeState>> = OnceLock::new(); static AUTH_PROBE_STATE: OnceLock<DashMap<IpAddr, AuthProbeState>> = OnceLock::new();
static AUTH_PROBE_SATURATION_STATE: OnceLock<Mutex<Option<AuthProbeSaturationState>>> = static AUTH_PROBE_SATURATION_STATE: OnceLock<Mutex<Option<AuthProbeSaturationState>>> = OnceLock::new();
OnceLock::new();
static AUTH_PROBE_EVICTION_HASHER: OnceLock<RandomState> = OnceLock::new(); static AUTH_PROBE_EVICTION_HASHER: OnceLock<RandomState> = OnceLock::new();
fn auth_probe_state_map() -> &'static DashMap<IpAddr, AuthProbeState> { fn auth_probe_state_map() -> &'static DashMap<IpAddr, AuthProbeState> {
@ -79,8 +78,8 @@ fn auth_probe_saturation_state() -> &'static Mutex<Option<AuthProbeSaturationSta
AUTH_PROBE_SATURATION_STATE.get_or_init(|| Mutex::new(None)) AUTH_PROBE_SATURATION_STATE.get_or_init(|| Mutex::new(None))
} }
fn auth_probe_saturation_state_lock() fn auth_probe_saturation_state_lock(
-> std::sync::MutexGuard<'static, Option<AuthProbeSaturationState>> { ) -> std::sync::MutexGuard<'static, Option<AuthProbeSaturationState>> {
auth_probe_saturation_state() auth_probe_saturation_state()
.lock() .lock()
.unwrap_or_else(|poisoned| poisoned.into_inner()) .unwrap_or_else(|poisoned| poisoned.into_inner())
@ -253,7 +252,9 @@ fn auth_probe_record_failure_with_state(
match eviction_candidate { match eviction_candidate {
Some((_, current_fail, current_seen)) Some((_, current_fail, current_seen))
if fail_streak > current_fail if fail_streak > current_fail
|| (fail_streak == current_fail && last_seen >= current_seen) => {} || (fail_streak == current_fail && last_seen >= current_seen) =>
{
}
_ => eviction_candidate = Some((key, fail_streak, last_seen)), _ => eviction_candidate = Some((key, fail_streak, last_seen)),
} }
} }
@ -283,7 +284,9 @@ fn auth_probe_record_failure_with_state(
match eviction_candidate { match eviction_candidate {
Some((_, current_fail, current_seen)) Some((_, current_fail, current_seen))
if fail_streak > current_fail if fail_streak > current_fail
|| (fail_streak == current_fail && last_seen >= current_seen) => {} || (fail_streak == current_fail && last_seen >= current_seen) =>
{
}
_ => eviction_candidate = Some((key, fail_streak, last_seen)), _ => eviction_candidate = Some((key, fail_streak, last_seen)),
} }
if auth_probe_state_expired(entry.value(), now) { if auth_probe_state_expired(entry.value(), now) {
@ -303,7 +306,9 @@ fn auth_probe_record_failure_with_state(
match eviction_candidate { match eviction_candidate {
Some((_, current_fail, current_seen)) Some((_, current_fail, current_seen))
if fail_streak > current_fail if fail_streak > current_fail
|| (fail_streak == current_fail && last_seen >= current_seen) => {} || (fail_streak == current_fail && last_seen >= current_seen) =>
{
}
_ => eviction_candidate = Some((key, fail_streak, last_seen)), _ => eviction_candidate = Some((key, fail_streak, last_seen)),
} }
if auth_probe_state_expired(entry.value(), now) { if auth_probe_state_expired(entry.value(), now) {
@ -534,12 +539,13 @@ pub struct HandshakeSuccess {
/// Decryption key and IV (for reading from client) /// Decryption key and IV (for reading from client)
pub dec_key: [u8; 32], pub dec_key: [u8; 32],
pub dec_iv: u128, pub dec_iv: u128,
/// Encryption key and IV (for writing to client) /// Encryption key and IV (for writing to client)
pub enc_key: [u8; 32], pub enc_key: [u8; 32],
pub enc_iv: u128, pub enc_iv: u128,
/// Client address /// Client address
pub peer: SocketAddr, pub peer: SocketAddr,
/// Whether TLS was used /// Whether TLS was used
pub is_tls: bool, pub is_tls: bool,
} }
@ -597,7 +603,7 @@ where
auth_probe_record_failure(peer.ip(), Instant::now()); auth_probe_record_failure(peer.ip(), Instant::now());
maybe_apply_server_hello_delay(config).await; maybe_apply_server_hello_delay(config).await;
debug!( debug!(
peer = %peer, peer = %peer,
ignore_time_skew = config.access.ignore_time_skew, ignore_time_skew = config.access.ignore_time_skew,
"TLS handshake validation failed - no matching user or time skew" "TLS handshake validation failed - no matching user or time skew"
); );
@ -626,7 +632,7 @@ where
let cached = if config.censorship.tls_emulation { let cached = if config.censorship.tls_emulation {
if let Some(cache) = tls_cache.as_ref() { if let Some(cache) = tls_cache.as_ref() {
let selected_domain = if let Some(sni) = client_sni.as_ref() { let selected_domain = if let Some(sni) = client_sni.as_ref() {
if cache.contains_domain(sni).await { if cache.contains_domain(&sni).await {
sni.clone() sni.clone()
} else { } else {
config.censorship.tls_domain.clone() config.censorship.tls_domain.clone()
@ -763,6 +769,7 @@ where
let decoded_users = decode_user_secrets(config, preferred_user); let decoded_users = decode_user_secrets(config, preferred_user);
for (user, secret) in decoded_users { for (user, secret) in decoded_users {
let dec_prekey = &dec_prekey_iv[..PREKEY_LEN]; let dec_prekey = &dec_prekey_iv[..PREKEY_LEN];
let dec_iv_bytes = &dec_prekey_iv[PREKEY_LEN..]; let dec_iv_bytes = &dec_prekey_iv[PREKEY_LEN..];
@ -813,12 +820,12 @@ where
let encryptor = AesCtr::new(&enc_key, enc_iv); let encryptor = AesCtr::new(&enc_key, enc_iv);
// Apply replay tracking only after successful authentication. // Apply replay tracking only after successful authentication.
// //
// This ordering prevents an attacker from producing invalid handshakes that // This ordering prevents an attacker from producing invalid handshakes that
// still collide with a valid handshake's replay slot and thus evict a valid // still collide with a valid handshake's replay slot and thus evict a valid
// entry from the cache. We accept the cost of performing the full // entry from the cache. We accept the cost of performing the full
// authentication check first to avoid poisoning the replay cache. // authentication check first to avoid poisoning the replay cache.
if replay_checker.check_and_add_handshake(dec_prekey_iv) { if replay_checker.check_and_add_handshake(dec_prekey_iv) {
auth_probe_record_failure(peer.ip(), Instant::now()); auth_probe_record_failure(peer.ip(), Instant::now());
maybe_apply_server_hello_delay(config).await; maybe_apply_server_hello_delay(config).await;
@ -865,7 +872,7 @@ where
/// Generate nonce for Telegram connection /// Generate nonce for Telegram connection
pub fn generate_tg_nonce( pub fn generate_tg_nonce(
proto_tag: ProtoTag, proto_tag: ProtoTag,
dc_idx: i16, dc_idx: i16,
client_enc_key: &[u8; 32], client_enc_key: &[u8; 32],
client_enc_iv: u128, client_enc_iv: u128,
@ -878,19 +885,13 @@ pub fn generate_tg_nonce(
continue; continue;
}; };
if RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]) { if RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]) { continue; }
continue;
}
let first_four: [u8; 4] = [nonce[0], nonce[1], nonce[2], nonce[3]]; let first_four: [u8; 4] = [nonce[0], nonce[1], nonce[2], nonce[3]];
if RESERVED_NONCE_BEGINNINGS.contains(&first_four) { if RESERVED_NONCE_BEGINNINGS.contains(&first_four) { continue; }
continue;
}
let continue_four: [u8; 4] = [nonce[4], nonce[5], nonce[6], nonce[7]]; let continue_four: [u8; 4] = [nonce[4], nonce[5], nonce[6], nonce[7]];
if RESERVED_NONCE_CONTINUES.contains(&continue_four) { if RESERVED_NONCE_CONTINUES.contains(&continue_four) { continue; }
continue;
}
nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes()); nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes());
// CRITICAL: write dc_idx so upstream DC knows where to route // CRITICAL: write dc_idx so upstream DC knows where to route
@ -941,7 +942,7 @@ pub fn encrypt_tg_nonce_with_ciphers(nonce: &[u8; HANDSHAKE_LEN]) -> (Vec<u8>, A
let dec_iv = u128::from_be_bytes(dec_iv_arr); let dec_iv = u128::from_be_bytes(dec_iv_arr);
let mut encryptor = AesCtr::new(&enc_key, enc_iv); let mut encryptor = AesCtr::new(&enc_key, enc_iv);
let encrypted_full = encryptor.encrypt(nonce); // counter: 0 → 4 let encrypted_full = encryptor.encrypt(nonce); // counter: 0 → 4
let mut result = nonce[..PROTO_TAG_POS].to_vec(); let mut result = nonce[..PROTO_TAG_POS].to_vec();
result.extend_from_slice(&encrypted_full[PROTO_TAG_POS..]); result.extend_from_slice(&encrypted_full[PROTO_TAG_POS..]);
@ -954,6 +955,7 @@ pub fn encrypt_tg_nonce_with_ciphers(nonce: &[u8; HANDSHAKE_LEN]) -> (Vec<u8>, A
} }
/// Encrypt nonce for sending to Telegram (legacy function for compatibility) /// Encrypt nonce for sending to Telegram (legacy function for compatibility)
pub fn encrypt_tg_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec<u8> { pub fn encrypt_tg_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec<u8> {
let (encrypted, _, _) = encrypt_tg_nonce_with_ciphers(nonce); let (encrypted, _, _) = encrypt_tg_nonce_with_ciphers(nonce);
encrypted encrypted

View File

@ -1,19 +1,19 @@
//! Masking - forward unrecognized traffic to mask host //! Masking - forward unrecognized traffic to mask host
use std::str;
use std::net::SocketAddr;
use std::time::Duration;
use rand::{Rng, RngExt};
use tokio::net::TcpStream;
#[cfg(unix)]
use tokio::net::UnixStream;
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt};
use tokio::time::{Instant, timeout};
use tracing::debug;
use crate::config::ProxyConfig; use crate::config::ProxyConfig;
use crate::network::dns_overrides::resolve_socket_addr; use crate::network::dns_overrides::resolve_socket_addr;
use crate::stats::beobachten::BeobachtenStore; use crate::stats::beobachten::BeobachtenStore;
use crate::transport::proxy_protocol::{ProxyProtocolV1Builder, ProxyProtocolV2Builder}; use crate::transport::proxy_protocol::{ProxyProtocolV1Builder, ProxyProtocolV2Builder};
use rand::{Rng, RngExt};
use std::net::SocketAddr;
use std::str;
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpStream;
#[cfg(unix)]
use tokio::net::UnixStream;
use tokio::time::{Instant, timeout};
use tracing::debug;
#[cfg(not(test))] #[cfg(not(test))]
const MASK_TIMEOUT: Duration = Duration::from_secs(5); const MASK_TIMEOUT: Duration = Duration::from_secs(5);
@ -98,8 +98,8 @@ async fn maybe_write_shape_padding<W>(
cap: usize, cap: usize,
above_cap_blur: bool, above_cap_blur: bool,
above_cap_blur_max_bytes: usize, above_cap_blur_max_bytes: usize,
aggressive_mode: bool, )
) where where
W: AsyncWrite + Unpin, W: AsyncWrite + Unpin,
{ {
if !enabled { if !enabled {
@ -108,11 +108,7 @@ async fn maybe_write_shape_padding<W>(
let target_total = if total_sent >= cap && above_cap_blur && above_cap_blur_max_bytes > 0 { let target_total = if total_sent >= cap && above_cap_blur && above_cap_blur_max_bytes > 0 {
let mut rng = rand::rng(); let mut rng = rand::rng();
let extra = if aggressive_mode { let extra = rng.random_range(0..=above_cap_blur_max_bytes);
rng.random_range(1..=above_cap_blur_max_bytes)
} else {
rng.random_range(0..=above_cap_blur_max_bytes)
};
total_sent.saturating_add(extra) total_sent.saturating_add(extra)
} else { } else {
next_mask_shape_bucket(total_sent, floor, cap) next_mask_shape_bucket(total_sent, floor, cap)
@ -171,10 +167,7 @@ async fn consume_client_data_with_timeout<R>(reader: R)
where where
R: AsyncRead + Unpin, R: AsyncRead + Unpin,
{ {
if timeout(MASK_RELAY_TIMEOUT, consume_client_data(reader)) if timeout(MASK_RELAY_TIMEOUT, consume_client_data(reader)).await.is_err() {
.await
.is_err()
{
debug!("Timed out while consuming client data on masking fallback path"); debug!("Timed out while consuming client data on masking fallback path");
} }
} }
@ -220,12 +213,9 @@ async fn wait_mask_outcome_budget(started: Instant, config: &ProxyConfig) {
fn detect_client_type(data: &[u8]) -> &'static str { fn detect_client_type(data: &[u8]) -> &'static str {
// Check for HTTP request // Check for HTTP request
if data.len() > 4 if data.len() > 4
&& (data.starts_with(b"GET ") && (data.starts_with(b"GET ") || data.starts_with(b"POST") ||
|| data.starts_with(b"POST") data.starts_with(b"HEAD") || data.starts_with(b"PUT ") ||
|| data.starts_with(b"HEAD") data.starts_with(b"DELETE") || data.starts_with(b"OPTIONS"))
|| data.starts_with(b"PUT ")
|| data.starts_with(b"DELETE")
|| data.starts_with(b"OPTIONS"))
{ {
return "HTTP"; return "HTTP";
} }
@ -262,12 +252,16 @@ fn build_mask_proxy_header(
), ),
_ => { _ => {
let header = match (peer, local_addr) { let header = match (peer, local_addr) {
(SocketAddr::V4(src), SocketAddr::V4(dst)) => ProxyProtocolV1Builder::new() (SocketAddr::V4(src), SocketAddr::V4(dst)) => {
.tcp4(src.into(), dst.into()) ProxyProtocolV1Builder::new()
.build(), .tcp4(src.into(), dst.into())
(SocketAddr::V6(src), SocketAddr::V6(dst)) => ProxyProtocolV1Builder::new() .build()
.tcp6(src.into(), dst.into()) }
.build(), (SocketAddr::V6(src), SocketAddr::V6(dst)) => {
ProxyProtocolV1Builder::new()
.tcp6(src.into(), dst.into())
.build()
}
_ => ProxyProtocolV1Builder::new().build(), _ => ProxyProtocolV1Builder::new().build(),
}; };
Some(header) Some(header)
@ -284,7 +278,8 @@ pub async fn handle_bad_client<R, W>(
local_addr: SocketAddr, local_addr: SocketAddr,
config: &ProxyConfig, config: &ProxyConfig,
beobachten: &BeobachtenStore, beobachten: &BeobachtenStore,
) where )
where
R: AsyncRead + Unpin + Send + 'static, R: AsyncRead + Unpin + Send + 'static,
W: AsyncWrite + Unpin + Send + 'static, W: AsyncWrite + Unpin + Send + 'static,
{ {
@ -316,16 +311,13 @@ pub async fn handle_bad_client<R, W>(
match connect_result { match connect_result {
Ok(Ok(stream)) => { Ok(Ok(stream)) => {
let (mask_read, mut mask_write) = stream.into_split(); let (mask_read, mut mask_write) = stream.into_split();
let proxy_header = build_mask_proxy_header( let proxy_header =
config.censorship.mask_proxy_protocol, build_mask_proxy_header(config.censorship.mask_proxy_protocol, peer, local_addr);
peer, if let Some(header) = proxy_header {
local_addr, if !write_proxy_header_with_timeout(&mut mask_write, &header).await {
); wait_mask_outcome_budget(outcome_started, config).await;
if let Some(header) = proxy_header return;
&& !write_proxy_header_with_timeout(&mut mask_write, &header).await }
{
wait_mask_outcome_budget(outcome_started, config).await;
return;
} }
if timeout( if timeout(
MASK_RELAY_TIMEOUT, MASK_RELAY_TIMEOUT,
@ -340,7 +332,6 @@ pub async fn handle_bad_client<R, W>(
config.censorship.mask_shape_bucket_cap_bytes, config.censorship.mask_shape_bucket_cap_bytes,
config.censorship.mask_shape_above_cap_blur, config.censorship.mask_shape_above_cap_blur,
config.censorship.mask_shape_above_cap_blur_max_bytes, config.censorship.mask_shape_above_cap_blur_max_bytes,
config.censorship.mask_shape_hardening_aggressive_mode,
), ),
) )
.await .await
@ -365,10 +356,7 @@ pub async fn handle_bad_client<R, W>(
return; return;
} }
let mask_host = config let mask_host = config.censorship.mask_host.as_deref()
.censorship
.mask_host
.as_deref()
.unwrap_or(&config.censorship.tls_domain); .unwrap_or(&config.censorship.tls_domain);
let mask_port = config.censorship.mask_port; let mask_port = config.censorship.mask_port;
@ -393,11 +381,11 @@ pub async fn handle_bad_client<R, W>(
build_mask_proxy_header(config.censorship.mask_proxy_protocol, peer, local_addr); build_mask_proxy_header(config.censorship.mask_proxy_protocol, peer, local_addr);
let (mask_read, mut mask_write) = stream.into_split(); let (mask_read, mut mask_write) = stream.into_split();
if let Some(header) = proxy_header if let Some(header) = proxy_header {
&& !write_proxy_header_with_timeout(&mut mask_write, &header).await if !write_proxy_header_with_timeout(&mut mask_write, &header).await {
{ wait_mask_outcome_budget(outcome_started, config).await;
wait_mask_outcome_budget(outcome_started, config).await; return;
return; }
} }
if timeout( if timeout(
MASK_RELAY_TIMEOUT, MASK_RELAY_TIMEOUT,
@ -412,7 +400,6 @@ pub async fn handle_bad_client<R, W>(
config.censorship.mask_shape_bucket_cap_bytes, config.censorship.mask_shape_bucket_cap_bytes,
config.censorship.mask_shape_above_cap_blur, config.censorship.mask_shape_above_cap_blur,
config.censorship.mask_shape_above_cap_blur_max_bytes, config.censorship.mask_shape_above_cap_blur_max_bytes,
config.censorship.mask_shape_hardening_aggressive_mode,
), ),
) )
.await .await
@ -448,8 +435,8 @@ async fn relay_to_mask<R, W, MR, MW>(
shape_bucket_cap_bytes: usize, shape_bucket_cap_bytes: usize,
shape_above_cap_blur: bool, shape_above_cap_blur: bool,
shape_above_cap_blur_max_bytes: usize, shape_above_cap_blur_max_bytes: usize,
shape_hardening_aggressive_mode: bool, )
) where where
R: AsyncRead + Unpin + Send + 'static, R: AsyncRead + Unpin + Send + 'static,
W: AsyncWrite + Unpin + Send + 'static, W: AsyncWrite + Unpin + Send + 'static,
MR: AsyncRead + Unpin + Send + 'static, MR: AsyncRead + Unpin + Send + 'static,
@ -463,32 +450,32 @@ async fn relay_to_mask<R, W, MR, MW>(
return; return;
} }
let (upstream_copy, downstream_copy) = tokio::join!( let _ = tokio::join!(
async { copy_with_idle_timeout(&mut reader, &mut mask_write).await }, async {
async { copy_with_idle_timeout(&mut mask_read, &mut writer).await } let copied = copy_with_idle_timeout(&mut reader, &mut mask_write).await;
let total_sent = initial_data.len().saturating_add(copied.total);
let should_shape = shape_hardening_enabled
&& copied.ended_by_eof
&& !initial_data.is_empty();
maybe_write_shape_padding(
&mut mask_write,
total_sent,
should_shape,
shape_bucket_floor_bytes,
shape_bucket_cap_bytes,
shape_above_cap_blur,
shape_above_cap_blur_max_bytes,
)
.await;
let _ = mask_write.shutdown().await;
},
async {
let _ = copy_with_idle_timeout(&mut mask_read, &mut writer).await;
let _ = writer.shutdown().await;
}
); );
let total_sent = initial_data.len().saturating_add(upstream_copy.total);
let should_shape = shape_hardening_enabled
&& !initial_data.is_empty()
&& (upstream_copy.ended_by_eof
|| (shape_hardening_aggressive_mode && downstream_copy.total == 0));
maybe_write_shape_padding(
&mut mask_write,
total_sent,
should_shape,
shape_bucket_floor_bytes,
shape_bucket_cap_bytes,
shape_above_cap_blur,
shape_above_cap_blur_max_bytes,
shape_hardening_aggressive_mode,
)
.await;
let _ = mask_write.shutdown().await;
let _ = writer.shutdown().await;
} }
/// Just consume all data from client without responding /// Just consume all data from client without responding
@ -537,14 +524,6 @@ mod masking_shape_guard_adversarial_tests;
#[path = "tests/masking_shape_classifier_resistance_adversarial_tests.rs"] #[path = "tests/masking_shape_classifier_resistance_adversarial_tests.rs"]
mod masking_shape_classifier_resistance_adversarial_tests; mod masking_shape_classifier_resistance_adversarial_tests;
#[cfg(test)]
#[path = "tests/masking_shape_bypass_blackhat_tests.rs"]
mod masking_shape_bypass_blackhat_tests;
#[cfg(test)]
#[path = "tests/masking_aggressive_mode_security_tests.rs"]
mod masking_aggressive_mode_security_tests;
#[cfg(test)] #[cfg(test)]
#[path = "tests/masking_timing_sidechannel_redteam_expected_fail_tests.rs"] #[path = "tests/masking_timing_sidechannel_redteam_expected_fail_tests.rs"]
mod masking_timing_sidechannel_redteam_expected_fail_tests; mod masking_timing_sidechannel_redteam_expected_fail_tests;

View File

@ -1,6 +1,7 @@
use std::collections::hash_map::RandomState; use std::collections::hash_map::RandomState;
use std::collections::{BTreeSet, HashMap}; use std::collections::{BTreeSet, HashMap};
use std::hash::{BuildHasher, Hash}; use std::hash::BuildHasher;
use std::hash::{Hash, Hasher};
use std::net::{IpAddr, SocketAddr}; use std::net::{IpAddr, SocketAddr};
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::{Arc, Mutex, OnceLock}; use std::sync::{Arc, Mutex, OnceLock};
@ -8,17 +9,17 @@ use std::time::{Duration, Instant};
use dashmap::DashMap; use dashmap::DashMap;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::sync::{Mutex as AsyncMutex, mpsc, oneshot, watch}; use tokio::sync::{mpsc, oneshot, watch, Mutex as AsyncMutex};
use tokio::time::timeout; use tokio::time::timeout;
use tracing::{debug, info, trace, warn}; use tracing::{debug, info, trace, warn};
use crate::config::ProxyConfig; use crate::config::ProxyConfig;
use crate::crypto::SecureRandom; use crate::crypto::SecureRandom;
use crate::error::{ProxyError, Result}; use crate::error::{ProxyError, Result};
use crate::protocol::constants::{secure_padding_len, *}; use crate::protocol::constants::{*, secure_padding_len};
use crate::proxy::handshake::HandshakeSuccess; use crate::proxy::handshake::HandshakeSuccess;
use crate::proxy::route_mode::{ use crate::proxy::route_mode::{
ROUTE_SWITCH_ERROR_MSG, RelayRouteMode, RouteCutoverState, affected_cutover_state, RelayRouteMode, RouteCutoverState, ROUTE_SWITCH_ERROR_MSG, affected_cutover_state,
cutover_stagger_delay, cutover_stagger_delay,
}; };
use crate::stats::Stats; use crate::stats::Stats;
@ -49,16 +50,11 @@ const ME_D2C_FLUSH_BATCH_MAX_BYTES_MIN: usize = 4096;
const QUOTA_USER_LOCKS_MAX: usize = 64; const QUOTA_USER_LOCKS_MAX: usize = 64;
#[cfg(not(test))] #[cfg(not(test))]
const QUOTA_USER_LOCKS_MAX: usize = 4_096; const QUOTA_USER_LOCKS_MAX: usize = 4_096;
#[cfg(test)]
const QUOTA_OVERFLOW_LOCK_STRIPES: usize = 16;
#[cfg(not(test))]
const QUOTA_OVERFLOW_LOCK_STRIPES: usize = 256;
static DESYNC_DEDUP: OnceLock<DashMap<u64, Instant>> = OnceLock::new(); static DESYNC_DEDUP: OnceLock<DashMap<u64, Instant>> = OnceLock::new();
static DESYNC_HASHER: OnceLock<RandomState> = OnceLock::new(); static DESYNC_HASHER: OnceLock<RandomState> = OnceLock::new();
static DESYNC_FULL_CACHE_LAST_EMIT_AT: OnceLock<Mutex<Option<Instant>>> = OnceLock::new(); static DESYNC_FULL_CACHE_LAST_EMIT_AT: OnceLock<Mutex<Option<Instant>>> = OnceLock::new();
static DESYNC_DEDUP_EVER_SATURATED: OnceLock<AtomicBool> = OnceLock::new(); static DESYNC_DEDUP_EVER_SATURATED: OnceLock<AtomicBool> = OnceLock::new();
static QUOTA_USER_LOCKS: OnceLock<DashMap<String, Arc<AsyncMutex<()>>>> = OnceLock::new(); static QUOTA_USER_LOCKS: OnceLock<DashMap<String, Arc<AsyncMutex<()>>>> = OnceLock::new();
static QUOTA_USER_OVERFLOW_LOCKS: OnceLock<Vec<Arc<AsyncMutex<()>>>> = OnceLock::new();
static RELAY_IDLE_CANDIDATE_REGISTRY: OnceLock<Mutex<RelayIdleCandidateRegistry>> = OnceLock::new(); static RELAY_IDLE_CANDIDATE_REGISTRY: OnceLock<Mutex<RelayIdleCandidateRegistry>> = OnceLock::new();
static RELAY_IDLE_MARK_SEQ: AtomicU64 = AtomicU64::new(0); static RELAY_IDLE_MARK_SEQ: AtomicU64 = AtomicU64::new(0);
@ -290,7 +286,9 @@ impl MeD2cFlushPolicy {
fn hash_value<T: Hash>(value: &T) -> u64 { fn hash_value<T: Hash>(value: &T) -> u64 {
let state = DESYNC_HASHER.get_or_init(RandomState::new); let state = DESYNC_HASHER.get_or_init(RandomState::new);
state.hash_one(value) let mut hasher = state.build_hasher();
value.hash(&mut hasher);
hasher.finish()
} }
fn hash_ip(ip: IpAddr) -> u64 { fn hash_ip(ip: IpAddr) -> u64 {
@ -418,13 +416,6 @@ fn desync_dedup_test_lock() -> &'static Mutex<()> {
TEST_LOCK.get_or_init(|| Mutex::new(())) TEST_LOCK.get_or_init(|| Mutex::new(()))
} }
fn desync_forensics_len_bytes(len: usize) -> ([u8; 4], bool) {
match u32::try_from(len) {
Ok(value) => (value.to_le_bytes(), false),
Err(_) => (u32::MAX.to_le_bytes(), true),
}
}
fn report_desync_frame_too_large( fn report_desync_frame_too_large(
state: &RelayForensicsState, state: &RelayForensicsState,
proto_tag: ProtoTag, proto_tag: ProtoTag,
@ -434,8 +425,7 @@ fn report_desync_frame_too_large(
raw_len_bytes: Option<[u8; 4]>, raw_len_bytes: Option<[u8; 4]>,
stats: &Stats, stats: &Stats,
) -> ProxyError { ) -> ProxyError {
let (fallback_len_buf, len_buf_truncated) = desync_forensics_len_bytes(len); let len_buf = raw_len_bytes.unwrap_or((len as u32).to_le_bytes());
let len_buf = raw_len_bytes.unwrap_or(fallback_len_buf);
let looks_like_tls = raw_len_bytes let looks_like_tls = raw_len_bytes
.map(|b| b[0] == 0x16 && b[1] == 0x03) .map(|b| b[0] == 0x16 && b[1] == 0x03)
.unwrap_or(false); .unwrap_or(false);
@ -471,7 +461,6 @@ fn report_desync_frame_too_large(
bytes_me2c, bytes_me2c,
raw_len = len, raw_len = len,
raw_len_hex = format_args!("0x{:08x}", len), raw_len_hex = format_args!("0x{:08x}", len),
raw_len_bytes_truncated = len_buf_truncated,
raw_bytes = format_args!( raw_bytes = format_args!(
"{:02x} {:02x} {:02x} {:02x}", "{:02x} {:02x} {:02x} {:02x}",
len_buf[0], len_buf[1], len_buf[2], len_buf[3] len_buf[0], len_buf[1], len_buf[2], len_buf[3]
@ -514,7 +503,8 @@ fn report_desync_frame_too_large(
ProxyError::Proxy(format!( ProxyError::Proxy(format!(
"Frame too large: {len} (max {max_frame}), frames_ok={frame_counter}, conn_id={}, trace_id=0x{:016x}", "Frame too large: {len} (max {max_frame}), frames_ok={frame_counter}, conn_id={}, trace_id=0x{:016x}",
state.conn_id, state.trace_id state.conn_id,
state.trace_id
)) ))
} }
@ -538,30 +528,6 @@ fn quota_would_be_exceeded_for_user(
}) })
} }
#[cfg(test)]
fn quota_user_lock_test_guard() -> &'static Mutex<()> {
static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
TEST_LOCK.get_or_init(|| Mutex::new(()))
}
#[cfg(test)]
fn quota_user_lock_test_scope() -> std::sync::MutexGuard<'static, ()> {
quota_user_lock_test_guard()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
fn quota_overflow_user_lock(user: &str) -> Arc<AsyncMutex<()>> {
let stripes = QUOTA_USER_OVERFLOW_LOCKS.get_or_init(|| {
(0..QUOTA_OVERFLOW_LOCK_STRIPES)
.map(|_| Arc::new(AsyncMutex::new(())))
.collect()
});
let hash = crc32fast::hash(user.as_bytes()) as usize;
Arc::clone(&stripes[hash % stripes.len()])
}
fn quota_user_lock(user: &str) -> Arc<AsyncMutex<()>> { fn quota_user_lock(user: &str) -> Arc<AsyncMutex<()>> {
let locks = QUOTA_USER_LOCKS.get_or_init(DashMap::new); let locks = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
if let Some(existing) = locks.get(user) { if let Some(existing) = locks.get(user) {
@ -573,7 +539,7 @@ fn quota_user_lock(user: &str) -> Arc<AsyncMutex<()>> {
} }
if locks.len() >= QUOTA_USER_LOCKS_MAX { if locks.len() >= QUOTA_USER_LOCKS_MAX {
return quota_overflow_user_lock(user); return Arc::new(AsyncMutex::new(()));
} }
let created = Arc::new(AsyncMutex::new(())); let created = Arc::new(AsyncMutex::new(()));
@ -663,9 +629,11 @@ where
stats.increment_user_connects(&user); stats.increment_user_connects(&user);
let _me_connection_lease = stats.acquire_me_connection_lease(); let _me_connection_lease = stats.acquire_me_connection_lease();
if let Some(cutover) = if let Some(cutover) = affected_cutover_state(
affected_cutover_state(&route_rx, RelayRouteMode::Middle, route_snapshot.generation) &route_rx,
{ RelayRouteMode::Middle,
route_snapshot.generation,
) {
let delay = cutover_stagger_delay(session_id, cutover.generation); let delay = cutover_stagger_delay(session_id, cutover.generation);
warn!( warn!(
conn_id, conn_id,
@ -721,22 +689,21 @@ where
.max(C2ME_CHANNEL_CAPACITY_FALLBACK); .max(C2ME_CHANNEL_CAPACITY_FALLBACK);
let (c2me_tx, mut c2me_rx) = mpsc::channel::<C2MeCommand>(c2me_channel_capacity); let (c2me_tx, mut c2me_rx) = mpsc::channel::<C2MeCommand>(c2me_channel_capacity);
let me_pool_c2me = me_pool.clone(); let me_pool_c2me = me_pool.clone();
let effective_tag = effective_tag;
let c2me_sender = tokio::spawn(async move { let c2me_sender = tokio::spawn(async move {
let mut sent_since_yield = 0usize; let mut sent_since_yield = 0usize;
while let Some(cmd) = c2me_rx.recv().await { while let Some(cmd) = c2me_rx.recv().await {
match cmd { match cmd {
C2MeCommand::Data { payload, flags } => { C2MeCommand::Data { payload, flags } => {
me_pool_c2me me_pool_c2me.send_proxy_req(
.send_proxy_req( conn_id,
conn_id, success.dc_idx,
success.dc_idx, peer,
peer, translated_local_addr,
translated_local_addr, payload.as_ref(),
payload.as_ref(), flags,
flags, effective_tag.as_deref(),
effective_tag.as_deref(), ).await?;
)
.await?;
sent_since_yield = sent_since_yield.saturating_add(1); sent_since_yield = sent_since_yield.saturating_add(1);
if should_yield_c2me_sender(sent_since_yield, !c2me_rx.is_empty()) { if should_yield_c2me_sender(sent_since_yield, !c2me_rx.is_empty()) {
sent_since_yield = 0; sent_since_yield = 0;
@ -949,11 +916,7 @@ where
let mut seen_pressure_seq = relay_pressure_event_seq(); let mut seen_pressure_seq = relay_pressure_event_seq();
loop { loop {
if relay_idle_policy.enabled if relay_idle_policy.enabled
&& maybe_evict_idle_candidate_on_pressure( && maybe_evict_idle_candidate_on_pressure(conn_id, &mut seen_pressure_seq, stats.as_ref())
conn_id,
&mut seen_pressure_seq,
stats.as_ref(),
)
{ {
info!( info!(
conn_id, conn_id,
@ -968,9 +931,11 @@ where
break; break;
} }
if let Some(cutover) = if let Some(cutover) = affected_cutover_state(
affected_cutover_state(&route_rx, RelayRouteMode::Middle, route_snapshot.generation) &route_rx,
{ RelayRouteMode::Middle,
route_snapshot.generation,
) {
let delay = cutover_stagger_delay(session_id, cutover.generation); let delay = cutover_stagger_delay(session_id, cutover.generation);
warn!( warn!(
conn_id, conn_id,
@ -1137,8 +1102,7 @@ where
return deadline; return deadline;
} }
let downstream_at = let downstream_at = session_started_at + Duration::from_millis(last_downstream_activity_ms);
session_started_at + Duration::from_millis(last_downstream_activity_ms);
if downstream_at > idle_state.last_client_frame_at { if downstream_at > idle_state.last_client_frame_at {
let grace_deadline = downstream_at + idle_policy.grace_after_downstream_activity; let grace_deadline = downstream_at + idle_policy.grace_after_downstream_activity;
if grace_deadline > deadline { if grace_deadline > deadline {
@ -1153,8 +1117,12 @@ where
let timeout_window = if idle_policy.enabled { let timeout_window = if idle_policy.enabled {
let now = Instant::now(); let now = Instant::now();
let downstream_ms = last_downstream_activity_ms.load(Ordering::Relaxed); let downstream_ms = last_downstream_activity_ms.load(Ordering::Relaxed);
let hard_deadline = let hard_deadline = hard_deadline(
hard_deadline(idle_policy, idle_state, session_started_at, downstream_ms); idle_policy,
idle_state,
session_started_at,
downstream_ms,
);
if now >= hard_deadline { if now >= hard_deadline {
clear_relay_idle_candidate(forensics.conn_id); clear_relay_idle_candidate(forensics.conn_id);
stats.increment_relay_idle_hard_close_total(); stats.increment_relay_idle_hard_close_total();
@ -1162,9 +1130,7 @@ where
.saturating_duration_since(idle_state.last_client_frame_at) .saturating_duration_since(idle_state.last_client_frame_at)
.as_secs(); .as_secs();
let downstream_idle_secs = now let downstream_idle_secs = now
.saturating_duration_since( .saturating_duration_since(session_started_at + Duration::from_millis(downstream_ms))
session_started_at + Duration::from_millis(downstream_ms),
)
.as_secs(); .as_secs();
warn!( warn!(
trace_id = format_args!("0x{:016x}", forensics.trace_id), trace_id = format_args!("0x{:016x}", forensics.trace_id),
@ -1238,9 +1204,7 @@ where
Err(_) if !idle_policy.enabled => { Err(_) if !idle_policy.enabled => {
return Err(ProxyError::Io(std::io::Error::new( return Err(ProxyError::Io(std::io::Error::new(
std::io::ErrorKind::TimedOut, std::io::ErrorKind::TimedOut,
format!( format!("middle-relay client frame read timeout while reading {read_label}"),
"middle-relay client frame read timeout while reading {read_label}"
),
))); )));
} }
Err(_) => {} Err(_) => {}
@ -1506,8 +1470,15 @@ where
user: user.to_string(), user: user.to_string(),
}); });
} }
write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf) write_client_payload(
.await?; client_writer,
proto_tag,
flags,
&data,
rng,
frame_buf,
)
.await?;
bytes_me2c.fetch_add(data.len() as u64, Ordering::Relaxed); bytes_me2c.fetch_add(data.len() as u64, Ordering::Relaxed);
stats.add_user_octets_to(user, data.len() as u64); stats.add_user_octets_to(user, data.len() as u64);
@ -1518,8 +1489,15 @@ where
}); });
} }
} else { } else {
write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf) write_client_payload(
.await?; client_writer,
proto_tag,
flags,
&data,
rng,
frame_buf,
)
.await?;
bytes_me2c.fetch_add(data.len() as u64, Ordering::Relaxed); bytes_me2c.fetch_add(data.len() as u64, Ordering::Relaxed);
stats.add_user_octets_to(user, data.len() as u64); stats.add_user_octets_to(user, data.len() as u64);
@ -1556,31 +1534,6 @@ where
} }
} }
fn compute_intermediate_secure_wire_len(
data_len: usize,
padding_len: usize,
quickack: bool,
) -> Result<(u32, usize)> {
let wire_len = data_len
.checked_add(padding_len)
.ok_or_else(|| ProxyError::Proxy("Frame length overflow".into()))?;
if wire_len > 0x7fff_ffffusize {
return Err(ProxyError::Proxy(format!(
"Intermediate/Secure frame too large: {wire_len}"
)));
}
let total = 4usize
.checked_add(wire_len)
.ok_or_else(|| ProxyError::Proxy("Frame buffer size overflow".into()))?;
let mut len_val = u32::try_from(wire_len)
.map_err(|_| ProxyError::Proxy("Frame length conversion overflow".into()))?;
if quickack {
len_val |= 0x8000_0000;
}
Ok((len_val, total))
}
async fn write_client_payload<W>( async fn write_client_payload<W>(
client_writer: &mut CryptoWriter<W>, client_writer: &mut CryptoWriter<W>,
proto_tag: ProtoTag, proto_tag: ProtoTag,
@ -1650,8 +1603,11 @@ where
} else { } else {
0 0
}; };
let (len_val, total) = let mut len_val = (data.len() + padding_len) as u32;
compute_intermediate_secure_wire_len(data.len(), padding_len, quickack)?; if quickack {
len_val |= 0x8000_0000;
}
let total = 4 + data.len() + padding_len;
frame_buf.clear(); frame_buf.clear();
frame_buf.reserve(total); frame_buf.reserve(total);
frame_buf.extend_from_slice(&len_val.to_le_bytes()); frame_buf.extend_from_slice(&len_val.to_le_bytes());
@ -1701,23 +1657,3 @@ mod idle_policy_security_tests;
#[cfg(test)] #[cfg(test)]
#[path = "tests/middle_relay_desync_all_full_dedup_security_tests.rs"] #[path = "tests/middle_relay_desync_all_full_dedup_security_tests.rs"]
mod desync_all_full_dedup_security_tests; mod desync_all_full_dedup_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_stub_completion_security_tests.rs"]
mod stub_completion_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_coverage_high_risk_security_tests.rs"]
mod coverage_high_risk_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_quota_overflow_lock_security_tests.rs"]
mod quota_overflow_lock_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_length_cast_hardening_security_tests.rs"]
mod length_cast_hardening_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_blackhat_campaign_integration_tests.rs"]
mod blackhat_campaign_integration_tests;

View File

@ -1,71 +1,13 @@
//! Proxy Defs //! Proxy Defs
// Apply strict linting to proxy production code while keeping test builds noise-tolerant.
#![cfg_attr(test, allow(warnings))]
#![cfg_attr(not(test), forbid(clippy::undocumented_unsafe_blocks))]
#![cfg_attr(
not(test),
deny(
clippy::unwrap_used,
clippy::expect_used,
clippy::panic,
clippy::todo,
clippy::unimplemented,
clippy::correctness,
clippy::option_if_let_else,
clippy::or_fun_call,
clippy::branches_sharing_code,
clippy::single_option_map,
clippy::useless_let_if_seq,
clippy::redundant_locals,
clippy::cloned_ref_to_slice_refs,
unsafe_code,
clippy::await_holding_lock,
clippy::await_holding_refcell_ref,
clippy::debug_assert_with_mut_call,
clippy::macro_use_imports,
clippy::cast_ptr_alignment,
clippy::cast_lossless,
clippy::ptr_as_ptr,
clippy::large_stack_arrays,
clippy::same_functions_in_if_condition,
trivial_casts,
trivial_numeric_casts,
unused_extern_crates,
unused_import_braces,
rust_2018_idioms
)
)]
#![cfg_attr(
not(test),
allow(
clippy::use_self,
clippy::redundant_closure,
clippy::too_many_arguments,
clippy::doc_markdown,
clippy::missing_const_for_fn,
clippy::unnecessary_operation,
clippy::redundant_pub_crate,
clippy::derive_partial_eq_without_eq,
clippy::type_complexity,
clippy::new_ret_no_self,
clippy::cast_possible_truncation,
clippy::cast_possible_wrap,
clippy::significant_drop_tightening,
clippy::significant_drop_in_scrutinee,
clippy::float_cmp,
clippy::nursery
)
)]
pub mod adaptive_buffers; pub mod adaptive_buffers;
pub mod client; pub mod client;
pub mod direct_relay; pub mod direct_relay;
pub mod handshake; pub mod handshake;
pub mod masking; pub mod masking;
pub mod middle_relay; pub mod middle_relay;
pub mod relay;
pub mod route_mode; pub mod route_mode;
pub mod relay;
pub mod session_eviction; pub mod session_eviction;
pub use client::ClientHandler; pub use client::ClientHandler;

View File

@ -51,19 +51,21 @@
//! - `poll_write` on client = S→C (to client) → `octets_to`, `msgs_to` //! - `poll_write` on client = S→C (to client) → `octets_to`, `msgs_to`
//! - `SharedCounters` (atomics) let the watchdog read stats without locking //! - `SharedCounters` (atomics) let the watchdog read stats without locking
use std::io;
use std::pin::Pin;
use std::sync::{Arc, Mutex, OnceLock};
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::task::{Context, Poll};
use std::time::Duration;
use dashmap::DashMap;
use tokio::io::{
AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf, copy_bidirectional_with_sizes,
};
use tokio::time::Instant;
use tracing::{debug, trace, warn};
use crate::error::{ProxyError, Result}; use crate::error::{ProxyError, Result};
use crate::stats::Stats; use crate::stats::Stats;
use crate::stream::BufferPool; use crate::stream::BufferPool;
use dashmap::DashMap;
use std::io;
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::{Arc, Mutex, OnceLock};
use std::task::{Context, Poll};
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf, copy_bidirectional_with_sizes};
use tokio::time::Instant;
use tracing::{debug, trace, warn};
// ============= Constants ============= // ============= Constants =============
@ -249,8 +251,7 @@ impl<S> StatsIo<S> {
impl<S> Drop for StatsIo<S> { impl<S> Drop for StatsIo<S> {
fn drop(&mut self) { fn drop(&mut self) {
self.quota_read_retry_active.store(false, Ordering::Relaxed); self.quota_read_retry_active.store(false, Ordering::Relaxed);
self.quota_write_retry_active self.quota_write_retry_active.store(false, Ordering::Relaxed);
.store(false, Ordering::Relaxed);
} }
} }
@ -427,9 +428,7 @@ impl<S: AsyncRead + Unpin> AsyncRead for StatsIo<S> {
} }
// C→S: client sent data // C→S: client sent data
this.counters this.counters.c2s_bytes.fetch_add(n as u64, Ordering::Relaxed);
.c2s_bytes
.fetch_add(n as u64, Ordering::Relaxed);
this.counters.c2s_ops.fetch_add(1, Ordering::Relaxed); this.counters.c2s_ops.fetch_add(1, Ordering::Relaxed);
this.counters.touch(Instant::now(), this.epoch); this.counters.touch(Instant::now(), this.epoch);
@ -468,8 +467,7 @@ impl<S: AsyncWrite + Unpin> AsyncWrite for StatsIo<S> {
match lock.try_lock() { match lock.try_lock() {
Ok(guard) => { Ok(guard) => {
this.quota_write_wake_scheduled = false; this.quota_write_wake_scheduled = false;
this.quota_write_retry_active this.quota_write_retry_active.store(false, Ordering::Relaxed);
.store(false, Ordering::Relaxed);
Some(guard) Some(guard)
} }
Err(_) => { Err(_) => {
@ -511,9 +509,7 @@ impl<S: AsyncWrite + Unpin> AsyncWrite for StatsIo<S> {
Poll::Ready(Ok(n)) => { Poll::Ready(Ok(n)) => {
if n > 0 { if n > 0 {
// S→C: data written to client // S→C: data written to client
this.counters this.counters.s2c_bytes.fetch_add(n as u64, Ordering::Relaxed);
.s2c_bytes
.fetch_add(n as u64, Ordering::Relaxed);
this.counters.s2c_ops.fetch_add(1, Ordering::Relaxed); this.counters.s2c_ops.fetch_add(1, Ordering::Relaxed);
this.counters.touch(Instant::now(), this.epoch); this.counters.touch(Instant::now(), this.epoch);
@ -790,4 +786,4 @@ mod relay_quota_waker_storm_adversarial_tests;
#[cfg(test)] #[cfg(test)]
#[path = "tests/relay_quota_wake_liveness_regression_tests.rs"] #[path = "tests/relay_quota_wake_liveness_regression_tests.rs"]
mod relay_quota_wake_liveness_regression_tests; mod relay_quota_wake_liveness_regression_tests;

View File

@ -119,7 +119,9 @@ pub(crate) fn affected_cutover_state(
} }
pub(crate) fn cutover_stagger_delay(session_id: u64, generation: u64) -> Duration { pub(crate) fn cutover_stagger_delay(session_id: u64, generation: u64) -> Duration {
let mut value = session_id ^ generation.rotate_left(17) ^ 0x9e37_79b9_7f4a_7c15; let mut value = session_id
^ generation.rotate_left(17)
^ 0x9e37_79b9_7f4a_7c15;
value ^= value >> 30; value ^= value >> 30;
value = value.wrapping_mul(0xbf58_476d_1ce4_e5b9); value = value.wrapping_mul(0xbf58_476d_1ce4_e5b9);
value ^= value >> 27; value ^= value >> 27;

View File

@ -1,5 +1,3 @@
#![allow(dead_code)]
/// Session eviction is intentionally disabled in runtime. /// Session eviction is intentionally disabled in runtime.
/// ///
/// The initial `user+dc` single-lease model caused valid parallel client /// The initial `user+dc` single-lease model caused valid parallel client

View File

@ -1,11 +1,11 @@
use super::*; use super::*;
use crate::config::ProxyConfig; use crate::config::ProxyConfig;
use crate::error::ProxyError;
use crate::ip_tracker::UserIpTracker;
use crate::stats::Stats; use crate::stats::Stats;
use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use crate::ip_tracker::UserIpTracker;
use crate::error::ProxyError;
use std::sync::Arc; use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::atomic::{AtomicU64, Ordering};
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
// ------------------------------------------------------------------ // ------------------------------------------------------------------
// Priority 3: Massive Concurrency Stress (OWASP ASVS 5.1.6) // Priority 3: Massive Concurrency Stress (OWASP ASVS 5.1.6)
@ -15,16 +15,13 @@ use std::sync::atomic::{AtomicU64, Ordering};
async fn client_stress_10k_connections_limit_strict() { async fn client_stress_10k_connections_limit_strict() {
let user = "stress-user"; let user = "stress-user";
let limit = 512; let limit = 512;
let stats = Arc::new(Stats::new()); let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new()); let ip_tracker = Arc::new(UserIpTracker::new());
let mut config = ProxyConfig::default(); let mut config = ProxyConfig::default();
config config.access.user_max_tcp_conns.insert(user.to_string(), limit);
.access
.user_max_tcp_conns
.insert(user.to_string(), limit);
let iterations = 1000; let iterations = 1000;
let mut tasks = Vec::new(); let mut tasks = Vec::new();
@ -33,18 +30,20 @@ async fn client_stress_10k_connections_limit_strict() {
let ip_tracker = Arc::clone(&ip_tracker); let ip_tracker = Arc::clone(&ip_tracker);
let config = config.clone(); let config = config.clone();
let user_str = user.to_string(); let user_str = user.to_string();
tasks.push(tokio::spawn(async move { tasks.push(tokio::spawn(async move {
let peer = SocketAddr::new( let peer = SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(127, 0, 0, (i % 254 + 1) as u8)), IpAddr::V4(Ipv4Addr::new(127, 0, 0, (i % 254 + 1) as u8)),
10000 + (i % 1000) as u16, 10000 + (i % 1000) as u16,
); );
match RunningClientHandler::acquire_user_connection_reservation_static( match RunningClientHandler::acquire_user_connection_reservation_static(
&user_str, &config, stats, peer, ip_tracker, &user_str,
) &config,
.await stats,
{ peer,
ip_tracker,
).await {
Ok(res) => Ok(res), Ok(res) => Ok(res),
Err(ProxyError::ConnectionLimitExceeded { .. }) => Err(()), Err(ProxyError::ConnectionLimitExceeded { .. }) => Err(()),
Err(e) => panic!("Unexpected error: {:?}", e), Err(e) => panic!("Unexpected error: {:?}", e),
@ -68,27 +67,15 @@ async fn client_stress_10k_connections_limit_strict() {
} }
assert_eq!(successes, limit, "Should allow exactly 'limit' connections"); assert_eq!(successes, limit, "Should allow exactly 'limit' connections");
assert_eq!( assert_eq!(failures, iterations - limit, "Should fail the rest with LimitExceeded");
failures,
iterations - limit,
"Should fail the rest with LimitExceeded"
);
assert_eq!(stats.get_user_curr_connects(user), limit as u64); assert_eq!(stats.get_user_curr_connects(user), limit as u64);
drop(reservations); drop(reservations);
ip_tracker.drain_cleanup_queue().await; ip_tracker.drain_cleanup_queue().await;
assert_eq!( assert_eq!(stats.get_user_curr_connects(user), 0, "Stats must converge to 0 after all drops");
stats.get_user_curr_connects(user), assert_eq!(ip_tracker.get_active_ip_count(user).await, 0, "IP tracker must converge to 0");
0,
"Stats must converge to 0 after all drops"
);
assert_eq!(
ip_tracker.get_active_ip_count(user).await,
0,
"IP tracker must converge to 0"
);
} }
// ------------------------------------------------------------------ // ------------------------------------------------------------------
@ -100,14 +87,14 @@ async fn client_ip_tracker_race_condition_stress() {
let user = "race-user"; let user = "race-user";
let ip_tracker = Arc::new(UserIpTracker::new()); let ip_tracker = Arc::new(UserIpTracker::new());
ip_tracker.set_user_limit(user, 100).await; ip_tracker.set_user_limit(user, 100).await;
let iterations = 1000; let iterations = 1000;
let mut tasks = Vec::new(); let mut tasks = Vec::new();
for i in 0..iterations { for i in 0..iterations {
let ip_tracker = Arc::clone(&ip_tracker); let ip_tracker = Arc::clone(&ip_tracker);
let ip = IpAddr::V4(Ipv4Addr::new(10, 0, 0, (i % 254 + 1) as u8)); let ip = IpAddr::V4(Ipv4Addr::new(10, 0, 0, (i % 254 + 1) as u8));
tasks.push(tokio::spawn(async move { tasks.push(tokio::spawn(async move {
for _ in 0..10 { for _ in 0..10 {
if let Ok(()) = ip_tracker.check_and_add("race-user", ip).await { if let Ok(()) = ip_tracker.check_and_add("race-user", ip).await {
@ -118,12 +105,8 @@ async fn client_ip_tracker_race_condition_stress() {
} }
futures::future::join_all(tasks).await; futures::future::join_all(tasks).await;
assert_eq!( assert_eq!(ip_tracker.get_active_ip_count(user).await, 0, "IP count must be zero after balanced add/remove burst");
ip_tracker.get_active_ip_count(user).await,
0,
"IP count must be zero after balanced add/remove burst"
);
} }
#[tokio::test] #[tokio::test]
@ -136,10 +119,7 @@ async fn client_limit_burst_peak_never_exceeds_cap() {
let ip_tracker = Arc::new(UserIpTracker::new()); let ip_tracker = Arc::new(UserIpTracker::new());
let mut config = ProxyConfig::default(); let mut config = ProxyConfig::default();
config config.access.user_max_tcp_conns.insert(user.to_string(), limit);
.access
.user_max_tcp_conns
.insert(user.to_string(), limit);
let peak = Arc::new(AtomicU64::new(0)); let peak = Arc::new(AtomicU64::new(0));
let mut tasks = Vec::with_capacity(attempts); let mut tasks = Vec::with_capacity(attempts);
@ -227,10 +207,10 @@ async fn client_expiration_rejection_never_mutates_live_counters() {
let ip_tracker = Arc::new(UserIpTracker::new()); let ip_tracker = Arc::new(UserIpTracker::new());
let mut config = ProxyConfig::default(); let mut config = ProxyConfig::default();
config.access.user_expirations.insert( config
user.to_string(), .access
chrono::Utc::now() - chrono::Duration::seconds(1), .user_expirations
); .insert(user.to_string(), chrono::Utc::now() - chrono::Duration::seconds(1));
let peer: SocketAddr = "198.51.100.202:31112".parse().unwrap(); let peer: SocketAddr = "198.51.100.202:31112".parse().unwrap();
let res = RunningClientHandler::acquire_user_connection_reservation_static( let res = RunningClientHandler::acquire_user_connection_reservation_static(
@ -255,10 +235,7 @@ async fn client_ip_limit_failure_rolls_back_counter_exactly() {
ip_tracker.set_user_limit(user, 1).await; ip_tracker.set_user_limit(user, 1).await;
let mut config = ProxyConfig::default(); let mut config = ProxyConfig::default();
config config.access.user_max_tcp_conns.insert(user.to_string(), 16);
.access
.user_max_tcp_conns
.insert(user.to_string(), 16);
let first_peer: SocketAddr = "198.51.100.203:31113".parse().unwrap(); let first_peer: SocketAddr = "198.51.100.203:31113".parse().unwrap();
let first = RunningClientHandler::acquire_user_connection_reservation_static( let first = RunningClientHandler::acquire_user_connection_reservation_static(
@ -281,10 +258,7 @@ async fn client_ip_limit_failure_rolls_back_counter_exactly() {
) )
.await; .await;
assert!(matches!( assert!(matches!(second, Err(ProxyError::ConnectionLimitExceeded { .. })));
second,
Err(ProxyError::ConnectionLimitExceeded { .. })
));
assert_eq!(stats.get_user_curr_connects(user), 1); assert_eq!(stats.get_user_curr_connects(user), 1);
drop(first); drop(first);
@ -302,10 +276,7 @@ async fn client_parallel_limit_checks_success_path_leaves_no_residue() {
ip_tracker.set_user_limit(user, 128).await; ip_tracker.set_user_limit(user, 128).await;
let mut config = ProxyConfig::default(); let mut config = ProxyConfig::default();
config config.access.user_max_tcp_conns.insert(user.to_string(), 128);
.access
.user_max_tcp_conns
.insert(user.to_string(), 128);
let mut tasks = Vec::new(); let mut tasks = Vec::new();
for i in 0..128u16 { for i in 0..128u16 {
@ -339,10 +310,7 @@ async fn client_parallel_limit_checks_failure_path_leaves_no_residue() {
ip_tracker.set_user_limit(user, 0).await; ip_tracker.set_user_limit(user, 0).await;
let mut config = ProxyConfig::default(); let mut config = ProxyConfig::default();
config config.access.user_max_tcp_conns.insert(user.to_string(), 512);
.access
.user_max_tcp_conns
.insert(user.to_string(), 512);
let mut tasks = Vec::new(); let mut tasks = Vec::new();
for i in 0..64u16 { for i in 0..64u16 {
@ -351,10 +319,7 @@ async fn client_parallel_limit_checks_failure_path_leaves_no_residue() {
let config = config.clone(); let config = config.clone();
tasks.push(tokio::spawn(async move { tasks.push(tokio::spawn(async move {
let peer = SocketAddr::new( let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(172, 16, 0, (i % 250 + 1) as u8)), 33000 + i);
IpAddr::V4(Ipv4Addr::new(172, 16, 0, (i % 250 + 1) as u8)),
33000 + i,
);
RunningClientHandler::check_user_limits_static(user, &config, &stats, peer, &ip_tracker) RunningClientHandler::check_user_limits_static(user, &config, &stats, peer, &ip_tracker)
.await .await
})); }));
@ -395,7 +360,11 @@ async fn client_churn_mixed_success_failure_converges_to_zero_state() {
34000 + (i % 32), 34000 + (i % 32),
); );
let maybe_res = RunningClientHandler::acquire_user_connection_reservation_static( let maybe_res = RunningClientHandler::acquire_user_connection_reservation_static(
user, &config, stats, peer, ip_tracker, user,
&config,
stats,
peer,
ip_tracker,
) )
.await; .await;
@ -432,7 +401,11 @@ async fn client_same_ip_parallel_attempts_allow_at_most_one_when_limit_is_one()
let config = config.clone(); let config = config.clone();
tasks.push(tokio::spawn(async move { tasks.push(tokio::spawn(async move {
RunningClientHandler::acquire_user_connection_reservation_static( RunningClientHandler::acquire_user_connection_reservation_static(
user, &config, stats, peer, ip_tracker, user,
&config,
stats,
peer,
ip_tracker,
) )
.await .await
})); }));
@ -451,10 +424,7 @@ async fn client_same_ip_parallel_attempts_allow_at_most_one_when_limit_is_one()
} }
} }
assert_eq!( assert_eq!(granted, 1, "only one reservation may be granted for same IP with limit=1");
granted, 1,
"only one reservation may be granted for same IP with limit=1"
);
drop(reservations); drop(reservations);
ip_tracker.drain_cleanup_queue().await; ip_tracker.drain_cleanup_queue().await;
assert_eq!(stats.get_user_curr_connects(user), 0); assert_eq!(stats.get_user_curr_connects(user), 0);
@ -469,10 +439,7 @@ async fn client_repeat_acquire_release_cycles_never_accumulate_state() {
ip_tracker.set_user_limit(user, 32).await; ip_tracker.set_user_limit(user, 32).await;
let mut config = ProxyConfig::default(); let mut config = ProxyConfig::default();
config config.access.user_max_tcp_conns.insert(user.to_string(), 32);
.access
.user_max_tcp_conns
.insert(user.to_string(), 32);
for i in 0..500u16 { for i in 0..500u16 {
let peer = SocketAddr::new( let peer = SocketAddr::new(
@ -517,7 +484,11 @@ async fn client_multi_user_isolation_under_parallel_limit_exhaustion() {
37000 + i, 37000 + i,
); );
RunningClientHandler::acquire_user_connection_reservation_static( RunningClientHandler::acquire_user_connection_reservation_static(
user, &config, stats, peer, ip_tracker, user,
&config,
stats,
peer,
ip_tracker,
) )
.await .await
})); }));
@ -526,11 +497,7 @@ async fn client_multi_user_isolation_under_parallel_limit_exhaustion() {
let mut u1_success = 0usize; let mut u1_success = 0usize;
let mut u2_success = 0usize; let mut u2_success = 0usize;
let mut reservations = Vec::new(); let mut reservations = Vec::new();
for (idx, result) in futures::future::join_all(tasks) for (idx, result) in futures::future::join_all(tasks).await.into_iter().enumerate() {
.await
.into_iter()
.enumerate()
{
let user = if idx % 2 == 0 { "u1" } else { "u2" }; let user = if idx % 2 == 0 { "u1" } else { "u2" };
match result.unwrap() { match result.unwrap() {
Ok(reservation) => { Ok(reservation) => {
@ -589,10 +556,7 @@ async fn client_limit_recovery_after_full_rejection_wave() {
ip_tracker.clone(), ip_tracker.clone(),
) )
.await; .await;
assert!(matches!( assert!(matches!(denied, Err(ProxyError::ConnectionLimitExceeded { .. })));
denied,
Err(ProxyError::ConnectionLimitExceeded { .. })
));
} }
drop(reservation); drop(reservation);
@ -608,10 +572,7 @@ async fn client_limit_recovery_after_full_rejection_wave() {
ip_tracker.clone(), ip_tracker.clone(),
) )
.await; .await;
assert!( assert!(recovered.is_ok(), "capacity must recover after prior holder drops");
recovered.is_ok(),
"capacity must recover after prior holder drops"
);
} }
#[tokio::test] #[tokio::test]
@ -658,10 +619,7 @@ async fn client_dual_limit_cross_product_never_leaks_on_reject() {
ip_tracker.clone(), ip_tracker.clone(),
) )
.await; .await;
assert!(matches!( assert!(matches!(denied, Err(ProxyError::ConnectionLimitExceeded { .. })));
denied,
Err(ProxyError::ConnectionLimitExceeded { .. })
));
} }
assert_eq!(stats.get_user_curr_connects(user), 2); assert_eq!(stats.get_user_curr_connects(user), 2);
@ -679,10 +637,7 @@ async fn client_check_user_limits_concurrent_churn_no_counter_drift() {
ip_tracker.set_user_limit(user, 64).await; ip_tracker.set_user_limit(user, 64).await;
let mut config = ProxyConfig::default(); let mut config = ProxyConfig::default();
config config.access.user_max_tcp_conns.insert(user.to_string(), 64);
.access
.user_max_tcp_conns
.insert(user.to_string(), 64);
let mut tasks = Vec::new(); let mut tasks = Vec::new();
for i in 0..512u16 { for i in 0..512u16 {

View File

@ -2,14 +2,17 @@ use super::*;
use crate::config::{UpstreamConfig, UpstreamType}; use crate::config::{UpstreamConfig, UpstreamType};
use crate::crypto::sha256_hmac; use crate::crypto::sha256_hmac;
use crate::protocol::constants::{ use crate::protocol::constants::{
HANDSHAKE_LEN, MAX_TLS_PLAINTEXT_SIZE, MIN_TLS_CLIENT_HELLO_SIZE, TLS_RECORD_APPLICATION, HANDSHAKE_LEN,
MAX_TLS_PLAINTEXT_SIZE,
MIN_TLS_CLIENT_HELLO_SIZE,
TLS_RECORD_APPLICATION,
TLS_VERSION, TLS_VERSION,
}; };
use crate::protocol::tls; use crate::protocol::tls;
use std::collections::HashSet; use std::collections::HashSet;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio::time::{Duration, Instant}; use tokio::time::{Duration, Instant};
@ -76,10 +79,7 @@ fn build_mask_harness(secret_hex: &str, mask_port: u16) -> CampaignHarness {
} }
fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec<u8> { fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec<u8> {
assert!( assert!(tls_len <= u16::MAX as usize, "TLS length must fit into record header");
tls_len <= u16::MAX as usize,
"TLS length must fit into record header"
);
let total_len = 5 + tls_len; let total_len = 5 + tls_len;
let mut handshake = vec![fill; total_len]; let mut handshake = vec![fill; total_len];
@ -171,10 +171,7 @@ async fn run_tls_success_mtproto_fail_capture(
client_side.write_all(&client_hello).await.unwrap(); client_side.write_all(&client_hello).await.unwrap();
let mut tls_response_head = [0u8; 5]; let mut tls_response_head = [0u8; 5];
client_side client_side.read_exact(&mut tls_response_head).await.unwrap();
.read_exact(&mut tls_response_head)
.await
.unwrap();
assert_eq!(tls_response_head[0], 0x16); assert_eq!(tls_response_head[0], 0x16);
read_and_discard_tls_record_body(&mut client_side, tls_response_head).await; read_and_discard_tls_record_body(&mut client_side, tls_response_head).await;
@ -430,10 +427,7 @@ async fn blackhat_campaign_06_replayed_tls_hello_is_masked_without_serverhello()
client_side.read_exact(&mut head).await.unwrap(); client_side.read_exact(&mut head).await.unwrap();
assert_eq!(head[0], 0x16); assert_eq!(head[0], 0x16);
read_and_discard_tls_record_body(&mut client_side, head).await; read_and_discard_tls_record_body(&mut client_side, head).await;
client_side client_side.write_all(&invalid_mtproto_record).await.unwrap();
.write_all(&invalid_mtproto_record)
.await
.unwrap();
client_side.write_all(&first_tail).await.unwrap(); client_side.write_all(&first_tail).await.unwrap();
} else { } else {
let mut one = [0u8; 1]; let mut one = [0u8; 1];
@ -703,15 +697,13 @@ async fn blackhat_campaign_12_parallel_tls_success_mtproto_fail_sessions_keep_is
let mut tasks = Vec::new(); let mut tasks = Vec::new();
for i in 0..sessions { for i in 0..sessions {
let mut harness = let mut harness = build_mask_harness("abababababababababababababababab", backend_addr.port());
build_mask_harness("abababababababababababababababab", backend_addr.port());
let mut cfg = (*harness.config).clone(); let mut cfg = (*harness.config).clone();
cfg.censorship.mask_port = backend_addr.port(); cfg.censorship.mask_port = backend_addr.port();
harness.config = Arc::new(cfg); harness.config = Arc::new(cfg);
tasks.push(tokio::spawn(async move { tasks.push(tokio::spawn(async move {
let secret = [0xABu8; 16]; let secret = [0xABu8; 16];
let hello = let hello = make_valid_tls_client_hello(&secret, 100 + i as u32, 600, 0x40 + (i as u8 % 10));
make_valid_tls_client_hello(&secret, 100 + i as u32, 600, 0x40 + (i as u8 % 10));
let bad = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); let bad = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]);
let tail = wrap_tls_application_data(&vec![i as u8; 8 + i]); let tail = wrap_tls_application_data(&vec![i as u8; 8 + i]);
let (server_side, mut client_side) = duplex(131072); let (server_side, mut client_side) = duplex(131072);
@ -851,12 +843,12 @@ async fn blackhat_campaign_15_light_fuzz_tls_lengths_and_fragmentation() {
tls_len = MAX_TLS_PLAINTEXT_SIZE + 1 + (tls_len % 1024); tls_len = MAX_TLS_PLAINTEXT_SIZE + 1 + (tls_len % 1024);
} }
let body_to_send = let body_to_send = if (MIN_TLS_CLIENT_HELLO_SIZE..=MAX_TLS_PLAINTEXT_SIZE).contains(&tls_len)
if (MIN_TLS_CLIENT_HELLO_SIZE..=MAX_TLS_PLAINTEXT_SIZE).contains(&tls_len) { {
(seed as usize % 29).min(tls_len.saturating_sub(1)) (seed as usize % 29).min(tls_len.saturating_sub(1))
} else { } else {
0 0
}; };
let mut probe = vec![0u8; 5 + body_to_send]; let mut probe = vec![0u8; 5 + body_to_send];
probe[0] = 0x16; probe[0] = 0x16;
@ -864,9 +856,7 @@ async fn blackhat_campaign_15_light_fuzz_tls_lengths_and_fragmentation() {
probe[2] = 0x01; probe[2] = 0x01;
probe[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes()); probe[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes());
for b in &mut probe[5..] { for b in &mut probe[5..] {
seed = seed seed = seed.wrapping_mul(2862933555777941757).wrapping_add(3037000493);
.wrapping_mul(2862933555777941757)
.wrapping_add(3037000493);
*b = (seed >> 24) as u8; *b = (seed >> 24) as u8;
} }
@ -889,8 +879,7 @@ async fn blackhat_campaign_16_mixed_probe_burst_stress_finishes_without_panics()
probe[2] = 0x01; probe[2] = 0x01;
probe[3..5].copy_from_slice(&600u16.to_be_bytes()); probe[3..5].copy_from_slice(&600u16.to_be_bytes());
probe[5..].fill((0x90 + i as u8) ^ 0x5A); probe[5..].fill((0x90 + i as u8) ^ 0x5A);
run_invalid_tls_capture(Arc::new(ProxyConfig::default()), probe.clone(), probe) run_invalid_tls_capture(Arc::new(ProxyConfig::default()), probe.clone(), probe).await;
.await;
} else { } else {
let hdr = vec![0x16, 0x03, 0x01, 0xFF, i as u8]; let hdr = vec![0x16, 0x03, 0x01, 0xFF, i as u8];
run_invalid_tls_capture(Arc::new(ProxyConfig::default()), hdr.clone(), hdr).await; run_invalid_tls_capture(Arc::new(ProxyConfig::default()), hdr.clone(), hdr).await;

View File

@ -3,7 +3,7 @@ use crate::config::{UpstreamConfig, UpstreamType};
use crate::crypto::sha256_hmac; use crate::crypto::sha256_hmac;
use crate::protocol::constants::{HANDSHAKE_LEN, TLS_VERSION}; use crate::protocol::constants::{HANDSHAKE_LEN, TLS_VERSION};
use crate::protocol::tls; use crate::protocol::tls;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio::time::{Duration, Instant}; use tokio::time::{Duration, Instant};
@ -55,10 +55,7 @@ fn build_harness(config: ProxyConfig) -> PipelineHarness {
} }
fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec<u8> { fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec<u8> {
assert!( assert!(tls_len <= u16::MAX as usize, "TLS length must fit into record header");
tls_len <= u16::MAX as usize,
"TLS length must fit into record header"
);
let total_len = 5 + tls_len; let total_len = 5 + tls_len;
let mut handshake = vec![fill; total_len]; let mut handshake = vec![fill; total_len];
@ -153,10 +150,7 @@ async fn masking_runs_outside_handshake_timeout_budget_with_high_reject_delay()
.unwrap() .unwrap()
.unwrap(); .unwrap();
assert!( assert!(result.is_ok(), "bad-client fallback must not be canceled by handshake timeout");
result.is_ok(),
"bad-client fallback must not be canceled by handshake timeout"
);
assert_eq!( assert_eq!(
stats.get_handshake_timeouts(), stats.get_handshake_timeouts(),
0, 0,
@ -181,10 +175,10 @@ async fn tls_mtproto_bad_client_does_not_reinject_clienthello_into_mask_backend(
config.censorship.mask_port = backend_addr.port(); config.censorship.mask_port = backend_addr.port();
config.censorship.mask_proxy_protocol = 0; config.censorship.mask_proxy_protocol = 0;
config.access.ignore_time_skew = true; config.access.ignore_time_skew = true;
config.access.users.insert( config
"user".to_string(), .access
"d0d0d0d0d0d0d0d0d0d0d0d0d0d0d0d0".to_string(), .users
); .insert("user".to_string(), "d0d0d0d0d0d0d0d0d0d0d0d0d0d0d0d0".to_string());
let harness = build_harness(config); let harness = build_harness(config);
@ -200,7 +194,8 @@ async fn tls_mtproto_bad_client_does_not_reinject_clienthello_into_mask_backend(
let mut got = vec![0u8; expected_trailing.len()]; let mut got = vec![0u8; expected_trailing.len()];
stream.read_exact(&mut got).await.unwrap(); stream.read_exact(&mut got).await.unwrap();
assert_eq!( assert_eq!(
got, expected_trailing, got,
expected_trailing,
"mask backend must receive only post-handshake trailing TLS records" "mask backend must receive only post-handshake trailing TLS records"
); );
}); });
@ -228,17 +223,11 @@ async fn tls_mtproto_bad_client_does_not_reinject_clienthello_into_mask_backend(
client_side.write_all(&client_hello).await.unwrap(); client_side.write_all(&client_hello).await.unwrap();
let mut tls_response_head = [0u8; 5]; let mut tls_response_head = [0u8; 5];
client_side client_side.read_exact(&mut tls_response_head).await.unwrap();
.read_exact(&mut tls_response_head)
.await
.unwrap();
assert_eq!(tls_response_head[0], 0x16); assert_eq!(tls_response_head[0], 0x16);
read_and_discard_tls_record_body(&mut client_side, tls_response_head).await; read_and_discard_tls_record_body(&mut client_side, tls_response_head).await;
client_side client_side.write_all(&invalid_mtproto_record).await.unwrap();
.write_all(&invalid_mtproto_record)
.await
.unwrap();
client_side.write_all(&trailing_record).await.unwrap(); client_side.write_all(&trailing_record).await.unwrap();
tokio::time::timeout(Duration::from_secs(3), accept_task) tokio::time::timeout(Duration::from_secs(3), accept_task)

View File

@ -1,7 +1,7 @@
use super::*; use super::*;
use crate::config::{UpstreamConfig, UpstreamType}; use crate::config::{UpstreamConfig, UpstreamType};
use std::sync::Arc; use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio::time::{Duration, Instant}; use tokio::time::{Duration, Instant};
@ -163,36 +163,21 @@ async fn diagnostic_timing_profiles_are_within_realistic_guardrails() {
); );
assert!(p50 >= 650, "p50 too low for delayed reject class={}", class); assert!(p50 >= 650, "p50 too low for delayed reject class={}", class);
assert!( assert!(p95 <= 1200, "p95 too high for delayed reject class={}", class);
p95 <= 1200, assert!(max <= 1500, "max too high for delayed reject class={}", class);
"p95 too high for delayed reject class={}",
class
);
assert!(
max <= 1500,
"max too high for delayed reject class={}",
class
);
} }
} }
#[tokio::test] #[tokio::test]
async fn diagnostic_forwarded_size_profiles_by_probe_class() { async fn diagnostic_forwarded_size_profiles_by_probe_class() {
let classes = [ let classes = [0usize, 1usize, 7usize, 17usize, 63usize, 511usize, 1023usize, 2047usize];
0usize, 1usize, 7usize, 17usize, 63usize, 511usize, 1023usize, 2047usize,
];
let mut observed = Vec::new(); let mut observed = Vec::new();
for class in classes { for class in classes {
let len = capture_forwarded_len(class).await; let len = capture_forwarded_len(class).await;
println!("diagnostic_shape class={} forwarded_len={}", class, len); println!("diagnostic_shape class={} forwarded_len={}", class, len);
observed.push(len as u128); observed.push(len as u128);
assert_eq!( assert_eq!(len, 5 + class, "unexpected forwarded len for class={}", class);
len,
5 + class,
"unexpected forwarded len for class={}",
class
);
} }
let p50 = percentile_ms(observed.clone(), 50, 100); let p50 = percentile_ms(observed.clone(), 50, 100);

View File

@ -3,7 +3,7 @@ use crate::config::{UpstreamConfig, UpstreamType};
use crate::crypto::sha256_hmac; use crate::crypto::sha256_hmac;
use crate::protocol::constants::{HANDSHAKE_LEN, TLS_RECORD_APPLICATION, TLS_VERSION}; use crate::protocol::constants::{HANDSHAKE_LEN, TLS_RECORD_APPLICATION, TLS_VERSION};
use crate::protocol::tls; use crate::protocol::tls;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio::time::{Duration, Instant}; use tokio::time::{Duration, Instant};
@ -70,10 +70,7 @@ fn build_harness(secret_hex: &str, mask_port: u16) -> Harness {
} }
fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec<u8> { fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec<u8> {
assert!( assert!(tls_len <= u16::MAX as usize, "TLS length must fit into record header");
tls_len <= u16::MAX as usize,
"TLS length must fit into record header"
);
let total_len = 5 + tls_len; let total_len = 5 + tls_len;
let mut handshake = vec![fill; total_len]; let mut handshake = vec![fill; total_len];
@ -161,17 +158,11 @@ async fn run_tls_success_mtproto_fail_capture(
client_side.write_all(&client_hello).await.unwrap(); client_side.write_all(&client_hello).await.unwrap();
let mut tls_response_head = [0u8; 5]; let mut tls_response_head = [0u8; 5];
client_side client_side.read_exact(&mut tls_response_head).await.unwrap();
.read_exact(&mut tls_response_head)
.await
.unwrap();
assert_eq!(tls_response_head[0], 0x16); assert_eq!(tls_response_head[0], 0x16);
read_tls_record_body(&mut client_side, tls_response_head).await; read_tls_record_body(&mut client_side, tls_response_head).await;
client_side client_side.write_all(&invalid_mtproto_record).await.unwrap();
.write_all(&invalid_mtproto_record)
.await
.unwrap();
for record in trailing_records { for record in trailing_records {
client_side.write_all(&record).await.unwrap(); client_side.write_all(&record).await.unwrap();
} }
@ -339,10 +330,7 @@ async fn replayed_tls_hello_gets_no_serverhello_and_is_masked() {
client_side.read_exact(&mut head).await.unwrap(); client_side.read_exact(&mut head).await.unwrap();
assert_eq!(head[0], 0x16); assert_eq!(head[0], 0x16);
read_tls_record_body(&mut client_side, head).await; read_tls_record_body(&mut client_side, head).await;
client_side client_side.write_all(&invalid_mtproto_record).await.unwrap();
.write_all(&invalid_mtproto_record)
.await
.unwrap();
client_side.write_all(&first_tail).await.unwrap(); client_side.write_all(&first_tail).await.unwrap();
} else { } else {
let mut one = [0u8; 1]; let mut one = [0u8; 1];
@ -414,10 +402,7 @@ async fn connects_bad_increments_once_per_invalid_mtproto() {
let mut head = [0u8; 5]; let mut head = [0u8; 5];
client_side.read_exact(&mut head).await.unwrap(); client_side.read_exact(&mut head).await.unwrap();
read_tls_record_body(&mut client_side, head).await; read_tls_record_body(&mut client_side, head).await;
client_side client_side.write_all(&invalid_mtproto_record).await.unwrap();
.write_all(&invalid_mtproto_record)
.await
.unwrap();
client_side.write_all(&tail).await.unwrap(); client_side.write_all(&tail).await.unwrap();
tokio::time::timeout(Duration::from_secs(3), accept_task) tokio::time::timeout(Duration::from_secs(3), accept_task)
@ -640,8 +625,7 @@ async fn concurrent_tls_mtproto_fail_sessions_are_isolated() {
for idx in 0..sessions { for idx in 0..sessions {
let secret_hex = "c4c4c4c4c4c4c4c4c4c4c4c4c4c4c4c4"; let secret_hex = "c4c4c4c4c4c4c4c4c4c4c4c4c4c4c4c4";
let harness = build_harness(secret_hex, backend_addr.port()); let harness = build_harness(secret_hex, backend_addr.port());
let hello = let hello = make_valid_tls_client_hello(&[0xC4; 16], 20 + idx as u32, 600, 0x40 + idx as u8);
make_valid_tls_client_hello(&[0xC4; 16], 20 + idx as u32, 600, 0x40 + idx as u8);
let invalid_mtproto = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); let invalid_mtproto = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]);
let trailing = wrap_tls_application_data(&vec![idx as u8; 32 + idx]); let trailing = wrap_tls_application_data(&vec![idx as u8; 32 + idx]);
let peer: SocketAddr = format!("198.51.100.217:{}", 56100 + idx as u16) let peer: SocketAddr = format!("198.51.100.217:{}", 56100 + idx as u16)
@ -701,67 +685,17 @@ macro_rules! tail_length_case {
*b = (i as u8).wrapping_mul(17).wrapping_add(5); *b = (i as u8).wrapping_mul(17).wrapping_add(5);
} }
let record = wrap_tls_application_data(&payload); let record = wrap_tls_application_data(&payload);
let got = let got = run_tls_success_mtproto_fail_capture($hex, $secret, $ts, vec![record.clone()]).await;
run_tls_success_mtproto_fail_capture($hex, $secret, $ts, vec![record.clone()])
.await;
assert_eq!(got, record); assert_eq!(got, record);
} }
}; };
} }
tail_length_case!( tail_length_case!(tail_len_1_preserved, "d1d1d1d1d1d1d1d1d1d1d1d1d1d1d1d1", [0xD1; 16], 30, 1);
tail_len_1_preserved, tail_length_case!(tail_len_2_preserved, "d2d2d2d2d2d2d2d2d2d2d2d2d2d2d2d2", [0xD2; 16], 31, 2);
"d1d1d1d1d1d1d1d1d1d1d1d1d1d1d1d1", tail_length_case!(tail_len_3_preserved, "d3d3d3d3d3d3d3d3d3d3d3d3d3d3d3d3", [0xD3; 16], 32, 3);
[0xD1; 16], tail_length_case!(tail_len_7_preserved, "d4d4d4d4d4d4d4d4d4d4d4d4d4d4d4d4", [0xD4; 16], 33, 7);
30, tail_length_case!(tail_len_31_preserved, "d5d5d5d5d5d5d5d5d5d5d5d5d5d5d5d5", [0xD5; 16], 34, 31);
1 tail_length_case!(tail_len_127_preserved, "d6d6d6d6d6d6d6d6d6d6d6d6d6d6d6d6", [0xD6; 16], 35, 127);
); tail_length_case!(tail_len_511_preserved, "d7d7d7d7d7d7d7d7d7d7d7d7d7d7d7d7", [0xD7; 16], 36, 511);
tail_length_case!( tail_length_case!(tail_len_1023_preserved, "d8d8d8d8d8d8d8d8d8d8d8d8d8d8d8d8", [0xD8; 16], 37, 1023);
tail_len_2_preserved,
"d2d2d2d2d2d2d2d2d2d2d2d2d2d2d2d2",
[0xD2; 16],
31,
2
);
tail_length_case!(
tail_len_3_preserved,
"d3d3d3d3d3d3d3d3d3d3d3d3d3d3d3d3",
[0xD3; 16],
32,
3
);
tail_length_case!(
tail_len_7_preserved,
"d4d4d4d4d4d4d4d4d4d4d4d4d4d4d4d4",
[0xD4; 16],
33,
7
);
tail_length_case!(
tail_len_31_preserved,
"d5d5d5d5d5d5d5d5d5d5d5d5d5d5d5d5",
[0xD5; 16],
34,
31
);
tail_length_case!(
tail_len_127_preserved,
"d6d6d6d6d6d6d6d6d6d6d6d6d6d6d6d6",
[0xD6; 16],
35,
127
);
tail_length_case!(
tail_len_511_preserved,
"d7d7d7d7d7d7d7d7d7d7d7d7d7d7d7d7",
[0xD7; 16],
36,
511
);
tail_length_case!(
tail_len_1023_preserved,
"d8d8d8d8d8d8d8d8d8d8d8d8d8d8d8d8",
[0xD8; 16],
37,
1023
);

View File

@ -5,7 +5,7 @@ use rand::{Rng, SeedableRng};
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream}; use tokio::net::{TcpListener, TcpStream};
const REPLY_404: &[u8] = b"HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n"; const REPLY_404: &[u8] = b"HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n";
@ -92,13 +92,10 @@ async fn run_generic_probe_and_capture_prefix(payload: Vec<u8>, expected_prefix:
client_side.shutdown().await.unwrap(); client_side.shutdown().await.unwrap();
let mut observed = vec![0u8; REPLY_404.len()]; let mut observed = vec![0u8; REPLY_404.len()];
tokio::time::timeout( tokio::time::timeout(Duration::from_secs(2), client_side.read_exact(&mut observed))
Duration::from_secs(2), .await
client_side.read_exact(&mut observed), .unwrap()
) .unwrap();
.await
.unwrap()
.unwrap();
assert_eq!(observed, REPLY_404); assert_eq!(observed, REPLY_404);
let got = tokio::time::timeout(Duration::from_secs(2), accept_task) let got = tokio::time::timeout(Duration::from_secs(2), accept_task)
@ -267,8 +264,7 @@ async fn stress_parallel_probe_mix_masks_all_sessions_without_cross_leakage() {
let mut expected = std::collections::HashSet::new(); let mut expected = std::collections::HashSet::new();
for idx in 0..session_count { for idx in 0..session_count {
let probe = let probe = format!("GET /stress-{idx} HTTP/1.1\r\nHost: s{idx}.example\r\n\r\n").into_bytes();
format!("GET /stress-{idx} HTTP/1.1\r\nHost: s{idx}.example\r\n\r\n").into_bytes();
expected.insert(probe); expected.insert(probe);
} }
@ -278,15 +274,9 @@ async fn stress_parallel_probe_mix_masks_all_sessions_without_cross_leakage() {
let (mut stream, _) = listener.accept().await.unwrap(); let (mut stream, _) = listener.accept().await.unwrap();
let head = read_http_probe_header(&mut stream).await; let head = read_http_probe_header(&mut stream).await;
stream.write_all(REPLY_404).await.unwrap(); stream.write_all(REPLY_404).await.unwrap();
assert!( assert!(remaining.remove(&head), "backend received unexpected or duplicated probe prefix");
remaining.remove(&head),
"backend received unexpected or duplicated probe prefix"
);
} }
assert!( assert!(remaining.is_empty(), "all session prefixes must be observed exactly once");
remaining.is_empty(),
"all session prefixes must be observed exactly once"
);
}); });
let mut tasks = Vec::with_capacity(session_count); let mut tasks = Vec::with_capacity(session_count);
@ -301,8 +291,7 @@ async fn stress_parallel_probe_mix_masks_all_sessions_without_cross_leakage() {
let ip_tracker = Arc::new(UserIpTracker::new()); let ip_tracker = Arc::new(UserIpTracker::new());
let beobachten = Arc::new(BeobachtenStore::new()); let beobachten = Arc::new(BeobachtenStore::new());
let probe = let probe = format!("GET /stress-{idx} HTTP/1.1\r\nHost: s{idx}.example\r\n\r\n").into_bytes();
format!("GET /stress-{idx} HTTP/1.1\r\nHost: s{idx}.example\r\n\r\n").into_bytes();
let peer: SocketAddr = format!("203.0.113.{}:{}", 30 + idx, 56000 + idx) let peer: SocketAddr = format!("203.0.113.{}:{}", 30 + idx, 56000 + idx)
.parse() .parse()
.unwrap(); .unwrap();
@ -330,13 +319,10 @@ async fn stress_parallel_probe_mix_masks_all_sessions_without_cross_leakage() {
client_side.shutdown().await.unwrap(); client_side.shutdown().await.unwrap();
let mut observed = vec![0u8; REPLY_404.len()]; let mut observed = vec![0u8; REPLY_404.len()];
tokio::time::timeout( tokio::time::timeout(Duration::from_secs(2), client_side.read_exact(&mut observed))
Duration::from_secs(2), .await
client_side.read_exact(&mut observed), .unwrap()
) .unwrap();
.await
.unwrap()
.unwrap();
assert_eq!(observed, REPLY_404); assert_eq!(observed, REPLY_404);
let result = tokio::time::timeout(Duration::from_secs(2), handler) let result = tokio::time::timeout(Duration::from_secs(2), handler)

View File

@ -3,7 +3,7 @@ use crate::config::{UpstreamConfig, UpstreamType};
use crate::crypto::sha256_hmac; use crate::crypto::sha256_hmac;
use crate::protocol::constants::{HANDSHAKE_LEN, TLS_VERSION}; use crate::protocol::constants::{HANDSHAKE_LEN, TLS_VERSION};
use crate::protocol::tls; use crate::protocol::tls;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio::time::{Duration, Instant}; use tokio::time::{Duration, Instant};
@ -67,10 +67,7 @@ fn build_harness(secret_hex: &str, mask_port: u16) -> RedTeamHarness {
} }
fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec<u8> { fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec<u8> {
assert!( assert!(tls_len <= u16::MAX as usize, "TLS length must fit into record header");
tls_len <= u16::MAX as usize,
"TLS length must fit into record header"
);
let total_len = 5 + tls_len; let total_len = 5 + tls_len;
let mut handshake = vec![fill; total_len]; let mut handshake = vec![fill; total_len];
@ -151,14 +148,8 @@ async fn run_tls_success_mtproto_fail_session(
let mut body = vec![0u8; body_len]; let mut body = vec![0u8; body_len];
client_side.read_exact(&mut body).await.unwrap(); client_side.read_exact(&mut body).await.unwrap();
client_side client_side.write_all(&invalid_mtproto_record).await.unwrap();
.write_all(&invalid_mtproto_record) client_side.write_all(&wrap_tls_application_data(&tail)).await.unwrap();
.await
.unwrap();
client_side
.write_all(&wrap_tls_application_data(&tail))
.await
.unwrap();
let forwarded = tokio::time::timeout(Duration::from_secs(3), accept_task) let forwarded = tokio::time::timeout(Duration::from_secs(3), accept_task)
.await .await
@ -184,10 +175,7 @@ async fn redteam_01_backend_receives_no_data_after_mtproto_fail() {
b"probe-a".to_vec(), b"probe-a".to_vec(),
) )
.await; .await;
assert!( assert!(forwarded.is_empty(), "backend unexpectedly received fallback bytes");
forwarded.is_empty(),
"backend unexpectedly received fallback bytes"
);
} }
#[tokio::test] #[tokio::test]
@ -200,10 +188,7 @@ async fn redteam_02_backend_must_never_receive_tls_records_after_mtproto_fail()
b"probe-b".to_vec(), b"probe-b".to_vec(),
) )
.await; .await;
assert_ne!( assert_ne!(forwarded[0], 0x17, "received TLS application record despite strict policy");
forwarded[0], 0x17,
"received TLS application record despite strict policy"
);
} }
#[tokio::test] #[tokio::test]
@ -215,10 +200,9 @@ async fn redteam_03_masking_duration_must_be_less_than_1ms_when_backend_down() {
cfg.censorship.mask_host = Some("127.0.0.1".to_string()); cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = 1; cfg.censorship.mask_port = 1;
cfg.access.ignore_time_skew = true; cfg.access.ignore_time_skew = true;
cfg.access.users.insert( cfg.access
"user".to_string(), .users
"acacacacacacacacacacacacacacacac".to_string(), .insert("user".to_string(), "acacacacacacacacacacacacacacacac".to_string());
);
let harness = RedTeamHarness { let harness = RedTeamHarness {
config: Arc::new(cfg), config: Arc::new(cfg),
@ -277,10 +261,7 @@ async fn redteam_03_masking_duration_must_be_less_than_1ms_when_backend_down() {
.unwrap() .unwrap()
.unwrap(); .unwrap();
assert!( assert!(started.elapsed() < Duration::from_millis(1), "fallback path took longer than 1ms");
started.elapsed() < Duration::from_millis(1),
"fallback path took longer than 1ms"
);
} }
macro_rules! redteam_tail_must_not_forward_case { macro_rules! redteam_tail_must_not_forward_case {
@ -302,90 +283,18 @@ macro_rules! redteam_tail_must_not_forward_case {
}; };
} }
redteam_tail_must_not_forward_case!( redteam_tail_must_not_forward_case!(redteam_04_tail_len_1_not_forwarded, "adadadadadadadadadadadadadadadad", [0xAD; 16], 4, 1);
redteam_04_tail_len_1_not_forwarded, redteam_tail_must_not_forward_case!(redteam_05_tail_len_2_not_forwarded, "aeaeaeaeaeaeaeaeaeaeaeaeaeaeaeae", [0xAE; 16], 5, 2);
"adadadadadadadadadadadadadadadad", redteam_tail_must_not_forward_case!(redteam_06_tail_len_3_not_forwarded, "afafafafafafafafafafafafafafafaf", [0xAF; 16], 6, 3);
[0xAD; 16], redteam_tail_must_not_forward_case!(redteam_07_tail_len_7_not_forwarded, "b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0", [0xB0; 16], 7, 7);
4, redteam_tail_must_not_forward_case!(redteam_08_tail_len_15_not_forwarded, "b1b1b1b1b1b1b1b1b1b1b1b1b1b1b1b1", [0xB1; 16], 8, 15);
1 redteam_tail_must_not_forward_case!(redteam_09_tail_len_63_not_forwarded, "b2b2b2b2b2b2b2b2b2b2b2b2b2b2b2b2", [0xB2; 16], 9, 63);
); redteam_tail_must_not_forward_case!(redteam_10_tail_len_127_not_forwarded, "b3b3b3b3b3b3b3b3b3b3b3b3b3b3b3b3", [0xB3; 16], 10, 127);
redteam_tail_must_not_forward_case!( redteam_tail_must_not_forward_case!(redteam_11_tail_len_255_not_forwarded, "b4b4b4b4b4b4b4b4b4b4b4b4b4b4b4b4", [0xB4; 16], 11, 255);
redteam_05_tail_len_2_not_forwarded, redteam_tail_must_not_forward_case!(redteam_12_tail_len_511_not_forwarded, "b5b5b5b5b5b5b5b5b5b5b5b5b5b5b5b5", [0xB5; 16], 12, 511);
"aeaeaeaeaeaeaeaeaeaeaeaeaeaeaeae", redteam_tail_must_not_forward_case!(redteam_13_tail_len_1023_not_forwarded, "b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6", [0xB6; 16], 13, 1023);
[0xAE; 16], redteam_tail_must_not_forward_case!(redteam_14_tail_len_2047_not_forwarded, "b7b7b7b7b7b7b7b7b7b7b7b7b7b7b7b7", [0xB7; 16], 14, 2047);
5, redteam_tail_must_not_forward_case!(redteam_15_tail_len_4095_not_forwarded, "b8b8b8b8b8b8b8b8b8b8b8b8b8b8b8b8", [0xB8; 16], 15, 4095);
2
);
redteam_tail_must_not_forward_case!(
redteam_06_tail_len_3_not_forwarded,
"afafafafafafafafafafafafafafafaf",
[0xAF; 16],
6,
3
);
redteam_tail_must_not_forward_case!(
redteam_07_tail_len_7_not_forwarded,
"b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0",
[0xB0; 16],
7,
7
);
redteam_tail_must_not_forward_case!(
redteam_08_tail_len_15_not_forwarded,
"b1b1b1b1b1b1b1b1b1b1b1b1b1b1b1b1",
[0xB1; 16],
8,
15
);
redteam_tail_must_not_forward_case!(
redteam_09_tail_len_63_not_forwarded,
"b2b2b2b2b2b2b2b2b2b2b2b2b2b2b2b2",
[0xB2; 16],
9,
63
);
redteam_tail_must_not_forward_case!(
redteam_10_tail_len_127_not_forwarded,
"b3b3b3b3b3b3b3b3b3b3b3b3b3b3b3b3",
[0xB3; 16],
10,
127
);
redteam_tail_must_not_forward_case!(
redteam_11_tail_len_255_not_forwarded,
"b4b4b4b4b4b4b4b4b4b4b4b4b4b4b4b4",
[0xB4; 16],
11,
255
);
redteam_tail_must_not_forward_case!(
redteam_12_tail_len_511_not_forwarded,
"b5b5b5b5b5b5b5b5b5b5b5b5b5b5b5b5",
[0xB5; 16],
12,
511
);
redteam_tail_must_not_forward_case!(
redteam_13_tail_len_1023_not_forwarded,
"b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6",
[0xB6; 16],
13,
1023
);
redteam_tail_must_not_forward_case!(
redteam_14_tail_len_2047_not_forwarded,
"b7b7b7b7b7b7b7b7b7b7b7b7b7b7b7b7",
[0xB7; 16],
14,
2047
);
redteam_tail_must_not_forward_case!(
redteam_15_tail_len_4095_not_forwarded,
"b8b8b8b8b8b8b8b8b8b8b8b8b8b8b8b8",
[0xB8; 16],
15,
4095
);
#[tokio::test] #[tokio::test]
#[ignore = "red-team expected-fail: impossible indistinguishability envelope"] #[ignore = "red-team expected-fail: impossible indistinguishability envelope"]
@ -440,13 +349,14 @@ async fn redteam_16_timing_delta_between_paths_must_be_sub_1ms_under_concurrency
let min = durations.iter().copied().min().unwrap(); let min = durations.iter().copied().min().unwrap();
let max = durations.iter().copied().max().unwrap(); let max = durations.iter().copied().max().unwrap();
assert!( assert!(max - min <= Duration::from_millis(1), "timing spread too wide for strict anti-probing envelope");
max - min <= Duration::from_millis(1),
"timing spread too wide for strict anti-probing envelope"
);
} }
async fn measure_invalid_probe_duration_ms(delay_ms: u64, tls_len: u16, body_sent: usize) -> u128 { async fn measure_invalid_probe_duration_ms(
delay_ms: u64,
tls_len: u16,
body_sent: usize,
) -> u128 {
let mut cfg = ProxyConfig::default(); let mut cfg = ProxyConfig::default();
cfg.general.beobachten = false; cfg.general.beobachten = false;
cfg.censorship.mask = true; cfg.censorship.mask = true;
@ -591,8 +501,7 @@ macro_rules! redteam_timing_envelope_case {
#[tokio::test] #[tokio::test]
#[ignore = "red-team expected-fail: unrealistically tight reject timing envelope"] #[ignore = "red-team expected-fail: unrealistically tight reject timing envelope"]
async fn $name() { async fn $name() {
let elapsed_ms = let elapsed_ms = measure_invalid_probe_duration_ms($delay_ms, $tls_len, $body_sent).await;
measure_invalid_probe_duration_ms($delay_ms, $tls_len, $body_sent).await;
assert!( assert!(
elapsed_ms <= $max_ms, elapsed_ms <= $max_ms,
"timing envelope violated: elapsed={}ms, max={}ms", "timing envelope violated: elapsed={}ms, max={}ms",
@ -610,9 +519,11 @@ macro_rules! redteam_constant_shape_case {
async fn $name() { async fn $name() {
let got = capture_forwarded_probe_len($tls_len, $body_sent).await; let got = capture_forwarded_probe_len($tls_len, $body_sent).await;
assert_eq!( assert_eq!(
got, $expected_len, got,
$expected_len,
"fingerprint shape mismatch: got={} expected={} (strict constant-shape model)", "fingerprint shape mismatch: got={} expected={} (strict constant-shape model)",
got, $expected_len got,
$expected_len
); );
} }
}; };

View File

@ -1,7 +1,7 @@
use super::*; use super::*;
use crate::config::{UpstreamConfig, UpstreamType}; use crate::config::{UpstreamConfig, UpstreamType};
use std::sync::Arc; use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio::time::Duration; use tokio::time::Duration;
@ -172,10 +172,7 @@ async fn redteam_fuzz_01_hardened_output_length_correlation_should_be_below_0_2(
let y_hard: Vec<f64> = hardened.iter().map(|v| *v as f64).collect(); let y_hard: Vec<f64> = hardened.iter().map(|v| *v as f64).collect();
let corr_hard = pearson_corr(&x, &y_hard).abs(); let corr_hard = pearson_corr(&x, &y_hard).abs();
println!( println!("redteam_fuzz corr_hardened={corr_hard:.4} samples={}", sizes.len());
"redteam_fuzz corr_hardened={corr_hard:.4} samples={}",
sizes.len()
);
assert!( assert!(
corr_hard < 0.2, corr_hard < 0.2,
@ -237,7 +234,9 @@ async fn redteam_fuzz_03_hardened_signal_must_be_10x_lower_than_plain() {
let corr_plain = pearson_corr(&x, &y_plain).abs(); let corr_plain = pearson_corr(&x, &y_plain).abs();
let corr_hard = pearson_corr(&x, &y_hard).abs(); let corr_hard = pearson_corr(&x, &y_hard).abs();
println!("redteam_fuzz corr_plain={corr_plain:.4} corr_hardened={corr_hard:.4}"); println!(
"redteam_fuzz corr_plain={corr_plain:.4} corr_hardened={corr_hard:.4}"
);
assert!( assert!(
corr_hard <= corr_plain * 0.1, corr_hard <= corr_plain * 0.1,

View File

@ -1,7 +1,7 @@
use super::*; use super::*;
use crate::config::{UpstreamConfig, UpstreamType}; use crate::config::{UpstreamConfig, UpstreamType};
use std::sync::Arc; use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio::time::Duration; use tokio::time::Duration;

View File

@ -1,7 +1,7 @@
use super::*; use super::*;
use crate::config::{UpstreamConfig, UpstreamType}; use crate::config::{UpstreamConfig, UpstreamType};
use std::sync::Arc; use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio::time::{Duration, Instant}; use tokio::time::{Duration, Instant};
@ -164,7 +164,10 @@ async fn redteam_shape_02_padding_tail_must_be_non_deterministic() {
let cap = 4096usize; let cap = 4096usize;
let got = run_probe_capture(17, 600, true, floor, cap).await; let got = run_probe_capture(17, 600, true, floor, cap).await;
assert!(got.len() > 22, "test requires padding tail to exist"); assert!(
got.len() > 22,
"test requires padding tail to exist"
);
let tail = &got[22..]; let tail = &got[22..];
assert!( assert!(
@ -191,9 +194,7 @@ async fn redteam_shape_03_exact_floor_input_should_not_be_fixed_point() {
async fn redteam_shape_04_all_sub_cap_sizes_should_collapse_to_single_size() { async fn redteam_shape_04_all_sub_cap_sizes_should_collapse_to_single_size() {
let floor = 512usize; let floor = 512usize;
let cap = 4096usize; let cap = 4096usize;
let classes = [ let classes = [17usize, 63usize, 255usize, 511usize, 1023usize, 2047usize, 3071usize];
17usize, 63usize, 255usize, 511usize, 1023usize, 2047usize, 3071usize,
];
let mut observed = Vec::new(); let mut observed = Vec::new();
for body in classes { for body in classes {
@ -202,10 +203,7 @@ async fn redteam_shape_04_all_sub_cap_sizes_should_collapse_to_single_size() {
let first = observed[0]; let first = observed[0];
for v in observed { for v in observed {
assert_eq!( assert_eq!(v, first, "strict model expects one collapsed class across all sub-cap probes");
v, first,
"strict model expects one collapsed class across all sub-cap probes"
);
} }
} }

View File

@ -1,7 +1,7 @@
use super::*; use super::*;
use crate::config::{UpstreamConfig, UpstreamType}; use crate::config::{UpstreamConfig, UpstreamType};
use std::sync::Arc; use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio::time::Duration; use tokio::time::Duration;

View File

@ -3,7 +3,7 @@ use crate::config::{UpstreamConfig, UpstreamType};
use crate::crypto::sha256_hmac; use crate::crypto::sha256_hmac;
use crate::protocol::constants::{HANDSHAKE_LEN, TLS_RECORD_APPLICATION, TLS_VERSION}; use crate::protocol::constants::{HANDSHAKE_LEN, TLS_RECORD_APPLICATION, TLS_VERSION};
use crate::protocol::tls; use crate::protocol::tls;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio::time::Duration; use tokio::time::Duration;
@ -70,10 +70,7 @@ fn build_harness(mask_port: u16, secret_hex: &str) -> StressHarness {
} }
fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec<u8> { fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec<u8> {
assert!( assert!(tls_len <= u16::MAX as usize, "TLS length must fit into record header");
tls_len <= u16::MAX as usize,
"TLS length must fit into record header"
);
let total_len = 5 + tls_len; let total_len = 5 + tls_len;
let mut handshake = vec![fill; total_len]; let mut handshake = vec![fill; total_len];
@ -153,8 +150,12 @@ async fn run_parallel_tail_fallback_case(
for idx in 0..sessions { for idx in 0..sessions {
let harness = build_harness(backend_addr.port(), "e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0"); let harness = build_harness(backend_addr.port(), "e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0");
let hello = let hello = make_valid_tls_client_hello(
make_valid_tls_client_hello(&[0xE0; 16], ts_base + idx as u32, 600, 0x40 + (idx as u8)); &[0xE0; 16],
ts_base + idx as u32,
600,
0x40 + (idx as u8),
);
let invalid_mtproto = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); let invalid_mtproto = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]);
let payload = vec![((idx * 37) & 0xff) as u8; payload_len + idx % 3]; let payload = vec![((idx * 37) & 0xff) as u8; payload_len + idx % 3];
@ -169,8 +170,8 @@ async fn run_parallel_tail_fallback_case(
peer_ip_fourth, peer_ip_fourth,
peer_port_base + idx as u16 peer_port_base + idx as u16
) )
.parse() .parse()
.unwrap(); .unwrap();
tasks.push(tokio::spawn(async move { tasks.push(tokio::spawn(async move {
let (server_side, mut client_side) = duplex(262144); let (server_side, mut client_side) = duplex(262144);
@ -193,10 +194,7 @@ async fn run_parallel_tail_fallback_case(
client_side.write_all(&hello).await.unwrap(); client_side.write_all(&hello).await.unwrap();
let mut server_hello_head = [0u8; 5]; let mut server_hello_head = [0u8; 5];
client_side client_side.read_exact(&mut server_hello_head).await.unwrap();
.read_exact(&mut server_hello_head)
.await
.unwrap();
assert_eq!(server_hello_head[0], 0x16); assert_eq!(server_hello_head[0], 0x16);
read_tls_record_body(&mut client_side, server_hello_head).await; read_tls_record_body(&mut client_side, server_hello_head).await;

View File

@ -8,7 +8,7 @@ use crate::proxy::handshake::HandshakeSuccess;
use crate::stream::{CryptoReader, CryptoWriter}; use crate::stream::{CryptoReader, CryptoWriter};
use crate::transport::proxy_protocol::ProxyProtocolV1Builder; use crate::transport::proxy_protocol::ProxyProtocolV1Builder;
use std::net::Ipv4Addr; use std::net::Ipv4Addr;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream}; use tokio::net::{TcpListener, TcpStream};
#[test] #[test]
@ -49,33 +49,25 @@ async fn user_connection_reservation_drop_enqueues_cleanup_synchronously() {
let stats = Arc::new(crate::stats::Stats::new()); let stats = Arc::new(crate::stats::Stats::new());
let user = "sync-drop-user".to_string(); let user = "sync-drop-user".to_string();
let ip: std::net::IpAddr = "192.168.1.1".parse().unwrap(); let ip: std::net::IpAddr = "192.168.1.1".parse().unwrap();
ip_tracker.set_user_limit(&user, 1).await; ip_tracker.set_user_limit(&user, 1).await;
ip_tracker.check_and_add(&user, ip).await.unwrap(); ip_tracker.check_and_add(&user, ip).await.unwrap();
stats.increment_user_curr_connects(&user); stats.increment_user_curr_connects(&user);
assert_eq!(ip_tracker.get_active_ip_count(&user).await, 1); assert_eq!(ip_tracker.get_active_ip_count(&user).await, 1);
assert_eq!(stats.get_user_curr_connects(&user), 1); assert_eq!(stats.get_user_curr_connects(&user), 1);
let reservation = let reservation = UserConnectionReservation::new(stats.clone(), ip_tracker.clone(), user.clone(), ip);
UserConnectionReservation::new(stats.clone(), ip_tracker.clone(), user.clone(), ip);
// Drop the reservation synchronously without any tokio::spawn/await yielding! // Drop the reservation synchronously without any tokio::spawn/await yielding!
drop(reservation); drop(reservation);
// The IP is now inside the cleanup_queue, check that the queue has length 1 // The IP is now inside the cleanup_queue, check that the queue has length 1
let queue_len = ip_tracker.cleanup_queue_len_for_tests(); let queue_len = ip_tracker.cleanup_queue.lock().unwrap().len();
assert_eq!( assert_eq!(queue_len, 1, "Reservation drop must push directly to synchronized IP queue");
queue_len, 1,
"Reservation drop must push directly to synchronized IP queue" assert_eq!(stats.get_user_curr_connects(&user), 0, "Stats must decrement immediately");
);
assert_eq!(
stats.get_user_curr_connects(&user),
0,
"Stats must decrement immediately"
);
ip_tracker.drain_cleanup_queue().await; ip_tracker.drain_cleanup_queue().await;
assert_eq!(ip_tracker.get_active_ip_count(&user).await, 0); assert_eq!(ip_tracker.get_active_ip_count(&user).await, 0);
} }
@ -294,10 +286,7 @@ async fn relay_cutover_releases_user_gate_and_ip_reservation() {
.await .await
.expect("relay must terminate after cutover") .expect("relay must terminate after cutover")
.expect("relay task must not panic"); .expect("relay task must not panic");
assert!( assert!(relay_result.is_err(), "cutover must terminate direct relay session");
relay_result.is_err(),
"cutover must terminate direct relay session"
);
assert_eq!( assert_eq!(
stats.get_user_curr_connects(user), stats.get_user_curr_connects(user),
@ -458,12 +447,7 @@ async fn stress_drop_without_release_converges_to_zero_user_and_ip_state() {
let mut reservations = Vec::new(); let mut reservations = Vec::new();
for idx in 0..512u16 { for idx in 0..512u16 {
let peer = std::net::SocketAddr::new( let peer = std::net::SocketAddr::new(
std::net::IpAddr::V4(std::net::Ipv4Addr::new( std::net::IpAddr::V4(std::net::Ipv4Addr::new(198, 51, (idx >> 8) as u8, (idx & 0xff) as u8)),
198,
51,
(idx >> 8) as u8,
(idx & 0xff) as u8,
)),
30_000 + idx, 30_000 + idx,
); );
let reservation = RunningClientHandler::acquire_user_connection_reservation_static( let reservation = RunningClientHandler::acquire_user_connection_reservation_static(
@ -526,15 +510,10 @@ async fn proxy_protocol_header_is_rejected_when_trust_list_is_empty() {
false, false,
stats.clone(), stats.clone(),
)); ));
let replay_checker = std::sync::Arc::new(crate::stats::ReplayChecker::new( let replay_checker = std::sync::Arc::new(crate::stats::ReplayChecker::new(128, std::time::Duration::from_secs(60)));
128,
std::time::Duration::from_secs(60),
));
let buffer_pool = std::sync::Arc::new(crate::stream::BufferPool::new()); let buffer_pool = std::sync::Arc::new(crate::stream::BufferPool::new());
let rng = std::sync::Arc::new(crate::crypto::SecureRandom::new()); let rng = std::sync::Arc::new(crate::crypto::SecureRandom::new());
let route_runtime = std::sync::Arc::new(crate::proxy::route_mode::RouteRuntimeController::new( let route_runtime = std::sync::Arc::new(crate::proxy::route_mode::RouteRuntimeController::new(crate::proxy::route_mode::RelayRouteMode::Direct));
crate::proxy::route_mode::RelayRouteMode::Direct,
));
let ip_tracker = std::sync::Arc::new(crate::ip_tracker::UserIpTracker::new()); let ip_tracker = std::sync::Arc::new(crate::ip_tracker::UserIpTracker::new());
let beobachten = std::sync::Arc::new(crate::stats::beobachten::BeobachtenStore::new()); let beobachten = std::sync::Arc::new(crate::stats::beobachten::BeobachtenStore::new());
@ -602,16 +581,10 @@ async fn proxy_protocol_header_from_untrusted_peer_range_is_rejected_under_load(
false, false,
stats.clone(), stats.clone(),
)); ));
let replay_checker = std::sync::Arc::new(crate::stats::ReplayChecker::new( let replay_checker = std::sync::Arc::new(crate::stats::ReplayChecker::new(64, std::time::Duration::from_secs(60)));
64,
std::time::Duration::from_secs(60),
));
let buffer_pool = std::sync::Arc::new(crate::stream::BufferPool::new()); let buffer_pool = std::sync::Arc::new(crate::stream::BufferPool::new());
let rng = std::sync::Arc::new(crate::crypto::SecureRandom::new()); let rng = std::sync::Arc::new(crate::crypto::SecureRandom::new());
let route_runtime = let route_runtime = std::sync::Arc::new(crate::proxy::route_mode::RouteRuntimeController::new(crate::proxy::route_mode::RelayRouteMode::Direct));
std::sync::Arc::new(crate::proxy::route_mode::RouteRuntimeController::new(
crate::proxy::route_mode::RelayRouteMode::Direct,
));
let ip_tracker = std::sync::Arc::new(crate::ip_tracker::UserIpTracker::new()); let ip_tracker = std::sync::Arc::new(crate::ip_tracker::UserIpTracker::new());
let beobachten = std::sync::Arc::new(crate::stats::beobachten::BeobachtenStore::new()); let beobachten = std::sync::Arc::new(crate::stats::beobachten::BeobachtenStore::new());
@ -696,16 +669,8 @@ async fn reservation_limit_failure_does_not_leak_curr_connects_counter() {
matches!(second, Err(crate::error::ProxyError::ConnectionLimitExceeded { user: denied }) if denied == user), matches!(second, Err(crate::error::ProxyError::ConnectionLimitExceeded { user: denied }) if denied == user),
"second reservation must be rejected at the configured tcp-conns limit" "second reservation must be rejected at the configured tcp-conns limit"
); );
assert_eq!( assert_eq!(stats.get_user_curr_connects(user), 1, "failed acquisition must not leak a counter increment");
stats.get_user_curr_connects(user), assert_eq!(ip_tracker.get_active_ip_count(user).await, 1, "failed acquisition must not mutate IP tracker state");
1,
"failed acquisition must not leak a counter increment"
);
assert_eq!(
ip_tracker.get_active_ip_count(user).await,
1,
"failed acquisition must not mutate IP tracker state"
);
first.release().await; first.release().await;
ip_tracker.drain_cleanup_queue().await; ip_tracker.drain_cleanup_queue().await;
@ -1154,10 +1119,7 @@ async fn partial_tls_header_stall_triggers_handshake_timeout() {
} }
fn make_valid_tls_client_hello_with_len(secret: &[u8], timestamp: u32, tls_len: usize) -> Vec<u8> { fn make_valid_tls_client_hello_with_len(secret: &[u8], timestamp: u32, tls_len: usize) -> Vec<u8> {
assert!( assert!(tls_len <= u16::MAX as usize, "TLS length must fit into record header");
tls_len <= u16::MAX as usize,
"TLS length must fit into record header"
);
let total_len = 5 + tls_len; let total_len = 5 + tls_len;
let mut handshake = vec![0x42u8; total_len]; let mut handshake = vec![0x42u8; total_len];
@ -1178,8 +1140,7 @@ fn make_valid_tls_client_hello_with_len(secret: &[u8], timestamp: u32, tls_len:
digest[28 + i] ^= ts[i]; digest[28 + i] ^= ts[i];
} }
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].copy_from_slice(&digest);
.copy_from_slice(&digest);
handshake handshake
} }
@ -1242,7 +1203,8 @@ fn make_valid_tls_client_hello_with_alpn(
digest[28 + i] ^= ts[i]; digest[28 + i] ^= ts[i];
} }
record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].copy_from_slice(&digest); record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN]
.copy_from_slice(&digest);
record record
} }
@ -1271,10 +1233,9 @@ async fn valid_tls_path_does_not_fall_back_to_mask_backend() {
cfg.censorship.mask_port = backend_addr.port(); cfg.censorship.mask_port = backend_addr.port();
cfg.censorship.mask_proxy_protocol = 0; cfg.censorship.mask_proxy_protocol = 0;
cfg.access.ignore_time_skew = true; cfg.access.ignore_time_skew = true;
cfg.access.users.insert( cfg.access
"user".to_string(), .users
"11111111111111111111111111111111".to_string(), .insert("user".to_string(), "11111111111111111111111111111111".to_string());
);
let config = Arc::new(cfg); let config = Arc::new(cfg);
let stats = Arc::new(Stats::new()); let stats = Arc::new(Stats::new());
@ -1346,7 +1307,8 @@ async fn valid_tls_path_does_not_fall_back_to_mask_backend() {
let bad_after = stats_for_assert.get_connects_bad(); let bad_after = stats_for_assert.get_connects_bad();
assert_eq!( assert_eq!(
bad_before, bad_after, bad_before,
bad_after,
"Authenticated TLS path must not increment connects_bad" "Authenticated TLS path must not increment connects_bad"
); );
} }
@ -1379,10 +1341,9 @@ async fn valid_tls_with_invalid_mtproto_falls_back_to_mask_backend() {
cfg.censorship.mask_port = backend_addr.port(); cfg.censorship.mask_port = backend_addr.port();
cfg.censorship.mask_proxy_protocol = 0; cfg.censorship.mask_proxy_protocol = 0;
cfg.access.ignore_time_skew = true; cfg.access.ignore_time_skew = true;
cfg.access.users.insert( cfg.access
"user".to_string(), .users
"33333333333333333333333333333333".to_string(), .insert("user".to_string(), "33333333333333333333333333333333".to_string());
);
let config = Arc::new(cfg); let config = Arc::new(cfg);
let stats = Arc::new(Stats::new()); let stats = Arc::new(Stats::new());
@ -1433,10 +1394,7 @@ async fn valid_tls_with_invalid_mtproto_falls_back_to_mask_backend() {
client_side.write_all(&client_hello).await.unwrap(); client_side.write_all(&client_hello).await.unwrap();
let mut tls_response_head = [0u8; 5]; let mut tls_response_head = [0u8; 5];
client_side client_side.read_exact(&mut tls_response_head).await.unwrap();
.read_exact(&mut tls_response_head)
.await
.unwrap();
assert_eq!(tls_response_head[0], 0x16); assert_eq!(tls_response_head[0], 0x16);
client_side.write_all(&tls_app_record).await.unwrap(); client_side.write_all(&tls_app_record).await.unwrap();
@ -1485,10 +1443,9 @@ async fn client_handler_tls_bad_mtproto_is_forwarded_to_mask_backend() {
cfg.censorship.mask_port = backend_addr.port(); cfg.censorship.mask_port = backend_addr.port();
cfg.censorship.mask_proxy_protocol = 0; cfg.censorship.mask_proxy_protocol = 0;
cfg.access.ignore_time_skew = true; cfg.access.ignore_time_skew = true;
cfg.access.users.insert( cfg.access
"user".to_string(), .users
"44444444444444444444444444444444".to_string(), .insert("user".to_string(), "44444444444444444444444444444444".to_string());
);
let config = Arc::new(cfg); let config = Arc::new(cfg);
let stats = Arc::new(Stats::new()); let stats = Arc::new(Stats::new());
@ -1606,10 +1563,9 @@ async fn alpn_mismatch_tls_probe_is_masked_through_client_pipeline() {
cfg.censorship.mask_proxy_protocol = 0; cfg.censorship.mask_proxy_protocol = 0;
cfg.censorship.alpn_enforce = true; cfg.censorship.alpn_enforce = true;
cfg.access.ignore_time_skew = true; cfg.access.ignore_time_skew = true;
cfg.access.users.insert( cfg.access
"user".to_string(), .users
"66666666666666666666666666666666".to_string(), .insert("user".to_string(), "66666666666666666666666666666666".to_string());
);
let config = Arc::new(cfg); let config = Arc::new(cfg);
let stats = Arc::new(Stats::new()); let stats = Arc::new(Stats::new());
@ -1698,10 +1654,9 @@ async fn invalid_hmac_tls_probe_is_masked_through_client_pipeline() {
cfg.censorship.mask_port = backend_addr.port(); cfg.censorship.mask_port = backend_addr.port();
cfg.censorship.mask_proxy_protocol = 0; cfg.censorship.mask_proxy_protocol = 0;
cfg.access.ignore_time_skew = true; cfg.access.ignore_time_skew = true;
cfg.access.users.insert( cfg.access
"user".to_string(), .users
"77777777777777777777777777777777".to_string(), .insert("user".to_string(), "77777777777777777777777777777777".to_string());
);
let config = Arc::new(cfg); let config = Arc::new(cfg);
let stats = Arc::new(Stats::new()); let stats = Arc::new(Stats::new());
@ -1796,10 +1751,9 @@ async fn burst_invalid_tls_probes_are_masked_verbatim() {
cfg.censorship.mask_port = backend_addr.port(); cfg.censorship.mask_port = backend_addr.port();
cfg.censorship.mask_proxy_protocol = 0; cfg.censorship.mask_proxy_protocol = 0;
cfg.access.ignore_time_skew = true; cfg.access.ignore_time_skew = true;
cfg.access.users.insert( cfg.access
"user".to_string(), .users
"88888888888888888888888888888888".to_string(), .insert("user".to_string(), "88888888888888888888888888888888".to_string());
);
let config = Arc::new(cfg); let config = Arc::new(cfg);
let stats = Arc::new(Stats::new()); let stats = Arc::new(Stats::new());
@ -2027,7 +1981,10 @@ async fn zero_tcp_limit_rejects_without_ip_or_counter_side_effects() {
async fn check_user_limits_static_success_does_not_leak_counter_or_ip_reservation() { async fn check_user_limits_static_success_does_not_leak_counter_or_ip_reservation() {
let user = "check-helper-user"; let user = "check-helper-user";
let mut config = ProxyConfig::default(); let mut config = ProxyConfig::default();
config.access.user_max_tcp_conns.insert(user.to_string(), 1); config
.access
.user_max_tcp_conns
.insert(user.to_string(), 1);
let stats = Stats::new(); let stats = Stats::new();
let ip_tracker = UserIpTracker::new(); let ip_tracker = UserIpTracker::new();
@ -2041,10 +1998,7 @@ async fn check_user_limits_static_success_does_not_leak_counter_or_ip_reservatio
&ip_tracker, &ip_tracker,
) )
.await; .await;
assert!( assert!(first.is_ok(), "first check-only limit validation must succeed");
first.is_ok(),
"first check-only limit validation must succeed"
);
let second = RunningClientHandler::check_user_limits_static( let second = RunningClientHandler::check_user_limits_static(
user, user,
@ -2054,10 +2008,7 @@ async fn check_user_limits_static_success_does_not_leak_counter_or_ip_reservatio
&ip_tracker, &ip_tracker,
) )
.await; .await;
assert!( assert!(second.is_ok(), "second check-only validation must not fail from leaked state");
second.is_ok(),
"second check-only validation must not fail from leaked state"
);
assert_eq!(stats.get_user_curr_connects(user), 0); assert_eq!(stats.get_user_curr_connects(user), 0);
assert_eq!(ip_tracker.get_active_ip_count(user).await, 0); assert_eq!(ip_tracker.get_active_ip_count(user).await, 0);
} }
@ -2066,7 +2017,10 @@ async fn check_user_limits_static_success_does_not_leak_counter_or_ip_reservatio
async fn stress_check_user_limits_static_success_never_leaks_state() { async fn stress_check_user_limits_static_success_never_leaks_state() {
let user = "check-helper-stress-user"; let user = "check-helper-stress-user";
let mut config = ProxyConfig::default(); let mut config = ProxyConfig::default();
config.access.user_max_tcp_conns.insert(user.to_string(), 1); config
.access
.user_max_tcp_conns
.insert(user.to_string(), 1);
let stats = Stats::new(); let stats = Stats::new();
let ip_tracker = UserIpTracker::new(); let ip_tracker = UserIpTracker::new();
@ -2085,10 +2039,7 @@ async fn stress_check_user_limits_static_success_never_leaks_state() {
&ip_tracker, &ip_tracker,
) )
.await; .await;
assert!( assert!(result.is_ok(), "check-only helper must remain leak-free under stress");
result.is_ok(),
"check-only helper must remain leak-free under stress"
);
} }
assert_eq!( assert_eq!(
@ -2139,7 +2090,11 @@ async fn concurrent_distinct_ip_rejections_rollback_user_counter_without_leak()
41000 + i as u16, 41000 + i as u16,
); );
let result = RunningClientHandler::acquire_user_connection_reservation_static( let result = RunningClientHandler::acquire_user_connection_reservation_static(
user, &config, stats, peer, ip_tracker, user,
&config,
stats,
peer,
ip_tracker,
) )
.await; .await;
assert!(matches!( assert!(matches!(
@ -2175,7 +2130,10 @@ async fn explicit_reservation_release_cleans_user_and_ip_immediately() {
let peer_addr: SocketAddr = "198.51.100.240:50002".parse().unwrap(); let peer_addr: SocketAddr = "198.51.100.240:50002".parse().unwrap();
let mut config = ProxyConfig::default(); let mut config = ProxyConfig::default();
config.access.user_max_tcp_conns.insert(user.to_string(), 4); config
.access
.user_max_tcp_conns
.insert(user.to_string(), 4);
let stats = Arc::new(Stats::new()); let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new()); let ip_tracker = Arc::new(UserIpTracker::new());
@ -2213,7 +2171,10 @@ async fn explicit_reservation_release_does_not_double_decrement_on_drop() {
let peer_addr: SocketAddr = "198.51.100.241:50003".parse().unwrap(); let peer_addr: SocketAddr = "198.51.100.241:50003".parse().unwrap();
let mut config = ProxyConfig::default(); let mut config = ProxyConfig::default();
config.access.user_max_tcp_conns.insert(user.to_string(), 4); config
.access
.user_max_tcp_conns
.insert(user.to_string(), 4);
let stats = Arc::new(Stats::new()); let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new()); let ip_tracker = Arc::new(UserIpTracker::new());
@ -2243,7 +2204,10 @@ async fn drop_fallback_eventually_cleans_user_and_ip_reservation() {
let peer_addr: SocketAddr = "198.51.100.242:50004".parse().unwrap(); let peer_addr: SocketAddr = "198.51.100.242:50004".parse().unwrap();
let mut config = ProxyConfig::default(); let mut config = ProxyConfig::default();
config.access.user_max_tcp_conns.insert(user.to_string(), 4); config
.access
.user_max_tcp_conns
.insert(user.to_string(), 4);
let stats = Arc::new(Stats::new()); let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new()); let ip_tracker = Arc::new(UserIpTracker::new());
@ -2284,7 +2248,10 @@ async fn explicit_release_allows_immediate_cross_ip_reacquire_under_limit() {
let peer2: SocketAddr = "198.51.100.244:50006".parse().unwrap(); let peer2: SocketAddr = "198.51.100.244:50006".parse().unwrap();
let mut config = ProxyConfig::default(); let mut config = ProxyConfig::default();
config.access.user_max_tcp_conns.insert(user.to_string(), 4); config
.access
.user_max_tcp_conns
.insert(user.to_string(), 4);
let stats = Arc::new(Stats::new()); let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new()); let ip_tracker = Arc::new(UserIpTracker::new());
@ -2506,14 +2473,8 @@ async fn parallel_users_abort_release_isolation_preserves_independent_cleanup()
let user_b = "abort-isolation-b"; let user_b = "abort-isolation-b";
let mut config = ProxyConfig::default(); let mut config = ProxyConfig::default();
config config.access.user_max_tcp_conns.insert(user_a.to_string(), 64);
.access config.access.user_max_tcp_conns.insert(user_b.to_string(), 64);
.user_max_tcp_conns
.insert(user_a.to_string(), 64);
config
.access
.user_max_tcp_conns
.insert(user_b.to_string(), 64);
let stats = Arc::new(Stats::new()); let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new()); let ip_tracker = Arc::new(UserIpTracker::new());
@ -2634,7 +2595,10 @@ async fn relay_connect_error_releases_user_and_ip_before_return() {
let ip_tracker = Arc::new(UserIpTracker::new()); let ip_tracker = Arc::new(UserIpTracker::new());
let mut config = ProxyConfig::default(); let mut config = ProxyConfig::default();
config.access.user_max_tcp_conns.insert(user.to_string(), 1); config
.access
.user_max_tcp_conns
.insert(user.to_string(), 1);
config config
.dc_overrides .dc_overrides
.insert("2".to_string(), vec![format!("127.0.0.1:{dead_port}")]); .insert("2".to_string(), vec![format!("127.0.0.1:{dead_port}")]);
@ -2697,10 +2661,7 @@ async fn relay_connect_error_releases_user_and_ip_before_return() {
) )
.await; .await;
assert!( assert!(result.is_err(), "relay must fail when upstream DC is unreachable");
result.is_err(),
"relay must fail when upstream DC is unreachable"
);
assert_eq!( assert_eq!(
stats.get_user_curr_connects(user), stats.get_user_curr_connects(user),
0, 0,
@ -2719,7 +2680,10 @@ async fn mixed_release_and_drop_same_ip_preserves_counter_correctness() {
let peer_addr: SocketAddr = "198.51.100.246:50008".parse().unwrap(); let peer_addr: SocketAddr = "198.51.100.246:50008".parse().unwrap();
let mut config = ProxyConfig::default(); let mut config = ProxyConfig::default();
config.access.user_max_tcp_conns.insert(user.to_string(), 8); config
.access
.user_max_tcp_conns
.insert(user.to_string(), 8);
let stats = Arc::new(Stats::new()); let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new()); let ip_tracker = Arc::new(UserIpTracker::new());
@ -2779,7 +2743,10 @@ async fn drop_one_of_two_same_ip_reservations_keeps_ip_active() {
let peer_addr: SocketAddr = "198.51.100.247:50009".parse().unwrap(); let peer_addr: SocketAddr = "198.51.100.247:50009".parse().unwrap();
let mut config = ProxyConfig::default(); let mut config = ProxyConfig::default();
config.access.user_max_tcp_conns.insert(user.to_string(), 8); config
.access
.user_max_tcp_conns
.insert(user.to_string(), 8);
let stats = Arc::new(Stats::new()); let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new()); let ip_tracker = Arc::new(UserIpTracker::new());
@ -2835,10 +2802,7 @@ async fn drop_one_of_two_same_ip_reservations_keeps_ip_active() {
#[tokio::test] #[tokio::test]
async fn quota_rejection_does_not_reserve_ip_or_trigger_rollback() { async fn quota_rejection_does_not_reserve_ip_or_trigger_rollback() {
let mut config = ProxyConfig::default(); let mut config = ProxyConfig::default();
config config.access.user_data_quota.insert("user".to_string(), 1024);
.access
.user_data_quota
.insert("user".to_string(), 1024);
let stats = Stats::new(); let stats = Stats::new();
stats.add_user_octets_from("user", 1024); stats.add_user_octets_from("user", 1024);
@ -2874,10 +2838,10 @@ async fn quota_rejection_does_not_reserve_ip_or_trigger_rollback() {
#[tokio::test] #[tokio::test]
async fn expired_user_rejection_does_not_reserve_ip_or_increment_curr_connects() { async fn expired_user_rejection_does_not_reserve_ip_or_increment_curr_connects() {
let mut config = ProxyConfig::default(); let mut config = ProxyConfig::default();
config.access.user_expirations.insert( config
"user".to_string(), .access
chrono::Utc::now() - chrono::Duration::seconds(1), .user_expirations
); .insert("user".to_string(), chrono::Utc::now() - chrono::Duration::seconds(1));
let stats = Stats::new(); let stats = Stats::new();
let ip_tracker = UserIpTracker::new(); let ip_tracker = UserIpTracker::new();
@ -2906,7 +2870,10 @@ async fn same_ip_second_reservation_succeeds_under_unique_ip_limit_one() {
let peer_addr: SocketAddr = "198.51.100.248:50010".parse().unwrap(); let peer_addr: SocketAddr = "198.51.100.248:50010".parse().unwrap();
let mut config = ProxyConfig::default(); let mut config = ProxyConfig::default();
config.access.user_max_tcp_conns.insert(user.to_string(), 8); config
.access
.user_max_tcp_conns
.insert(user.to_string(), 8);
let stats = Arc::new(Stats::new()); let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new()); let ip_tracker = Arc::new(UserIpTracker::new());
@ -2947,7 +2914,10 @@ async fn second_distinct_ip_is_rejected_under_unique_ip_limit_one() {
let peer2: SocketAddr = "198.51.100.250:50012".parse().unwrap(); let peer2: SocketAddr = "198.51.100.250:50012".parse().unwrap();
let mut config = ProxyConfig::default(); let mut config = ProxyConfig::default();
config.access.user_max_tcp_conns.insert(user.to_string(), 8); config
.access
.user_max_tcp_conns
.insert(user.to_string(), 8);
let stats = Arc::new(Stats::new()); let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new()); let ip_tracker = Arc::new(UserIpTracker::new());
@ -2988,7 +2958,10 @@ async fn cross_thread_drop_uses_captured_runtime_for_ip_cleanup() {
let peer_addr: SocketAddr = "198.51.100.251:50013".parse().unwrap(); let peer_addr: SocketAddr = "198.51.100.251:50013".parse().unwrap();
let mut config = ProxyConfig::default(); let mut config = ProxyConfig::default();
config.access.user_max_tcp_conns.insert(user.to_string(), 8); config
.access
.user_max_tcp_conns
.insert(user.to_string(), 8);
let stats = Arc::new(Stats::new()); let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new()); let ip_tracker = Arc::new(UserIpTracker::new());
@ -3032,7 +3005,10 @@ async fn immediate_reacquire_after_cross_thread_drop_succeeds() {
let peer_addr: SocketAddr = "198.51.100.252:50014".parse().unwrap(); let peer_addr: SocketAddr = "198.51.100.252:50014".parse().unwrap();
let mut config = ProxyConfig::default(); let mut config = ProxyConfig::default();
config.access.user_max_tcp_conns.insert(user.to_string(), 1); config
.access
.user_max_tcp_conns
.insert(user.to_string(), 1);
let stats = Arc::new(Stats::new()); let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new()); let ip_tracker = Arc::new(UserIpTracker::new());
@ -3067,7 +3043,11 @@ async fn immediate_reacquire_after_cross_thread_drop_succeeds() {
.expect("cross-thread cleanup must settle before reacquire check"); .expect("cross-thread cleanup must settle before reacquire check");
let reacquire = RunningClientHandler::acquire_user_connection_reservation_static( let reacquire = RunningClientHandler::acquire_user_connection_reservation_static(
user, &config, stats, peer_addr, ip_tracker, user,
&config,
stats,
peer_addr,
ip_tracker,
) )
.await; .await;
assert!( assert!(
@ -3133,7 +3113,10 @@ async fn concurrent_limit_rejections_from_mixed_ips_leave_no_ip_footprint() {
.get_recent_ips_for_users(&["user".to_string()]) .get_recent_ips_for_users(&["user".to_string()])
.await; .await;
assert!( assert!(
recent.get("user").map(|ips| ips.is_empty()).unwrap_or(true), recent
.get("user")
.map(|ips| ips.is_empty())
.unwrap_or(true),
"Concurrent rejected attempts must not leave recent IP footprint" "Concurrent rejected attempts must not leave recent IP footprint"
); );
@ -3167,7 +3150,11 @@ async fn atomic_limit_gate_allows_only_one_concurrent_acquire() {
30000 + i, 30000 + i,
); );
RunningClientHandler::acquire_user_connection_reservation_static( RunningClientHandler::acquire_user_connection_reservation_static(
"user", &config, stats, peer, ip_tracker, "user",
&config,
stats,
peer,
ip_tracker,
) )
.await .await
.ok() .ok()
@ -3782,10 +3769,9 @@ async fn tls_record_len_16384_is_accepted_in_generic_stream_pipeline() {
cfg.censorship.mask_port = backend_addr.port(); cfg.censorship.mask_port = backend_addr.port();
cfg.censorship.mask_proxy_protocol = 0; cfg.censorship.mask_proxy_protocol = 0;
cfg.access.ignore_time_skew = true; cfg.access.ignore_time_skew = true;
cfg.access.users.insert( cfg.access
"user".to_string(), .users
"55555555555555555555555555555555".to_string(), .insert("user".to_string(), "55555555555555555555555555555555".to_string());
);
let config = Arc::new(cfg); let config = Arc::new(cfg);
let stats = Arc::new(Stats::new()); let stats = Arc::new(Stats::new());
@ -3838,10 +3824,7 @@ async fn tls_record_len_16384_is_accepted_in_generic_stream_pipeline() {
client_side.write_all(&client_hello).await.unwrap(); client_side.write_all(&client_hello).await.unwrap();
let mut record_header = [0u8; 5]; let mut record_header = [0u8; 5];
client_side.read_exact(&mut record_header).await.unwrap(); client_side.read_exact(&mut record_header).await.unwrap();
assert_eq!( assert_eq!(record_header[0], 0x16, "Valid max-length ClientHello must be accepted");
record_header[0], 0x16,
"Valid max-length ClientHello must be accepted"
);
drop(client_side); drop(client_side);
let handler_result = tokio::time::timeout(Duration::from_secs(3), handler) let handler_result = tokio::time::timeout(Duration::from_secs(3), handler)
@ -3882,10 +3865,9 @@ async fn tls_record_len_16384_is_accepted_in_client_handler_pipeline() {
cfg.censorship.mask_port = backend_addr.port(); cfg.censorship.mask_port = backend_addr.port();
cfg.censorship.mask_proxy_protocol = 0; cfg.censorship.mask_proxy_protocol = 0;
cfg.access.ignore_time_skew = true; cfg.access.ignore_time_skew = true;
cfg.access.users.insert( cfg.access
"user".to_string(), .users
"66666666666666666666666666666666".to_string(), .insert("user".to_string(), "66666666666666666666666666666666".to_string());
);
let config = Arc::new(cfg); let config = Arc::new(cfg);
let stats = Arc::new(Stats::new()); let stats = Arc::new(Stats::new());
@ -3956,10 +3938,7 @@ async fn tls_record_len_16384_is_accepted_in_client_handler_pipeline() {
let mut record_header = [0u8; 5]; let mut record_header = [0u8; 5];
client.read_exact(&mut record_header).await.unwrap(); client.read_exact(&mut record_header).await.unwrap();
assert_eq!( assert_eq!(record_header[0], 0x16, "Valid max-length ClientHello must be accepted");
record_header[0], 0x16,
"Valid max-length ClientHello must be accepted"
);
drop(client); drop(client);
@ -3968,8 +3947,7 @@ async fn tls_record_len_16384_is_accepted_in_client_handler_pipeline() {
.unwrap() .unwrap()
.unwrap(); .unwrap();
let no_mask_connect = let no_mask_connect = tokio::time::timeout(Duration::from_millis(250), mask_listener.accept()).await;
tokio::time::timeout(Duration::from_millis(250), mask_listener.accept()).await;
assert!( assert!(
no_mask_connect.is_err(), no_mask_connect.is_err(),
"Valid max-length ClientHello must not trigger mask fallback in ClientHandler path" "Valid max-length ClientHello must not trigger mask fallback in ClientHandler path"
@ -4026,7 +4004,11 @@ async fn burst_acquire_distinct_ips(
55000 + i, 55000 + i,
); );
RunningClientHandler::acquire_user_connection_reservation_static( RunningClientHandler::acquire_user_connection_reservation_static(
user, &config, stats, peer, ip_tracker, user,
&config,
stats,
peer,
ip_tracker,
) )
.await .await
}); });
@ -4208,7 +4190,11 @@ async fn cross_thread_drop_storm_then_parallel_reacquire_wave_has_no_leak() {
54000 + i, 54000 + i,
); );
RunningClientHandler::acquire_user_connection_reservation_static( RunningClientHandler::acquire_user_connection_reservation_static(
user, &config, stats, peer, ip_tracker, user,
&config,
stats,
peer,
ip_tracker,
) )
.await .await
}); });
@ -4242,7 +4228,10 @@ async fn cross_thread_drop_storm_then_parallel_reacquire_wave_has_no_leak() {
async fn scheduled_near_limit_and_burst_windows_preserve_admission_invariants() { async fn scheduled_near_limit_and_burst_windows_preserve_admission_invariants() {
let user: &'static str = "scheduled-attack-user"; let user: &'static str = "scheduled-attack-user";
let mut config = ProxyConfig::default(); let mut config = ProxyConfig::default();
config.access.user_max_tcp_conns.insert(user.to_string(), 6); config
.access
.user_max_tcp_conns
.insert(user.to_string(), 6);
let config = Arc::new(config); let config = Arc::new(config);
let stats = Arc::new(Stats::new()); let stats = Arc::new(Stats::new());
@ -4251,10 +4240,7 @@ async fn scheduled_near_limit_and_burst_windows_preserve_admission_invariants()
let mut base = Vec::new(); let mut base = Vec::new();
for i in 0..5u16 { for i in 0..5u16 {
let peer = SocketAddr::new( let peer = SocketAddr::new(IpAddr::V4(std::net::Ipv4Addr::new(198, 51, 130, 1)), 56000 + i);
IpAddr::V4(std::net::Ipv4Addr::new(198, 51, 130, 1)),
56000 + i,
);
let reservation = RunningClientHandler::acquire_user_connection_reservation_static( let reservation = RunningClientHandler::acquire_user_connection_reservation_static(
user, user,
&config, &config,
@ -4302,8 +4288,15 @@ async fn scheduled_near_limit_and_burst_windows_preserve_admission_invariants()
.await .await
.expect("window cleanup must settle to expected occupancy"); .expect("window cleanup must settle to expected occupancy");
let (wave2_success, wave2_fail) = let (wave2_success, wave2_fail) = burst_acquire_distinct_ips(
burst_acquire_distinct_ips(user, config, stats.clone(), ip_tracker.clone(), 132, 32).await; user,
config,
stats.clone(),
ip_tracker.clone(),
132,
32,
)
.await;
assert_eq!(wave2_success.len(), 1); assert_eq!(wave2_success.len(), 1);
assert_eq!(wave2_fail, 31); assert_eq!(wave2_fail, 31);
assert_eq!(stats.get_user_curr_connects(user), 5); assert_eq!(stats.get_user_curr_connects(user), 5);

View File

@ -7,7 +7,7 @@ use crate::config::{UpstreamConfig, UpstreamType};
use crate::protocol::constants::MIN_TLS_CLIENT_HELLO_SIZE; use crate::protocol::constants::MIN_TLS_CLIENT_HELLO_SIZE;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream}; use tokio::net::{TcpListener, TcpStream};
const REPLY_404: &[u8] = b"HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n"; const REPLY_404: &[u8] = b"HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n";
@ -135,13 +135,10 @@ async fn run_generic_once(class: ProbeClass) -> u128 {
client_side.shutdown().await.unwrap(); client_side.shutdown().await.unwrap();
let mut observed = vec![0u8; REPLY_404.len()]; let mut observed = vec![0u8; REPLY_404.len()];
tokio::time::timeout( tokio::time::timeout(Duration::from_secs(2), client_side.read_exact(&mut observed))
Duration::from_secs(2), .await
client_side.read_exact(&mut observed), .unwrap()
) .unwrap();
.await
.unwrap()
.unwrap();
assert_eq!(observed, REPLY_404); assert_eq!(observed, REPLY_404);
tokio::time::timeout(Duration::from_secs(2), accept_task) tokio::time::timeout(Duration::from_secs(2), accept_task)

View File

@ -7,7 +7,7 @@ use crate::config::{UpstreamConfig, UpstreamType};
use crate::protocol::constants::{MAX_TLS_PLAINTEXT_SIZE, MIN_TLS_CLIENT_HELLO_SIZE}; use crate::protocol::constants::{MAX_TLS_PLAINTEXT_SIZE, MIN_TLS_CLIENT_HELLO_SIZE};
use std::net::SocketAddr; use std::net::SocketAddr;
use std::time::Duration; use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener; use tokio::net::TcpListener;
fn test_probe_for_len(len: usize) -> [u8; 5] { fn test_probe_for_len(len: usize) -> [u8; 5] {
@ -100,10 +100,7 @@ async fn run_probe_and_assert_masking(len: usize, expect_bad_increment: bool) {
client_side.write_all(&probe).await.unwrap(); client_side.write_all(&probe).await.unwrap();
let mut observed = vec![0u8; backend_reply.len()]; let mut observed = vec![0u8; backend_reply.len()];
client_side.read_exact(&mut observed).await.unwrap(); client_side.read_exact(&mut observed).await.unwrap();
assert_eq!( assert_eq!(observed, backend_reply, "invalid TLS path must be masked as a real site");
observed, backend_reply,
"invalid TLS path must be masked as a real site"
);
drop(client_side); drop(client_side);
let _ = tokio::time::timeout(Duration::from_secs(3), handler) let _ = tokio::time::timeout(Duration::from_secs(3), handler)
@ -112,11 +109,7 @@ async fn run_probe_and_assert_masking(len: usize, expect_bad_increment: bool) {
.unwrap(); .unwrap();
accept_task.await.unwrap(); accept_task.await.unwrap();
let expected_bad = if expect_bad_increment { let expected_bad = if expect_bad_increment { bad_before + 1 } else { bad_before };
bad_before + 1
} else {
bad_before
};
assert_eq!( assert_eq!(
stats.get_connects_bad(), stats.get_connects_bad(),
expected_bad, expected_bad,
@ -194,9 +187,7 @@ fn tls_client_hello_len_bounds_stress_many_evaluations() {
for _ in 0..100_000 { for _ in 0..100_000 {
assert!(tls_clienthello_len_in_bounds(MIN_TLS_CLIENT_HELLO_SIZE)); assert!(tls_clienthello_len_in_bounds(MIN_TLS_CLIENT_HELLO_SIZE));
assert!(tls_clienthello_len_in_bounds(MAX_TLS_PLAINTEXT_SIZE)); assert!(tls_clienthello_len_in_bounds(MAX_TLS_PLAINTEXT_SIZE));
assert!(!tls_clienthello_len_in_bounds( assert!(!tls_clienthello_len_in_bounds(MIN_TLS_CLIENT_HELLO_SIZE - 1));
MIN_TLS_CLIENT_HELLO_SIZE - 1
));
assert!(!tls_clienthello_len_in_bounds(MAX_TLS_PLAINTEXT_SIZE + 1)); assert!(!tls_clienthello_len_in_bounds(MAX_TLS_PLAINTEXT_SIZE + 1));
} }
} }

View File

@ -7,7 +7,7 @@ use crate::config::{UpstreamConfig, UpstreamType};
use crate::protocol::constants::MIN_TLS_CLIENT_HELLO_SIZE; use crate::protocol::constants::MIN_TLS_CLIENT_HELLO_SIZE;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::time::Duration; use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream}; use tokio::net::{TcpListener, TcpStream};
use tokio::time::sleep; use tokio::time::sleep;
@ -48,12 +48,7 @@ fn truncated_in_range_record(actual_body_len: usize) -> Vec<u8> {
out out
} }
async fn write_fragmented<W: AsyncWriteExt + Unpin>( async fn write_fragmented<W: AsyncWriteExt + Unpin>(writer: &mut W, bytes: &[u8], chunks: &[usize], delay_ms: u64) {
writer: &mut W,
bytes: &[u8],
chunks: &[usize],
delay_ms: u64,
) {
let mut offset = 0usize; let mut offset = 0usize;
for &chunk in chunks { for &chunk in chunks {
if offset >= bytes.len() { if offset >= bytes.len() {
@ -135,13 +130,10 @@ async fn run_blackhat_generic_fragmented_probe_should_mask(
client_side.shutdown().await.unwrap(); client_side.shutdown().await.unwrap();
let mut observed = vec![0u8; backend_reply.len()]; let mut observed = vec![0u8; backend_reply.len()];
tokio::time::timeout( tokio::time::timeout(Duration::from_secs(2), client_side.read_exact(&mut observed))
Duration::from_secs(2), .await
client_side.read_exact(&mut observed), .unwrap()
) .unwrap();
.await
.unwrap()
.unwrap();
assert_eq!(observed, backend_reply); assert_eq!(observed, backend_reply);
tokio::time::timeout(Duration::from_secs(2), mask_accept_task) tokio::time::timeout(Duration::from_secs(2), mask_accept_task)
@ -319,13 +311,10 @@ async fn blackhat_truncated_in_range_clienthello_generic_stream_should_mask() {
// Security expectation: even malformed in-range TLS should be masked. // Security expectation: even malformed in-range TLS should be masked.
// This invariant must hold to avoid probe-distinguishable EOF/timeout behavior. // This invariant must hold to avoid probe-distinguishable EOF/timeout behavior.
let mut observed = vec![0u8; backend_reply.len()]; let mut observed = vec![0u8; backend_reply.len()];
tokio::time::timeout( tokio::time::timeout(Duration::from_secs(2), client_side.read_exact(&mut observed))
Duration::from_secs(2), .await
client_side.read_exact(&mut observed), .unwrap()
) .unwrap();
.await
.unwrap()
.unwrap();
assert_eq!(observed, backend_reply); assert_eq!(observed, backend_reply);
tokio::time::timeout(Duration::from_secs(2), mask_accept_task) tokio::time::timeout(Duration::from_secs(2), mask_accept_task)

View File

@ -2,11 +2,16 @@ use super::*;
use crate::config::{UpstreamConfig, UpstreamType}; use crate::config::{UpstreamConfig, UpstreamType};
use crate::crypto::sha256_hmac; use crate::crypto::sha256_hmac;
use crate::protocol::constants::{ use crate::protocol::constants::{
HANDSHAKE_LEN, MAX_TLS_CIPHERTEXT_SIZE, TLS_RECORD_ALERT, TLS_RECORD_APPLICATION, HANDSHAKE_LEN,
TLS_RECORD_CHANGE_CIPHER, TLS_RECORD_HANDSHAKE, TLS_VERSION, MAX_TLS_CIPHERTEXT_SIZE,
TLS_RECORD_ALERT,
TLS_RECORD_APPLICATION,
TLS_RECORD_CHANGE_CIPHER,
TLS_RECORD_HANDSHAKE,
TLS_VERSION,
}; };
use crate::protocol::tls; use crate::protocol::tls;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener; use tokio::net::TcpListener;
struct PipelineHarness { struct PipelineHarness {
@ -69,10 +74,7 @@ fn build_harness(secret_hex: &str, mask_port: u16) -> PipelineHarness {
} }
fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec<u8> { fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec<u8> {
assert!( assert!(tls_len <= u16::MAX as usize, "TLS length must fit into record header");
tls_len <= u16::MAX as usize,
"TLS length must fit into record header"
);
let total_len = 5 + tls_len; let total_len = 5 + tls_len;
let mut handshake = vec![fill; total_len]; let mut handshake = vec![fill; total_len];
@ -179,17 +181,11 @@ async fn tls_bad_mtproto_fallback_preserves_wire_and_backend_response() {
client_side.write_all(&client_hello).await.unwrap(); client_side.write_all(&client_hello).await.unwrap();
let mut tls_response_head = [0u8; 5]; let mut tls_response_head = [0u8; 5];
client_side client_side.read_exact(&mut tls_response_head).await.unwrap();
.read_exact(&mut tls_response_head)
.await
.unwrap();
assert_eq!(tls_response_head[0], 0x16); assert_eq!(tls_response_head[0], 0x16);
read_and_discard_tls_record_body(&mut client_side, tls_response_head).await; read_and_discard_tls_record_body(&mut client_side, tls_response_head).await;
client_side client_side.write_all(&invalid_mtproto_record).await.unwrap();
.write_all(&invalid_mtproto_record)
.await
.unwrap();
client_side.write_all(&trailing_record).await.unwrap(); client_side.write_all(&trailing_record).await.unwrap();
tokio::time::timeout(Duration::from_secs(3), accept_task) tokio::time::timeout(Duration::from_secs(3), accept_task)
@ -250,16 +246,10 @@ async fn tls_bad_mtproto_fallback_keeps_connects_bad_accounting() {
client_side.write_all(&client_hello).await.unwrap(); client_side.write_all(&client_hello).await.unwrap();
let mut tls_response_head = [0u8; 5]; let mut tls_response_head = [0u8; 5];
client_side client_side.read_exact(&mut tls_response_head).await.unwrap();
.read_exact(&mut tls_response_head)
.await
.unwrap();
assert_eq!(tls_response_head[0], 0x16); assert_eq!(tls_response_head[0], 0x16);
client_side client_side.write_all(&invalid_mtproto_record).await.unwrap();
.write_all(&invalid_mtproto_record)
.await
.unwrap();
client_side.write_all(&trailing_record).await.unwrap(); client_side.write_all(&trailing_record).await.unwrap();
tokio::time::timeout(Duration::from_secs(3), accept_task) tokio::time::timeout(Duration::from_secs(3), accept_task)
@ -274,11 +264,7 @@ async fn tls_bad_mtproto_fallback_keeps_connects_bad_accounting() {
.unwrap(); .unwrap();
let bad_after = stats_for_assert.get_connects_bad(); let bad_after = stats_for_assert.get_connects_bad();
assert_eq!( assert_eq!(bad_after, bad_before + 1, "connects_bad must increase exactly once for invalid MTProto after valid TLS");
bad_after,
bad_before + 1,
"connects_bad must increase exactly once for invalid MTProto after valid TLS"
);
} }
#[tokio::test] #[tokio::test]
@ -324,16 +310,10 @@ async fn tls_bad_mtproto_fallback_forwards_zero_length_tls_record_verbatim() {
client_side.write_all(&client_hello).await.unwrap(); client_side.write_all(&client_hello).await.unwrap();
let mut tls_response_head = [0u8; 5]; let mut tls_response_head = [0u8; 5];
client_side client_side.read_exact(&mut tls_response_head).await.unwrap();
.read_exact(&mut tls_response_head)
.await
.unwrap();
assert_eq!(tls_response_head[0], 0x16); assert_eq!(tls_response_head[0], 0x16);
client_side client_side.write_all(&invalid_mtproto_record).await.unwrap();
.write_all(&invalid_mtproto_record)
.await
.unwrap();
client_side.write_all(&trailing_record).await.unwrap(); client_side.write_all(&trailing_record).await.unwrap();
tokio::time::timeout(Duration::from_secs(3), accept_task) tokio::time::timeout(Duration::from_secs(3), accept_task)
@ -392,16 +372,10 @@ async fn tls_bad_mtproto_fallback_forwards_max_tls_record_verbatim() {
client_side.write_all(&client_hello).await.unwrap(); client_side.write_all(&client_hello).await.unwrap();
let mut tls_response_head = [0u8; 5]; let mut tls_response_head = [0u8; 5];
client_side client_side.read_exact(&mut tls_response_head).await.unwrap();
.read_exact(&mut tls_response_head)
.await
.unwrap();
assert_eq!(tls_response_head[0], 0x16); assert_eq!(tls_response_head[0], 0x16);
client_side client_side.write_all(&invalid_mtproto_record).await.unwrap();
.write_all(&invalid_mtproto_record)
.await
.unwrap();
client_side.write_all(&trailing_record).await.unwrap(); client_side.write_all(&trailing_record).await.unwrap();
tokio::time::timeout(Duration::from_secs(3), accept_task) tokio::time::timeout(Duration::from_secs(3), accept_task)
@ -425,8 +399,7 @@ async fn tls_bad_mtproto_fallback_light_fuzz_tls_record_lengths_verbatim() {
let backend_addr = listener.local_addr().unwrap(); let backend_addr = listener.local_addr().unwrap();
let secret = [0x85u8; 16]; let secret = [0x85u8; 16];
let client_hello = let client_hello = make_valid_tls_client_hello(&secret, idx as u32 + 4, 600, 0x46 + idx as u8);
make_valid_tls_client_hello(&secret, idx as u32 + 4, 600, 0x46 + idx as u8);
let invalid_mtproto = vec![0u8; HANDSHAKE_LEN]; let invalid_mtproto = vec![0u8; HANDSHAKE_LEN];
let invalid_mtproto_record = wrap_tls_application_data(&invalid_mtproto); let invalid_mtproto_record = wrap_tls_application_data(&invalid_mtproto);
@ -470,16 +443,10 @@ async fn tls_bad_mtproto_fallback_light_fuzz_tls_record_lengths_verbatim() {
client_side.write_all(&client_hello).await.unwrap(); client_side.write_all(&client_hello).await.unwrap();
let mut tls_response_head = [0u8; 5]; let mut tls_response_head = [0u8; 5];
client_side client_side.read_exact(&mut tls_response_head).await.unwrap();
.read_exact(&mut tls_response_head)
.await
.unwrap();
assert_eq!(tls_response_head[0], 0x16); assert_eq!(tls_response_head[0], 0x16);
client_side client_side.write_all(&invalid_mtproto_record).await.unwrap();
.write_all(&invalid_mtproto_record)
.await
.unwrap();
client_side.write_all(&trailing_record).await.unwrap(); client_side.write_all(&trailing_record).await.unwrap();
tokio::time::timeout(Duration::from_secs(3), accept_task) tokio::time::timeout(Duration::from_secs(3), accept_task)
@ -531,10 +498,7 @@ async fn tls_bad_mtproto_fallback_concurrent_sessions_are_isolated() {
); );
} }
assert!( assert!(remaining.is_empty(), "all expected client sessions must be matched exactly once");
remaining.is_empty(),
"all expected client sessions must be matched exactly once"
);
}); });
let mut client_tasks = Vec::with_capacity(sessions); let mut client_tasks = Vec::with_capacity(sessions);
@ -542,8 +506,7 @@ async fn tls_bad_mtproto_fallback_concurrent_sessions_are_isolated() {
for idx in 0..sessions { for idx in 0..sessions {
let harness = build_harness("86868686868686868686868686868686", backend_addr.port()); let harness = build_harness("86868686868686868686868686868686", backend_addr.port());
let secret = [0x86u8; 16]; let secret = [0x86u8; 16];
let client_hello = let client_hello = make_valid_tls_client_hello(&secret, idx as u32 + 100, 600, 0x60 + idx as u8);
make_valid_tls_client_hello(&secret, idx as u32 + 100, 600, 0x60 + idx as u8);
let invalid_mtproto = vec![0u8; HANDSHAKE_LEN]; let invalid_mtproto = vec![0u8; HANDSHAKE_LEN];
let invalid_mtproto_record = wrap_tls_application_data(&invalid_mtproto); let invalid_mtproto_record = wrap_tls_application_data(&invalid_mtproto);
let trailing_payload = vec![idx as u8; 64 + idx]; let trailing_payload = vec![idx as u8; 64 + idx];
@ -575,16 +538,10 @@ async fn tls_bad_mtproto_fallback_concurrent_sessions_are_isolated() {
client_side.write_all(&client_hello).await.unwrap(); client_side.write_all(&client_hello).await.unwrap();
let mut tls_response_head = [0u8; 5]; let mut tls_response_head = [0u8; 5];
client_side client_side.read_exact(&mut tls_response_head).await.unwrap();
.read_exact(&mut tls_response_head)
.await
.unwrap();
assert_eq!(tls_response_head[0], 0x16); assert_eq!(tls_response_head[0], 0x16);
client_side client_side.write_all(&invalid_mtproto_record).await.unwrap();
.write_all(&invalid_mtproto_record)
.await
.unwrap();
client_side.write_all(&trailing_record).await.unwrap(); client_side.write_all(&trailing_record).await.unwrap();
drop(client_side); drop(client_side);
@ -649,16 +606,10 @@ async fn tls_bad_mtproto_fallback_forwards_fragmented_client_writes_verbatim() {
client_side.write_all(&client_hello).await.unwrap(); client_side.write_all(&client_hello).await.unwrap();
let mut tls_response_head = [0u8; 5]; let mut tls_response_head = [0u8; 5];
client_side client_side.read_exact(&mut tls_response_head).await.unwrap();
.read_exact(&mut tls_response_head)
.await
.unwrap();
assert_eq!(tls_response_head[0], 0x16); assert_eq!(tls_response_head[0], 0x16);
client_side client_side.write_all(&invalid_mtproto_record).await.unwrap();
.write_all(&invalid_mtproto_record)
.await
.unwrap();
for chunk in trailing_record.chunks(3) { for chunk in trailing_record.chunks(3) {
client_side.write_all(chunk).await.unwrap(); client_side.write_all(chunk).await.unwrap();
@ -718,16 +669,10 @@ async fn tls_bad_mtproto_fallback_header_fragmentation_bytewise_is_verbatim() {
client_side.write_all(&client_hello).await.unwrap(); client_side.write_all(&client_hello).await.unwrap();
let mut tls_response_head = [0u8; 5]; let mut tls_response_head = [0u8; 5];
client_side client_side.read_exact(&mut tls_response_head).await.unwrap();
.read_exact(&mut tls_response_head)
.await
.unwrap();
assert_eq!(tls_response_head[0], 0x16); assert_eq!(tls_response_head[0], 0x16);
client_side client_side.write_all(&invalid_mtproto_record).await.unwrap();
.write_all(&invalid_mtproto_record)
.await
.unwrap();
for b in trailing_record.iter().copied() { for b in trailing_record.iter().copied() {
client_side.write_all(&[b]).await.unwrap(); client_side.write_all(&[b]).await.unwrap();
} }
@ -791,16 +736,10 @@ async fn tls_bad_mtproto_fallback_record_splitting_chaos_is_verbatim() {
client_side.write_all(&client_hello).await.unwrap(); client_side.write_all(&client_hello).await.unwrap();
let mut tls_response_head = [0u8; 5]; let mut tls_response_head = [0u8; 5];
client_side client_side.read_exact(&mut tls_response_head).await.unwrap();
.read_exact(&mut tls_response_head)
.await
.unwrap();
assert_eq!(tls_response_head[0], 0x16); assert_eq!(tls_response_head[0], 0x16);
client_side client_side.write_all(&invalid_mtproto_record).await.unwrap();
.write_all(&invalid_mtproto_record)
.await
.unwrap();
let chaos = [7usize, 1, 19, 3, 5, 31, 2, 11, 13, 17]; let chaos = [7usize, 1, 19, 3, 5, 31, 2, 11, 13, 17];
let mut pos = 0usize; let mut pos = 0usize;
@ -808,10 +747,7 @@ async fn tls_bad_mtproto_fallback_record_splitting_chaos_is_verbatim() {
while pos < trailing_record.len() { while pos < trailing_record.len() {
let step = chaos[idx % chaos.len()]; let step = chaos[idx % chaos.len()];
let end = (pos + step).min(trailing_record.len()); let end = (pos + step).min(trailing_record.len());
client_side client_side.write_all(&trailing_record[pos..end]).await.unwrap();
.write_all(&trailing_record[pos..end])
.await
.unwrap();
pos = end; pos = end;
idx += 1; idx += 1;
} }
@ -873,16 +809,10 @@ async fn tls_bad_mtproto_fallback_multiple_tls_records_are_forwarded_in_order()
client_side.write_all(&client_hello).await.unwrap(); client_side.write_all(&client_hello).await.unwrap();
let mut tls_response_head = [0u8; 5]; let mut tls_response_head = [0u8; 5];
client_side client_side.read_exact(&mut tls_response_head).await.unwrap();
.read_exact(&mut tls_response_head)
.await
.unwrap();
assert_eq!(tls_response_head[0], 0x16); assert_eq!(tls_response_head[0], 0x16);
client_side client_side.write_all(&invalid_mtproto_record).await.unwrap();
.write_all(&invalid_mtproto_record)
.await
.unwrap();
client_side.write_all(&r1).await.unwrap(); client_side.write_all(&r1).await.unwrap();
client_side.write_all(&r2).await.unwrap(); client_side.write_all(&r2).await.unwrap();
client_side.write_all(&r3).await.unwrap(); client_side.write_all(&r3).await.unwrap();
@ -918,10 +848,7 @@ async fn tls_bad_mtproto_fallback_client_half_close_propagates_eof_to_backend()
let mut tail = [0u8; 1]; let mut tail = [0u8; 1];
let n = stream.read(&mut tail).await.unwrap(); let n = stream.read(&mut tail).await.unwrap();
assert_eq!( assert_eq!(n, 0, "backend must observe EOF after client write half-close");
n, 0,
"backend must observe EOF after client write half-close"
);
}); });
let harness = build_harness("8b8b8b8b8b8b8b8b8b8b8b8b8b8b8b8b", backend_addr.port()); let harness = build_harness("8b8b8b8b8b8b8b8b8b8b8b8b8b8b8b8b", backend_addr.port());
@ -947,16 +874,10 @@ async fn tls_bad_mtproto_fallback_client_half_close_propagates_eof_to_backend()
client_side.write_all(&client_hello).await.unwrap(); client_side.write_all(&client_hello).await.unwrap();
let mut tls_response_head = [0u8; 5]; let mut tls_response_head = [0u8; 5];
client_side client_side.read_exact(&mut tls_response_head).await.unwrap();
.read_exact(&mut tls_response_head)
.await
.unwrap();
assert_eq!(tls_response_head[0], 0x16); assert_eq!(tls_response_head[0], 0x16);
client_side client_side.write_all(&invalid_mtproto_record).await.unwrap();
.write_all(&invalid_mtproto_record)
.await
.unwrap();
client_side.write_all(&trailing_record).await.unwrap(); client_side.write_all(&trailing_record).await.unwrap();
client_side.shutdown().await.unwrap(); client_side.shutdown().await.unwrap();
@ -1017,17 +938,11 @@ async fn tls_bad_mtproto_fallback_backend_half_close_after_response_is_tolerated
client_side.write_all(&client_hello).await.unwrap(); client_side.write_all(&client_hello).await.unwrap();
let mut tls_response_head = [0u8; 5]; let mut tls_response_head = [0u8; 5];
client_side client_side.read_exact(&mut tls_response_head).await.unwrap();
.read_exact(&mut tls_response_head)
.await
.unwrap();
assert_eq!(tls_response_head[0], 0x16); assert_eq!(tls_response_head[0], 0x16);
read_and_discard_tls_record_body(&mut client_side, tls_response_head).await; read_and_discard_tls_record_body(&mut client_side, tls_response_head).await;
client_side client_side.write_all(&invalid_mtproto_record).await.unwrap();
.write_all(&invalid_mtproto_record)
.await
.unwrap();
client_side.write_all(&trailing_record).await.unwrap(); client_side.write_all(&trailing_record).await.unwrap();
tokio::time::timeout(Duration::from_secs(3), accept_task) tokio::time::timeout(Duration::from_secs(3), accept_task)
@ -1079,16 +994,10 @@ async fn tls_bad_mtproto_fallback_backend_reset_after_clienthello_is_handled() {
client_side.write_all(&client_hello).await.unwrap(); client_side.write_all(&client_hello).await.unwrap();
let mut tls_response_head = [0u8; 5]; let mut tls_response_head = [0u8; 5];
client_side client_side.read_exact(&mut tls_response_head).await.unwrap();
.read_exact(&mut tls_response_head)
.await
.unwrap();
assert_eq!(tls_response_head[0], 0x16); assert_eq!(tls_response_head[0], 0x16);
client_side client_side.write_all(&invalid_mtproto_record).await.unwrap();
.write_all(&invalid_mtproto_record)
.await
.unwrap();
let write_res = client_side.write_all(&trailing_record).await; let write_res = client_side.write_all(&trailing_record).await;
assert!( assert!(
write_res.is_ok() || write_res.is_err(), write_res.is_ok() || write_res.is_err(),
@ -1159,16 +1068,10 @@ async fn tls_bad_mtproto_fallback_backend_slow_reader_preserves_byte_identity()
client_side.write_all(&client_hello).await.unwrap(); client_side.write_all(&client_hello).await.unwrap();
let mut tls_response_head = [0u8; 5]; let mut tls_response_head = [0u8; 5];
client_side client_side.read_exact(&mut tls_response_head).await.unwrap();
.read_exact(&mut tls_response_head)
.await
.unwrap();
assert_eq!(tls_response_head[0], 0x16); assert_eq!(tls_response_head[0], 0x16);
client_side client_side.write_all(&invalid_mtproto_record).await.unwrap();
.write_all(&invalid_mtproto_record)
.await
.unwrap();
client_side.write_all(&trailing_record).await.unwrap(); client_side.write_all(&trailing_record).await.unwrap();
tokio::time::timeout(Duration::from_secs(5), accept_task) tokio::time::timeout(Duration::from_secs(5), accept_task)
@ -1249,10 +1152,7 @@ async fn tls_bad_mtproto_fallback_replay_pressure_masks_replay_without_serverhel
let mut head = [0u8; 5]; let mut head = [0u8; 5];
client_side.read_exact(&mut head).await.unwrap(); client_side.read_exact(&mut head).await.unwrap();
assert_eq!(head[0], 0x16); assert_eq!(head[0], 0x16);
client_side client_side.write_all(&invalid_mtproto_record).await.unwrap();
.write_all(&invalid_mtproto_record)
.await
.unwrap();
client_side.write_all(&trailing_record).await.unwrap(); client_side.write_all(&trailing_record).await.unwrap();
} else { } else {
let mut one = [0u8; 1]; let mut one = [0u8; 1];
@ -1341,16 +1241,10 @@ async fn tls_bad_mtproto_fallback_large_multi_record_chaos_under_backpressure()
client_side.write_all(&client_hello).await.unwrap(); client_side.write_all(&client_hello).await.unwrap();
let mut tls_response_head = [0u8; 5]; let mut tls_response_head = [0u8; 5];
client_side client_side.read_exact(&mut tls_response_head).await.unwrap();
.read_exact(&mut tls_response_head)
.await
.unwrap();
assert_eq!(tls_response_head[0], 0x16); assert_eq!(tls_response_head[0], 0x16);
client_side client_side.write_all(&invalid_mtproto_record).await.unwrap();
.write_all(&invalid_mtproto_record)
.await
.unwrap();
let chaos = [5usize, 23, 11, 47, 3, 19, 29, 13, 7, 31]; let chaos = [5usize, 23, 11, 47, 3, 19, 29, 13, 7, 31];
for record in [&a, &b, &c] { for record in [&a, &b, &c] {
@ -1422,16 +1316,10 @@ async fn tls_bad_mtproto_fallback_interleaved_control_and_application_records_ve
client_side.write_all(&client_hello).await.unwrap(); client_side.write_all(&client_hello).await.unwrap();
let mut tls_response_head = [0u8; 5]; let mut tls_response_head = [0u8; 5];
client_side client_side.read_exact(&mut tls_response_head).await.unwrap();
.read_exact(&mut tls_response_head)
.await
.unwrap();
assert_eq!(tls_response_head[0], 0x16); assert_eq!(tls_response_head[0], 0x16);
client_side client_side.write_all(&invalid_mtproto_record).await.unwrap();
.write_all(&invalid_mtproto_record)
.await
.unwrap();
client_side.write_all(&ccs).await.unwrap(); client_side.write_all(&ccs).await.unwrap();
client_side.write_all(&app).await.unwrap(); client_side.write_all(&app).await.unwrap();
client_side.write_all(&alert).await.unwrap(); client_side.write_all(&alert).await.unwrap();
@ -1484,10 +1372,7 @@ async fn tls_bad_mtproto_fallback_many_short_sessions_with_chaos_no_cross_leak()
); );
} }
assert!( assert!(remaining.is_empty(), "all expected sessions must be consumed exactly once");
remaining.is_empty(),
"all expected sessions must be consumed exactly once"
);
}); });
let mut tasks = Vec::with_capacity(sessions); let mut tasks = Vec::with_capacity(sessions);
@ -1528,10 +1413,7 @@ async fn tls_bad_mtproto_fallback_many_short_sessions_with_chaos_no_cross_leak()
client_side.read_exact(&mut head).await.unwrap(); client_side.read_exact(&mut head).await.unwrap();
assert_eq!(head[0], 0x16); assert_eq!(head[0], 0x16);
client_side client_side.write_all(&invalid_mtproto_record).await.unwrap();
.write_all(&invalid_mtproto_record)
.await
.unwrap();
for chunk in record.chunks((idx % 9) + 1) { for chunk in record.chunks((idx % 9) + 1) {
client_side.write_all(chunk).await.unwrap(); client_side.write_all(chunk).await.unwrap();
} }
@ -2638,10 +2520,7 @@ async fn blackhat_coalesced_tail_parallel_32_sessions_no_cross_bleed() {
"session mixup detected in parallel-32 blackhat test" "session mixup detected in parallel-32 blackhat test"
); );
} }
assert!( assert!(remaining.is_empty(), "all expected sessions must be consumed");
remaining.is_empty(),
"all expected sessions must be consumed"
);
}); });
let mut tasks = Vec::with_capacity(sessions); let mut tasks = Vec::with_capacity(sessions);

View File

@ -1,37 +0,0 @@
use super::*;
#[test]
fn wrap_tls_application_record_empty_payload_emits_zero_length_record() {
let record = wrap_tls_application_record(&[]);
assert_eq!(record.len(), 5);
assert_eq!(record[0], TLS_RECORD_APPLICATION);
assert_eq!(&record[1..3], &TLS_VERSION);
assert_eq!(&record[3..5], &0u16.to_be_bytes());
}
#[test]
fn wrap_tls_application_record_oversized_payload_is_chunked_without_truncation() {
let total = (u16::MAX as usize) + 37;
let payload = vec![0xA5u8; total];
let record = wrap_tls_application_record(&payload);
let mut offset = 0usize;
let mut recovered = Vec::with_capacity(total);
let mut frames = 0usize;
while offset + 5 <= record.len() {
assert_eq!(record[offset], TLS_RECORD_APPLICATION);
assert_eq!(&record[offset + 1..offset + 3], &TLS_VERSION);
let len = u16::from_be_bytes([record[offset + 3], record[offset + 4]]) as usize;
let body_start = offset + 5;
let body_end = body_start + len;
assert!(body_end <= record.len(), "declared TLS record length must be in-bounds");
recovered.extend_from_slice(&record[body_start..body_end]);
offset = body_end;
frames += 1;
}
assert_eq!(offset, record.len(), "record parser must consume exact output size");
assert_eq!(frames, 2, "oversized payload should split into exactly two records");
assert_eq!(recovered, payload, "chunked records must preserve full payload");
}

View File

@ -5,10 +5,7 @@ use std::net::SocketAddr;
#[test] #[test]
fn business_scope_hint_accepts_exact_boundary_length() { fn business_scope_hint_accepts_exact_boundary_length() {
let value = format!("scope_{}", "a".repeat(MAX_SCOPE_HINT_LEN)); let value = format!("scope_{}", "a".repeat(MAX_SCOPE_HINT_LEN));
assert_eq!( assert_eq!(validated_scope_hint(&value), Some("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"));
validated_scope_hint(&value),
Some("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa")
);
} }
#[test] #[test]
@ -27,8 +24,7 @@ fn business_known_dc_uses_ipv4_table_by_default() {
#[test] #[test]
fn business_negative_dc_maps_by_absolute_value() { fn business_negative_dc_maps_by_absolute_value() {
let cfg = ProxyConfig::default(); let cfg = ProxyConfig::default();
let resolved = let resolved = get_dc_addr_static(-3, &cfg).expect("negative dc index must map by absolute value");
get_dc_addr_static(-3, &cfg).expect("negative dc index must map by absolute value");
let expected = SocketAddr::new(TG_DATACENTERS_V4[2], TG_DATACENTER_PORT); let expected = SocketAddr::new(TG_DATACENTERS_V4[2], TG_DATACENTER_PORT);
assert_eq!(resolved, expected); assert_eq!(resolved, expected);
} }
@ -49,8 +45,7 @@ fn business_unknown_dc_uses_configured_default_dc_when_in_range() {
let mut cfg = ProxyConfig::default(); let mut cfg = ProxyConfig::default();
cfg.default_dc = Some(4); cfg.default_dc = Some(4);
let resolved = let resolved = get_dc_addr_static(29_999, &cfg).expect("unknown dc must resolve to configured default");
get_dc_addr_static(29_999, &cfg).expect("unknown dc must resolve to configured default");
let expected = SocketAddr::new(TG_DATACENTERS_V4[3], TG_DATACENTER_PORT); let expected = SocketAddr::new(TG_DATACENTERS_V4[3], TG_DATACENTER_PORT);
assert_eq!(resolved, expected); assert_eq!(resolved, expected);
} }

View File

@ -12,8 +12,7 @@ fn common_invalid_override_entries_fallback_to_static_table() {
vec!["bad-address".to_string(), "still-bad".to_string()], vec!["bad-address".to_string(), "still-bad".to_string()],
); );
let resolved = let resolved = get_dc_addr_static(2, &cfg).expect("fallback to static table must still resolve");
get_dc_addr_static(2, &cfg).expect("fallback to static table must still resolve");
let expected = SocketAddr::new(TG_DATACENTERS_V4[1], TG_DATACENTER_PORT); let expected = SocketAddr::new(TG_DATACENTERS_V4[1], TG_DATACENTER_PORT);
assert_eq!(resolved, expected); assert_eq!(resolved, expected);
} }
@ -26,8 +25,7 @@ fn common_prefer_v6_with_only_ipv4_override_uses_override_instead_of_ignoring_it
cfg.dc_overrides cfg.dc_overrides
.insert("3".to_string(), vec!["203.0.113.203:443".to_string()]); .insert("3".to_string(), vec!["203.0.113.203:443".to_string()]);
let resolved = let resolved = get_dc_addr_static(3, &cfg).expect("ipv4 override must be used if no ipv6 override exists");
get_dc_addr_static(3, &cfg).expect("ipv4 override must be used if no ipv6 override exists");
assert_eq!(resolved, "203.0.113.203:443".parse::<SocketAddr>().unwrap()); assert_eq!(resolved, "203.0.113.203:443".parse::<SocketAddr>().unwrap());
} }

View File

@ -15,7 +15,7 @@ use std::time::Duration;
use tokio::io::AsyncReadExt; use tokio::io::AsyncReadExt;
use tokio::io::duplex; use tokio::io::duplex;
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio::time::{Duration as TokioDuration, timeout}; use tokio::time::{timeout, Duration as TokioDuration};
fn make_crypto_reader<R>(reader: R) -> CryptoReader<R> fn make_crypto_reader<R>(reader: R) -> CryptoReader<R>
where where
@ -79,9 +79,7 @@ fn unknown_dc_log_respects_distinct_limit() {
#[test] #[test]
fn unknown_dc_log_fails_closed_when_dedup_lock_is_poisoned() { fn unknown_dc_log_fails_closed_when_dedup_lock_is_poisoned() {
let poisoned = Arc::new(std::sync::Mutex::new( let poisoned = Arc::new(std::sync::Mutex::new(std::collections::HashSet::<i16>::new()));
std::collections::HashSet::<i16>::new(),
));
let poisoned_for_thread = poisoned.clone(); let poisoned_for_thread = poisoned.clone();
let _ = std::thread::spawn(move || { let _ = std::thread::spawn(move || {
@ -245,10 +243,7 @@ fn unknown_dc_log_path_sanitizer_accepts_safe_relative_path() {
fs::create_dir_all(&base).expect("temp test directory must be creatable"); fs::create_dir_all(&base).expect("temp test directory must be creatable");
let candidate = base.join("unknown-dc.txt"); let candidate = base.join("unknown-dc.txt");
let candidate_relative = format!( let candidate_relative = format!("target/telemt-unknown-dc-log-{}/unknown-dc.txt", std::process::id());
"target/telemt-unknown-dc-log-{}/unknown-dc.txt",
std::process::id()
);
let sanitized = sanitize_unknown_dc_log_path(&candidate_relative) let sanitized = sanitize_unknown_dc_log_path(&candidate_relative)
.expect("safe relative path with existing parent must be accepted"); .expect("safe relative path with existing parent must be accepted");
@ -330,10 +325,7 @@ fn unknown_dc_log_path_sanitizer_accepts_symlinked_parent_inside_workspace() {
let base = std::env::current_dir() let base = std::env::current_dir()
.expect("cwd must be available") .expect("cwd must be available")
.join("target") .join("target")
.join(format!( .join(format!("telemt-unknown-dc-log-symlink-internal-{}", std::process::id()));
"telemt-unknown-dc-log-symlink-internal-{}",
std::process::id()
));
let real_parent = base.join("real_parent"); let real_parent = base.join("real_parent");
fs::create_dir_all(&real_parent).expect("real parent dir must be creatable"); fs::create_dir_all(&real_parent).expect("real parent dir must be creatable");
@ -362,10 +354,7 @@ fn unknown_dc_log_path_sanitizer_accepts_symlink_parent_escape_as_canonical_path
let base = std::env::current_dir() let base = std::env::current_dir()
.expect("cwd must be available") .expect("cwd must be available")
.join("target") .join("target")
.join(format!( .join(format!("telemt-unknown-dc-log-symlink-{}", std::process::id()));
"telemt-unknown-dc-log-symlink-{}",
std::process::id()
));
fs::create_dir_all(&base).expect("symlink test directory must be creatable"); fs::create_dir_all(&base).expect("symlink test directory must be creatable");
let symlink_parent = base.join("escape_link"); let symlink_parent = base.join("escape_link");
@ -393,10 +382,7 @@ fn unknown_dc_log_path_revalidation_rejects_symlinked_target_escape() {
let base = std::env::current_dir() let base = std::env::current_dir()
.expect("cwd must be available") .expect("cwd must be available")
.join("target") .join("target")
.join(format!( .join(format!("telemt-unknown-dc-target-link-{}", std::process::id()));
"telemt-unknown-dc-target-link-{}",
std::process::id()
));
fs::create_dir_all(&base).expect("target-link base must be creatable"); fs::create_dir_all(&base).expect("target-link base must be creatable");
let outside = std::env::temp_dir().join(format!("telemt-outside-{}", std::process::id())); let outside = std::env::temp_dir().join(format!("telemt-outside-{}", std::process::id()));
@ -459,10 +445,7 @@ fn unknown_dc_open_append_rejects_broken_symlink_target_with_nofollow() {
let base = std::env::current_dir() let base = std::env::current_dir()
.expect("cwd must be available") .expect("cwd must be available")
.join("target") .join("target")
.join(format!( .join(format!("telemt-unknown-dc-broken-link-{}", std::process::id()));
"telemt-unknown-dc-broken-link-{}",
std::process::id()
));
fs::create_dir_all(&base).expect("broken-link base must be creatable"); fs::create_dir_all(&base).expect("broken-link base must be creatable");
let linked_target = base.join("unknown-dc.log"); let linked_target = base.join("unknown-dc.log");
@ -487,10 +470,7 @@ fn adversarial_unknown_dc_open_append_symlink_flip_never_writes_outside_file() {
let base = std::env::current_dir() let base = std::env::current_dir()
.expect("cwd must be available") .expect("cwd must be available")
.join("target") .join("target")
.join(format!( .join(format!("telemt-unknown-dc-symlink-flip-{}", std::process::id()));
"telemt-unknown-dc-symlink-flip-{}",
std::process::id()
));
fs::create_dir_all(&base).expect("symlink-flip base must be creatable"); fs::create_dir_all(&base).expect("symlink-flip base must be creatable");
let outside = std::env::temp_dir().join(format!( let outside = std::env::temp_dir().join(format!(
@ -550,10 +530,7 @@ fn stress_unknown_dc_open_append_regular_file_preserves_line_integrity() {
let base = std::env::current_dir() let base = std::env::current_dir()
.expect("cwd must be available") .expect("cwd must be available")
.join("target") .join("target")
.join(format!( .join(format!("telemt-unknown-dc-open-stress-{}", std::process::id()));
"telemt-unknown-dc-open-stress-{}",
std::process::id()
));
fs::create_dir_all(&base).expect("stress open base must be creatable"); fs::create_dir_all(&base).expect("stress open base must be creatable");
let target = base.join("unknown-dc.log"); let target = base.join("unknown-dc.log");
@ -579,10 +556,7 @@ fn unknown_dc_log_path_revalidation_accepts_regular_existing_target() {
let base = std::env::current_dir() let base = std::env::current_dir()
.expect("cwd must be available") .expect("cwd must be available")
.join("target") .join("target")
.join(format!( .join(format!("telemt-unknown-dc-safe-target-{}", std::process::id()));
"telemt-unknown-dc-safe-target-{}",
std::process::id()
));
fs::create_dir_all(&base).expect("safe target base must be creatable"); fs::create_dir_all(&base).expect("safe target base must be creatable");
let target = base.join("unknown-dc.log"); let target = base.join("unknown-dc.log");
@ -592,8 +566,8 @@ fn unknown_dc_log_path_revalidation_accepts_regular_existing_target() {
"target/telemt-unknown-dc-safe-target-{}/unknown-dc.log", "target/telemt-unknown-dc-safe-target-{}/unknown-dc.log",
std::process::id() std::process::id()
); );
let sanitized = let sanitized = sanitize_unknown_dc_log_path(&rel_candidate)
sanitize_unknown_dc_log_path(&rel_candidate).expect("safe candidate must sanitize"); .expect("safe candidate must sanitize");
assert!( assert!(
unknown_dc_log_path_is_still_safe(&sanitized), unknown_dc_log_path_is_still_safe(&sanitized),
"revalidation must allow safe existing regular files" "revalidation must allow safe existing regular files"
@ -605,10 +579,7 @@ fn unknown_dc_log_path_revalidation_rejects_deleted_parent_after_sanitize() {
let base = std::env::current_dir() let base = std::env::current_dir()
.expect("cwd must be available") .expect("cwd must be available")
.join("target") .join("target")
.join(format!( .join(format!("telemt-unknown-dc-vanish-parent-{}", std::process::id()));
"telemt-unknown-dc-vanish-parent-{}",
std::process::id()
));
fs::create_dir_all(&base).expect("vanish-parent base must be creatable"); fs::create_dir_all(&base).expect("vanish-parent base must be creatable");
let rel_candidate = format!( let rel_candidate = format!(
@ -633,10 +604,7 @@ fn unknown_dc_log_path_revalidation_rejects_parent_swapped_to_symlink() {
let parent = std::env::current_dir() let parent = std::env::current_dir()
.expect("cwd must be available") .expect("cwd must be available")
.join("target") .join("target")
.join(format!( .join(format!("telemt-unknown-dc-parent-swap-{}", std::process::id()));
"telemt-unknown-dc-parent-swap-{}",
std::process::id()
));
fs::create_dir_all(&parent).expect("parent-swap test parent must be creatable"); fs::create_dir_all(&parent).expect("parent-swap test parent must be creatable");
let rel_candidate = format!( let rel_candidate = format!(
@ -665,10 +633,7 @@ fn adversarial_check_then_symlink_flip_is_blocked_by_nofollow_open() {
let parent = std::env::current_dir() let parent = std::env::current_dir()
.expect("cwd must be available") .expect("cwd must be available")
.join("target") .join("target")
.join(format!( .join(format!("telemt-unknown-dc-check-open-race-{}", std::process::id()));
"telemt-unknown-dc-check-open-race-{}",
std::process::id()
));
fs::create_dir_all(&parent).expect("check-open-race parent must be creatable"); fs::create_dir_all(&parent).expect("check-open-race parent must be creatable");
let target = parent.join("unknown-dc.log"); let target = parent.join("unknown-dc.log");
@ -677,7 +642,8 @@ fn adversarial_check_then_symlink_flip_is_blocked_by_nofollow_open() {
"target/telemt-unknown-dc-check-open-race-{}/unknown-dc.log", "target/telemt-unknown-dc-check-open-race-{}/unknown-dc.log",
std::process::id() std::process::id()
); );
let sanitized = sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize"); let sanitized = sanitize_unknown_dc_log_path(&rel_candidate)
.expect("candidate must sanitize");
assert!( assert!(
unknown_dc_log_path_is_still_safe(&sanitized), unknown_dc_log_path_is_still_safe(&sanitized),
@ -709,10 +675,7 @@ fn adversarial_parent_swap_after_check_is_blocked_by_anchored_open() {
let base = std::env::current_dir() let base = std::env::current_dir()
.expect("cwd must be available") .expect("cwd must be available")
.join("target") .join("target")
.join(format!( .join(format!("telemt-unknown-dc-parent-swap-openat-{}", std::process::id()));
"telemt-unknown-dc-parent-swap-openat-{}",
std::process::id()
));
fs::create_dir_all(&base).expect("parent-swap-openat base must be creatable"); fs::create_dir_all(&base).expect("parent-swap-openat base must be creatable");
let rel_candidate = format!( let rel_candidate = format!(
@ -745,10 +708,7 @@ fn adversarial_parent_swap_after_check_is_blocked_by_anchored_open() {
.expect_err("anchored open must fail when parent is swapped to symlink"); .expect_err("anchored open must fail when parent is swapped to symlink");
let raw = err.raw_os_error(); let raw = err.raw_os_error();
assert!( assert!(
matches!( matches!(raw, Some(libc::ELOOP) | Some(libc::ENOTDIR) | Some(libc::ENOENT)),
raw,
Some(libc::ELOOP) | Some(libc::ENOTDIR) | Some(libc::ENOENT)
),
"anchored open must fail closed on parent swap race, got raw_os_error={raw:?}" "anchored open must fail closed on parent swap race, got raw_os_error={raw:?}"
); );
assert!( assert!(
@ -757,284 +717,6 @@ fn adversarial_parent_swap_after_check_is_blocked_by_anchored_open() {
); );
} }
#[cfg(unix)]
#[test]
fn anchored_open_nix_path_writes_expected_lines() {
let base = std::env::current_dir()
.expect("cwd must be available")
.join("target")
.join(format!(
"telemt-unknown-dc-anchored-open-ok-{}",
std::process::id()
));
fs::create_dir_all(&base).expect("anchored-open-ok base must be creatable");
let rel_candidate = format!(
"target/telemt-unknown-dc-anchored-open-ok-{}/unknown-dc.log",
std::process::id()
);
let sanitized =
sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize");
let _ = fs::remove_file(&sanitized.resolved_path);
let mut first = open_unknown_dc_log_append_anchored(&sanitized)
.expect("anchored open must create log file in allowed parent");
append_unknown_dc_line(&mut first, 31_200).expect("first append must succeed");
let mut second = open_unknown_dc_log_append_anchored(&sanitized)
.expect("anchored reopen must succeed for existing regular file");
append_unknown_dc_line(&mut second, 31_201).expect("second append must succeed");
let content =
fs::read_to_string(&sanitized.resolved_path).expect("anchored log file must be readable");
let lines: Vec<&str> = content.lines().filter(|line| !line.trim().is_empty()).collect();
assert_eq!(lines.len(), 2, "expected one line per anchored append call");
assert!(
lines.contains(&"dc_idx=31200") && lines.contains(&"dc_idx=31201"),
"anchored append output must contain both expected dc_idx lines"
);
}
#[cfg(unix)]
#[test]
fn anchored_open_parallel_appends_preserve_line_integrity() {
let base = std::env::current_dir()
.expect("cwd must be available")
.join("target")
.join(format!(
"telemt-unknown-dc-anchored-open-parallel-{}",
std::process::id()
));
fs::create_dir_all(&base).expect("anchored-open-parallel base must be creatable");
let rel_candidate = format!(
"target/telemt-unknown-dc-anchored-open-parallel-{}/unknown-dc.log",
std::process::id()
);
let sanitized =
sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize");
let _ = fs::remove_file(&sanitized.resolved_path);
let mut workers = Vec::new();
for idx in 0..64i16 {
let sanitized = sanitized.clone();
workers.push(std::thread::spawn(move || {
let mut file = open_unknown_dc_log_append_anchored(&sanitized)
.expect("anchored open must succeed in worker");
append_unknown_dc_line(&mut file, 32_000 + idx).expect("worker append must succeed");
}));
}
for worker in workers {
worker.join().expect("worker must not panic");
}
let content =
fs::read_to_string(&sanitized.resolved_path).expect("parallel log file must be readable");
let lines: Vec<&str> = content.lines().filter(|line| !line.trim().is_empty()).collect();
assert_eq!(lines.len(), 64, "expected one complete line per worker append");
for line in lines {
assert!(
line.starts_with("dc_idx="),
"line must keep dc_idx prefix and not be interleaved: {line}"
);
let value = line
.strip_prefix("dc_idx=")
.expect("prefix checked above")
.parse::<i16>();
assert!(
value.is_ok(),
"line payload must remain parseable i16 and not be corrupted: {line}"
);
}
}
#[cfg(unix)]
#[test]
fn anchored_open_creates_private_0600_file_permissions() {
use std::os::unix::fs::PermissionsExt;
let base = std::env::current_dir()
.expect("cwd must be available")
.join("target")
.join(format!(
"telemt-unknown-dc-anchored-perms-{}",
std::process::id()
));
fs::create_dir_all(&base).expect("anchored-perms base must be creatable");
let rel_candidate = format!(
"target/telemt-unknown-dc-anchored-perms-{}/unknown-dc.log",
std::process::id()
);
let sanitized =
sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize");
let _ = fs::remove_file(&sanitized.resolved_path);
let mut file = open_unknown_dc_log_append_anchored(&sanitized)
.expect("anchored open must create file with restricted mode");
append_unknown_dc_line(&mut file, 31_210).expect("initial append must succeed");
drop(file);
let mode = fs::metadata(&sanitized.resolved_path)
.expect("created log file metadata must be readable")
.permissions()
.mode()
& 0o777;
assert_eq!(
mode, 0o600,
"anchored open must create unknown-dc log file with owner-only rw permissions"
);
}
#[cfg(unix)]
#[test]
fn anchored_open_rejects_existing_symlink_target() {
use std::os::unix::fs::symlink;
let base = std::env::current_dir()
.expect("cwd must be available")
.join("target")
.join(format!(
"telemt-unknown-dc-anchored-symlink-target-{}",
std::process::id()
));
fs::create_dir_all(&base).expect("anchored-symlink-target base must be creatable");
let rel_candidate = format!(
"target/telemt-unknown-dc-anchored-symlink-target-{}/unknown-dc.log",
std::process::id()
);
let sanitized =
sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize");
let outside = std::env::temp_dir().join(format!(
"telemt-unknown-dc-anchored-symlink-outside-{}.log",
std::process::id()
));
fs::write(&outside, "outside\n").expect("outside baseline file must be writable");
let _ = fs::remove_file(&sanitized.resolved_path);
symlink(&outside, &sanitized.resolved_path)
.expect("target symlink for anchored-open rejection test must be creatable");
let err = open_unknown_dc_log_append_anchored(&sanitized)
.expect_err("anchored open must reject symlinked filename target");
assert_eq!(
err.raw_os_error(),
Some(libc::ELOOP),
"anchored open should fail closed with ELOOP on symlinked target"
);
}
#[cfg(unix)]
#[test]
fn anchored_open_high_contention_multi_write_preserves_complete_lines() {
let base = std::env::current_dir()
.expect("cwd must be available")
.join("target")
.join(format!(
"telemt-unknown-dc-anchored-contention-{}",
std::process::id()
));
fs::create_dir_all(&base).expect("anchored-contention base must be creatable");
let rel_candidate = format!(
"target/telemt-unknown-dc-anchored-contention-{}/unknown-dc.log",
std::process::id()
);
let sanitized =
sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize");
let _ = fs::remove_file(&sanitized.resolved_path);
let workers = 24usize;
let rounds = 40usize;
let mut threads = Vec::new();
for worker in 0..workers {
let sanitized = sanitized.clone();
threads.push(std::thread::spawn(move || {
for round in 0..rounds {
let mut file = open_unknown_dc_log_append_anchored(&sanitized)
.expect("anchored open must succeed under contention");
let dc_idx = 20_000i16.wrapping_add((worker * rounds + round) as i16);
append_unknown_dc_line(&mut file, dc_idx)
.expect("each contention append must complete");
}
}));
}
for thread in threads {
thread.join().expect("contention worker must not panic");
}
let content = fs::read_to_string(&sanitized.resolved_path)
.expect("contention output file must be readable");
let lines: Vec<&str> = content.lines().filter(|line| !line.trim().is_empty()).collect();
assert_eq!(
lines.len(),
workers * rounds,
"every contention append must produce exactly one line"
);
let mut unique = std::collections::HashSet::new();
for line in lines {
assert!(
line.starts_with("dc_idx="),
"line must preserve expected prefix under heavy contention: {line}"
);
let value = line
.strip_prefix("dc_idx=")
.expect("prefix validated")
.parse::<i16>()
.expect("line payload must remain parseable i16 under contention");
unique.insert(value);
}
assert_eq!(
unique.len(),
workers * rounds,
"contention output must not lose or duplicate logical writes"
);
}
#[cfg(unix)]
#[test]
fn append_unknown_dc_line_returns_error_for_read_only_descriptor() {
let base = std::env::current_dir()
.expect("cwd must be available")
.join("target")
.join(format!(
"telemt-unknown-dc-append-ro-{}",
std::process::id()
));
fs::create_dir_all(&base).expect("append-ro base must be creatable");
let rel_candidate = format!(
"target/telemt-unknown-dc-append-ro-{}/unknown-dc.log",
std::process::id()
);
let sanitized =
sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize");
fs::write(&sanitized.resolved_path, "seed\n").expect("seed file must be writable");
let mut readonly = std::fs::OpenOptions::new()
.read(true)
.open(&sanitized.resolved_path)
.expect("readonly file open must succeed");
append_unknown_dc_line(&mut readonly, 31_222)
.expect_err("append on readonly descriptor must fail closed");
let content_after =
fs::read_to_string(&sanitized.resolved_path).expect("seed file must remain readable");
assert_eq!(
nonempty_line_count(&content_after),
1,
"failed readonly append must not modify persisted unknown-dc log content"
);
}
#[tokio::test] #[tokio::test]
async fn unknown_dc_absolute_log_path_writes_one_entry() { async fn unknown_dc_absolute_log_path_writes_one_entry() {
let _guard = unknown_dc_test_lock() let _guard = unknown_dc_test_lock()
@ -1214,10 +896,7 @@ async fn unknown_dc_symlinked_target_escape_is_not_written_integration() {
let base = std::env::current_dir() let base = std::env::current_dir()
.expect("cwd must be available") .expect("cwd must be available")
.join("target") .join("target")
.join(format!( .join(format!("telemt-unknown-dc-no-write-link-{}", std::process::id()));
"telemt-unknown-dc-no-write-link-{}",
std::process::id()
));
fs::create_dir_all(&base).expect("integration symlink base must be creatable"); fs::create_dir_all(&base).expect("integration symlink base must be creatable");
let outside = std::env::temp_dir().join(format!( let outside = std::env::temp_dir().join(format!(
@ -1345,17 +1024,11 @@ async fn direct_relay_abort_midflight_releases_route_gauge() {
} }
}) })
.await; .await;
assert!( assert!(started.is_ok(), "direct relay must increment route gauge before abort");
started.is_ok(),
"direct relay must increment route gauge before abort"
);
relay_task.abort(); relay_task.abort();
let joined = relay_task.await; let joined = relay_task.await;
assert!( assert!(joined.is_err(), "aborted direct relay task must return join error");
joined.is_err(),
"aborted direct relay task must return join error"
);
tokio::time::sleep(Duration::from_millis(20)).await; tokio::time::sleep(Duration::from_millis(20)).await;
assert_eq!( assert_eq!(
@ -1640,22 +1313,15 @@ fn prefer_v6_override_matrix_prefers_matching_family_then_degrades_safely() {
], ],
); );
let a = get_dc_addr_static(dc_idx, &cfg_a).expect("v6+v4 override set must resolve"); let a = get_dc_addr_static(dc_idx, &cfg_a).expect("v6+v4 override set must resolve");
assert!( assert!(a.is_ipv6(), "prefer_v6 should choose v6 override when present");
a.is_ipv6(),
"prefer_v6 should choose v6 override when present"
);
let mut cfg_b = ProxyConfig::default(); let mut cfg_b = ProxyConfig::default();
cfg_b.network.prefer = 6; cfg_b.network.prefer = 6;
cfg_b.network.ipv6 = Some(true); cfg_b.network.ipv6 = Some(true);
cfg_b cfg_b.dc_overrides
.dc_overrides
.insert(dc_idx.to_string(), vec!["203.0.113.91:443".to_string()]); .insert(dc_idx.to_string(), vec!["203.0.113.91:443".to_string()]);
let b = get_dc_addr_static(dc_idx, &cfg_b).expect("v4-only override must still resolve"); let b = get_dc_addr_static(dc_idx, &cfg_b).expect("v4-only override must still resolve");
assert!( assert!(b.is_ipv4(), "when no v6 override exists, v4 override must be used");
b.is_ipv4(),
"when no v6 override exists, v4 override must be used"
);
let mut cfg_c = ProxyConfig::default(); let mut cfg_c = ProxyConfig::default();
cfg_c.network.prefer = 6; cfg_c.network.prefer = 6;
@ -1684,8 +1350,7 @@ fn prefer_v6_override_matrix_ignores_invalid_entries_and_keeps_fail_closed_fallb
], ],
); );
let addr = get_dc_addr_static(dc_idx, &cfg) let addr = get_dc_addr_static(dc_idx, &cfg).expect("at least one valid override must keep resolution alive");
.expect("at least one valid override must keep resolution alive");
assert_eq!(addr, "203.0.113.55:443".parse::<SocketAddr>().unwrap()); assert_eq!(addr, "203.0.113.55:443".parse::<SocketAddr>().unwrap());
} }
@ -1705,10 +1370,7 @@ fn stress_prefer_v6_override_matrix_is_deterministic_under_mixed_inputs() {
let first = get_dc_addr_static(idx, &cfg).expect("first lookup must resolve"); let first = get_dc_addr_static(idx, &cfg).expect("first lookup must resolve");
let second = get_dc_addr_static(idx, &cfg).expect("second lookup must resolve"); let second = get_dc_addr_static(idx, &cfg).expect("second lookup must resolve");
assert_eq!( assert_eq!(first, second, "override resolution must stay deterministic for dc {idx}");
first, second,
"override resolution must stay deterministic for dc {idx}"
);
assert!(first.is_ipv6(), "dc {idx}: v6 override should be preferred"); assert!(first.is_ipv6(), "dc {idx}: v6 override should be preferred");
} }
} }
@ -1717,12 +1379,12 @@ fn stress_prefer_v6_override_matrix_is_deterministic_under_mixed_inputs() {
async fn negative_direct_relay_dc_connection_refused_fails_fast() { async fn negative_direct_relay_dc_connection_refused_fails_fast() {
let (client_reader_side, _client_writer_side) = duplex(1024); let (client_reader_side, _client_writer_side) = duplex(1024);
let (_client_reader_relay, client_writer_side) = duplex(1024); let (_client_reader_relay, client_writer_side) = duplex(1024);
let key = [0u8; 32]; let key = [0u8; 32];
let iv = 0u128; let iv = 0u128;
let client_reader = CryptoReader::new(client_reader_side, AesCtr::new(&key, iv)); let client_reader = CryptoReader::new(client_reader_side, AesCtr::new(&key, iv));
let client_writer = CryptoWriter::new(client_writer_side, AesCtr::new(&key, iv), 1024); let client_writer = CryptoWriter::new(client_writer_side, AesCtr::new(&key, iv), 1024);
let stats = Arc::new(Stats::new()); let stats = Arc::new(Stats::new());
let buffer_pool = Arc::new(BufferPool::with_config(1024, 1)); let buffer_pool = Arc::new(BufferPool::with_config(1024, 1));
let rng = Arc::new(SecureRandom::new()); let rng = Arc::new(SecureRandom::new());
@ -1735,11 +1397,9 @@ async fn negative_direct_relay_dc_connection_refused_fails_fast() {
drop(listener); drop(listener);
let mut config_with_override = ProxyConfig::default(); let mut config_with_override = ProxyConfig::default();
config_with_override config_with_override.dc_overrides.insert("1".to_string(), vec![dc_addr.to_string()]);
.dc_overrides
.insert("1".to_string(), vec![dc_addr.to_string()]);
let config = Arc::new(config_with_override); let config = Arc::new(config_with_override);
let upstream_manager = Arc::new(UpstreamManager::new( let upstream_manager = Arc::new(UpstreamManager::new(
vec![UpstreamConfig { vec![UpstreamConfig {
enabled: true, enabled: true,
@ -1758,7 +1418,7 @@ async fn negative_direct_relay_dc_connection_refused_fails_fast() {
false, false,
stats.clone(), stats.clone(),
)); ));
let success = HandshakeSuccess { let success = HandshakeSuccess {
user: "test-user".to_string(), user: "test-user".to_string(),
peer: "127.0.0.1:12345".parse().unwrap(), peer: "127.0.0.1:12345".parse().unwrap(),
@ -1800,21 +1460,21 @@ async fn negative_direct_relay_dc_connection_refused_fails_fast() {
async fn adversarial_direct_relay_cutover_integrity() { async fn adversarial_direct_relay_cutover_integrity() {
let (client_reader_side, _client_writer_side) = duplex(1024); let (client_reader_side, _client_writer_side) = duplex(1024);
let (_client_reader_relay, client_writer_side) = duplex(1024); let (_client_reader_relay, client_writer_side) = duplex(1024);
let key = [0u8; 32]; let key = [0u8; 32];
let iv = 0u128; let iv = 0u128;
let client_reader = CryptoReader::new(client_reader_side, AesCtr::new(&key, iv)); let client_reader = CryptoReader::new(client_reader_side, AesCtr::new(&key, iv));
let client_writer = CryptoWriter::new(client_writer_side, AesCtr::new(&key, iv), 1024); let client_writer = CryptoWriter::new(client_writer_side, AesCtr::new(&key, iv), 1024);
let stats = Arc::new(Stats::new()); let stats = Arc::new(Stats::new());
let buffer_pool = Arc::new(BufferPool::with_config(1024, 1)); let buffer_pool = Arc::new(BufferPool::with_config(1024, 1));
let rng = Arc::new(SecureRandom::new()); let rng = Arc::new(SecureRandom::new());
let route_runtime = RouteRuntimeController::new(RelayRouteMode::Direct); let route_runtime = RouteRuntimeController::new(RelayRouteMode::Direct);
// Mock upstream server. // Mock upstream server.
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let dc_addr = listener.local_addr().unwrap(); let dc_addr = listener.local_addr().unwrap();
tokio::spawn(async move { tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap(); let (mut stream, _) = listener.accept().await.unwrap();
// Read handshake nonce. // Read handshake nonce.
@ -1825,11 +1485,9 @@ async fn adversarial_direct_relay_cutover_integrity() {
}); });
let mut config_with_override = ProxyConfig::default(); let mut config_with_override = ProxyConfig::default();
config_with_override config_with_override.dc_overrides.insert("1".to_string(), vec![dc_addr.to_string()]);
.dc_overrides
.insert("1".to_string(), vec![dc_addr.to_string()]);
let config = Arc::new(config_with_override); let config = Arc::new(config_with_override);
let upstream_manager = Arc::new(UpstreamManager::new( let upstream_manager = Arc::new(UpstreamManager::new(
vec![UpstreamConfig { vec![UpstreamConfig {
enabled: true, enabled: true,
@ -1848,7 +1506,7 @@ async fn adversarial_direct_relay_cutover_integrity() {
false, false,
stats.clone(), stats.clone(),
)); ));
let success = HandshakeSuccess { let success = HandshakeSuccess {
user: "test-user".to_string(), user: "test-user".to_string(),
peer: "127.0.0.1:12345".parse().unwrap(), peer: "127.0.0.1:12345".parse().unwrap(),
@ -1876,8 +1534,7 @@ async fn adversarial_direct_relay_cutover_integrity() {
runtime_clone.subscribe(), runtime_clone.subscribe(),
runtime_clone.snapshot(), runtime_clone.snapshot(),
0xABCD_1234, 0xABCD_1234,
) ).await
.await
}); });
timeout(TokioDuration::from_secs(2), async { timeout(TokioDuration::from_secs(2), async {
@ -1890,10 +1547,10 @@ async fn adversarial_direct_relay_cutover_integrity() {
}) })
.await .await
.expect("direct relay session must start before cutover"); .expect("direct relay session must start before cutover");
// Trigger cutover. // Trigger cutover.
route_runtime.set_mode(RelayRouteMode::Middle).unwrap(); route_runtime.set_mode(RelayRouteMode::Middle).unwrap();
// The session should terminate after the staggered delay (1000-2000ms). // The session should terminate after the staggered delay (1000-2000ms).
let result = timeout(TokioDuration::from_secs(5), session_task) let result = timeout(TokioDuration::from_secs(5), session_task)
.await .await

View File

@ -40,7 +40,9 @@ fn subtle_light_fuzz_scope_hint_matches_oracle() {
}; };
!rest.is_empty() !rest.is_empty()
&& rest.len() <= MAX_SCOPE_HINT_LEN && rest.len() <= MAX_SCOPE_HINT_LEN
&& rest.bytes().all(|b| b.is_ascii_alphanumeric() || b == b'-') && rest
.bytes()
.all(|b| b.is_ascii_alphanumeric() || b == b'-')
} }
let mut state: u64 = 0xC0FF_EE11_D15C_AFE5; let mut state: u64 = 0xC0FF_EE11_D15C_AFE5;
@ -92,10 +94,7 @@ fn subtle_light_fuzz_dc_resolution_never_panics_and_preserves_port() {
let dc_idx = (state as i16).wrapping_sub(16_384); let dc_idx = (state as i16).wrapping_sub(16_384);
let resolved = get_dc_addr_static(dc_idx, &cfg).expect("dc resolution must never fail"); let resolved = get_dc_addr_static(dc_idx, &cfg).expect("dc resolution must never fail");
assert_eq!( assert_eq!(resolved.port(), crate::protocol::constants::TG_DATACENTER_PORT);
resolved.port(),
crate::protocol::constants::TG_DATACENTER_PORT
);
let expect_v6 = cfg.network.prefer == 6 && cfg.network.ipv6.unwrap_or(true); let expect_v6 = cfg.network.prefer == 6 && cfg.network.ipv6.unwrap_or(true);
assert_eq!(resolved.is_ipv6(), expect_v6); assert_eq!(resolved.is_ipv6(), expect_v6);
} }
@ -167,9 +166,7 @@ async fn subtle_integration_parallel_unique_dcs_log_unique_lines() {
cfg.general.unknown_dc_log_path = Some(rel_file); cfg.general.unknown_dc_log_path = Some(rel_file);
let cfg = Arc::new(cfg); let cfg = Arc::new(cfg);
let dcs = [ let dcs = [31_901_i16, 31_902, 31_903, 31_904, 31_905, 31_906, 31_907, 31_908];
31_901_i16, 31_902, 31_903, 31_904, 31_905, 31_906, 31_907, 31_908,
];
let mut tasks = Vec::new(); let mut tasks = Vec::new();
for dc in dcs { for dc in dcs {

View File

@ -1,14 +1,10 @@
use super::*; use super::*;
use crate::crypto::sha256;
use std::net::{IpAddr, Ipv4Addr};
use std::sync::Arc; use std::sync::Arc;
use std::net::{IpAddr, Ipv4Addr};
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use crate::crypto::sha256;
fn make_valid_mtproto_handshake( fn make_valid_mtproto_handshake(secret_hex: &str, proto_tag: ProtoTag, dc_idx: i16) -> [u8; HANDSHAKE_LEN] {
secret_hex: &str,
proto_tag: ProtoTag,
dc_idx: i16,
) -> [u8; HANDSHAKE_LEN] {
let secret = hex::decode(secret_hex).expect("secret hex must decode"); let secret = hex::decode(secret_hex).expect("secret hex must decode");
let mut handshake = [0x5Au8; HANDSHAKE_LEN]; let mut handshake = [0x5Au8; HANDSHAKE_LEN];
for (idx, b) in handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN] for (idx, b) in handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]
@ -53,9 +49,7 @@ fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> {
fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig { fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig {
let mut cfg = ProxyConfig::default(); let mut cfg = ProxyConfig::default();
cfg.access.users.clear(); cfg.access.users.clear();
cfg.access cfg.access.users.insert("user".to_string(), secret_hex.to_string());
.users
.insert("user".to_string(), secret_hex.to_string());
cfg.access.ignore_time_skew = true; cfg.access.ignore_time_skew = true;
cfg.general.modes.secure = true; cfg.general.modes.secure = true;
cfg cfg
@ -77,19 +71,9 @@ async fn mtproto_handshake_bit_flip_anywhere_rejected() {
let peer: SocketAddr = "192.0.2.1:12345".parse().unwrap(); let peer: SocketAddr = "192.0.2.1:12345".parse().unwrap();
// Baseline check // Baseline check
let res = handle_mtproto_handshake( let res = handle_mtproto_handshake(&base, tokio::io::empty(), tokio::io::sink(), peer, &config, &replay_checker, false, None).await;
&base,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
false,
None,
)
.await;
match res { match res {
HandshakeResult::Success(_) => {} HandshakeResult::Success(_) => {},
_ => panic!("Baseline failed: expected Success"), _ => panic!("Baseline failed: expected Success"),
} }
@ -97,21 +81,8 @@ async fn mtproto_handshake_bit_flip_anywhere_rejected() {
for byte_pos in SKIP_LEN..HANDSHAKE_LEN { for byte_pos in SKIP_LEN..HANDSHAKE_LEN {
let mut h = base; let mut h = base;
h[byte_pos] ^= 0x01; // Flip 1 bit h[byte_pos] ^= 0x01; // Flip 1 bit
let res = handle_mtproto_handshake( let res = handle_mtproto_handshake(&h, tokio::io::empty(), tokio::io::sink(), peer, &config, &replay_checker, false, None).await;
&h, assert!(matches!(res, HandshakeResult::BadClient { .. }), "Flip at byte {byte_pos} bit 0 must be rejected");
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
false,
None,
)
.await;
assert!(
matches!(res, HandshakeResult::BadClient { .. }),
"Flip at byte {byte_pos} bit 0 must be rejected"
);
} }
} }
@ -128,51 +99,25 @@ async fn mtproto_handshake_timing_neutrality_mocked() {
let peer: SocketAddr = "192.0.2.2:54321".parse().unwrap(); let peer: SocketAddr = "192.0.2.2:54321".parse().unwrap();
const ITER: usize = 50; const ITER: usize = 50;
let mut start = Instant::now(); let mut start = Instant::now();
for _ in 0..ITER { for _ in 0..ITER {
let _ = handle_mtproto_handshake( let _ = handle_mtproto_handshake(&base, tokio::io::empty(), tokio::io::sink(), peer, &config, &replay_checker, false, None).await;
&base,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
false,
None,
)
.await;
} }
let duration_success = start.elapsed(); let duration_success = start.elapsed();
start = Instant::now(); start = Instant::now();
for i in 0..ITER { for i in 0..ITER {
let mut h = base; let mut h = base;
h[SKIP_LEN + (i % 48)] ^= 0xFF; h[SKIP_LEN + (i % 48)] ^= 0xFF;
let _ = handle_mtproto_handshake( let _ = handle_mtproto_handshake(&h, tokio::io::empty(), tokio::io::sink(), peer, &config, &replay_checker, false, None).await;
&h,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
false,
None,
)
.await;
} }
let duration_fail = start.elapsed(); let duration_fail = start.elapsed();
let avg_diff_ms = (duration_success.as_millis() as f64 - duration_fail.as_millis() as f64) let avg_diff_ms = (duration_success.as_millis() as f64 - duration_fail.as_millis() as f64).abs() / ITER as f64;
.abs()
/ ITER as f64;
// Threshold (loose for CI) // Threshold (loose for CI)
assert!( assert!(avg_diff_ms < 100.0, "Timing difference too large: {} ms/iter", avg_diff_ms);
avg_diff_ms < 100.0,
"Timing difference too large: {} ms/iter",
avg_diff_ms
);
} }
// ------------------------------------------------------------------ // ------------------------------------------------------------------
@ -185,13 +130,13 @@ async fn auth_probe_throttle_saturation_stress() {
clear_auth_probe_state_for_testing(); clear_auth_probe_state_for_testing();
let now = Instant::now(); let now = Instant::now();
// Record enough failures for one IP to trigger backoff // Record enough failures for one IP to trigger backoff
let target_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)); let target_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
for _ in 0..AUTH_PROBE_BACKOFF_START_FAILS { for _ in 0..AUTH_PROBE_BACKOFF_START_FAILS {
auth_probe_record_failure(target_ip, now); auth_probe_record_failure(target_ip, now);
} }
assert!(auth_probe_is_throttled(target_ip, now)); assert!(auth_probe_is_throttled(target_ip, now));
// Stress test with many unique IPs // Stress test with many unique IPs
@ -200,7 +145,10 @@ async fn auth_probe_throttle_saturation_stress() {
auth_probe_record_failure(ip, now); auth_probe_record_failure(ip, now);
} }
let tracked = AUTH_PROBE_STATE.get().map(|state| state.len()).unwrap_or(0); let tracked = AUTH_PROBE_STATE
.get()
.map(|state| state.len())
.unwrap_or(0);
assert!( assert!(
tracked <= AUTH_PROBE_TRACK_MAX_ENTRIES, tracked <= AUTH_PROBE_TRACK_MAX_ENTRIES,
"auth probe state grew past hard cap: {tracked} > {AUTH_PROBE_TRACK_MAX_ENTRIES}" "auth probe state grew past hard cap: {tracked} > {AUTH_PROBE_TRACK_MAX_ENTRIES}"
@ -218,17 +166,7 @@ async fn mtproto_handshake_abridged_prefix_rejected() {
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let peer: SocketAddr = "192.0.2.3:12345".parse().unwrap(); let peer: SocketAddr = "192.0.2.3:12345".parse().unwrap();
let res = handle_mtproto_handshake( let res = handle_mtproto_handshake(&handshake, tokio::io::empty(), tokio::io::sink(), peer, &config, &replay_checker, false, None).await;
&handshake,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
false,
None,
)
.await;
// MTProxy stops immediately on 0xef // MTProxy stops immediately on 0xef
assert!(matches!(res, HandshakeResult::BadClient { .. })); assert!(matches!(res, HandshakeResult::BadClient { .. }));
} }
@ -240,17 +178,11 @@ async fn mtproto_handshake_preferred_user_mismatch_continues() {
let secret1_hex = "11111111111111111111111111111111"; let secret1_hex = "11111111111111111111111111111111";
let secret2_hex = "22222222222222222222222222222222"; let secret2_hex = "22222222222222222222222222222222";
let base = make_valid_mtproto_handshake(secret2_hex, ProtoTag::Secure, 1); let base = make_valid_mtproto_handshake(secret2_hex, ProtoTag::Secure, 1);
let mut config = ProxyConfig::default(); let mut config = ProxyConfig::default();
config config.access.users.insert("user1".to_string(), secret1_hex.to_string());
.access config.access.users.insert("user2".to_string(), secret2_hex.to_string());
.users
.insert("user1".to_string(), secret1_hex.to_string());
config
.access
.users
.insert("user2".to_string(), secret2_hex.to_string());
config.access.ignore_time_skew = true; config.access.ignore_time_skew = true;
config.general.modes.secure = true; config.general.modes.secure = true;
@ -258,17 +190,7 @@ async fn mtproto_handshake_preferred_user_mismatch_continues() {
let peer: SocketAddr = "192.0.2.4:12345".parse().unwrap(); let peer: SocketAddr = "192.0.2.4:12345".parse().unwrap();
// Even if we prefer user1, if user2 matches, it should succeed. // Even if we prefer user1, if user2 matches, it should succeed.
let res = handle_mtproto_handshake( let res = handle_mtproto_handshake(&base, tokio::io::empty(), tokio::io::sink(), peer, &config, &replay_checker, false, Some("user1")).await;
&base,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
false,
Some("user1"),
)
.await;
if let HandshakeResult::Success((_, _, success)) = res { if let HandshakeResult::Success((_, _, success)) = res {
assert_eq!(success.user, "user2"); assert_eq!(success.user, "user2");
} else { } else {
@ -287,30 +209,20 @@ async fn mtproto_handshake_concurrent_flood_stability() {
config.access.ignore_time_skew = true; config.access.ignore_time_skew = true;
let replay_checker = Arc::new(ReplayChecker::new(1024, Duration::from_secs(60))); let replay_checker = Arc::new(ReplayChecker::new(1024, Duration::from_secs(60)));
let config = Arc::new(config); let config = Arc::new(config);
let mut tasks = Vec::new(); let mut tasks = Vec::new();
for i in 0..50 { for i in 0..50 {
let base = base; let base = base;
let config = Arc::clone(&config); let config = Arc::clone(&config);
let replay_checker = Arc::clone(&replay_checker); let replay_checker = Arc::clone(&replay_checker);
let peer: SocketAddr = format!("192.0.2.{}:12345", (i % 254) + 1).parse().unwrap(); let peer: SocketAddr = format!("192.0.2.{}:12345", (i % 254) + 1).parse().unwrap();
tasks.push(tokio::spawn(async move { tasks.push(tokio::spawn(async move {
let res = handle_mtproto_handshake( let res = handle_mtproto_handshake(&base, tokio::io::empty(), tokio::io::sink(), peer, &config, &replay_checker, false, None).await;
&base,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
false,
None,
)
.await;
matches!(res, HandshakeResult::Success(_)) matches!(res, HandshakeResult::Success(_))
})); }));
} }
// We don't necessarily care if they all succeed (some might fail due to replay if they hit the same chunk), // We don't necessarily care if they all succeed (some might fail due to replay if they hit the same chunk),
// but the system must not panic or hang. // but the system must not panic or hang.
for task in tasks { for task in tasks {
@ -394,10 +306,7 @@ async fn mtproto_blackhat_mutation_corpus_never_panics_and_stays_fail_closed() {
.expect("fuzzed mutation must complete in bounded time"); .expect("fuzzed mutation must complete in bounded time");
assert!( assert!(
matches!( matches!(res, HandshakeResult::BadClient { .. } | HandshakeResult::Success(_)),
res,
HandshakeResult::BadClient { .. } | HandshakeResult::Success(_)
),
"mutation corpus must stay within explicit handshake outcomes" "mutation corpus must stay within explicit handshake outcomes"
); );
} }
@ -436,12 +345,7 @@ async fn mtproto_invalid_storm_over_cap_keeps_probe_map_hard_bounded() {
for i in 0..(AUTH_PROBE_TRACK_MAX_ENTRIES + 512) { for i in 0..(AUTH_PROBE_TRACK_MAX_ENTRIES + 512) {
let peer: SocketAddr = SocketAddr::new( let peer: SocketAddr = SocketAddr::new(
IpAddr::V4(Ipv4Addr::new( IpAddr::V4(Ipv4Addr::new(10, (i / 65535) as u8, ((i / 255) % 255) as u8, (i % 255 + 1) as u8)),
10,
(i / 65535) as u8,
((i / 255) % 255) as u8,
(i % 255 + 1) as u8,
)),
43000 + (i % 20000) as u16, 43000 + (i % 20000) as u16,
); );
let res = handle_mtproto_handshake( let res = handle_mtproto_handshake(
@ -458,7 +362,10 @@ async fn mtproto_invalid_storm_over_cap_keeps_probe_map_hard_bounded() {
assert!(matches!(res, HandshakeResult::BadClient { .. })); assert!(matches!(res, HandshakeResult::BadClient { .. }));
} }
let tracked = AUTH_PROBE_STATE.get().map(|state| state.len()).unwrap_or(0); let tracked = AUTH_PROBE_STATE
.get()
.map(|state| state.len())
.unwrap_or(0);
assert!( assert!(
tracked <= AUTH_PROBE_TRACK_MAX_ENTRIES, tracked <= AUTH_PROBE_TRACK_MAX_ENTRIES,
"probe map must remain bounded under invalid storm: {tracked}" "probe map must remain bounded under invalid storm: {tracked}"
@ -508,10 +415,7 @@ async fn mtproto_property_style_multi_bit_mutations_fail_closed_or_auth_only() {
.expect("mutation iteration must complete in bounded time"); .expect("mutation iteration must complete in bounded time");
assert!( assert!(
matches!( matches!(outcome, HandshakeResult::BadClient { .. } | HandshakeResult::Success(_)),
outcome,
HandshakeResult::BadClient { .. } | HandshakeResult::Success(_)
),
"mutations must remain fail-closed/auth-only" "mutations must remain fail-closed/auth-only"
); );
} }

View File

@ -6,7 +6,7 @@ use crate::protocol::constants::ProtoTag;
use crate::stats::ReplayChecker; use crate::stats::ReplayChecker;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::MutexGuard; use std::sync::MutexGuard;
use tokio::time::{Duration as TokioDuration, timeout}; use tokio::time::{timeout, Duration as TokioDuration};
fn make_mtproto_handshake_with_proto_bytes( fn make_mtproto_handshake_with_proto_bytes(
secret_hex: &str, secret_hex: &str,
@ -48,20 +48,14 @@ fn make_mtproto_handshake_with_proto_bytes(
handshake handshake
} }
fn make_valid_mtproto_handshake( fn make_valid_mtproto_handshake(secret_hex: &str, proto_tag: ProtoTag, dc_idx: i16) -> [u8; HANDSHAKE_LEN] {
secret_hex: &str,
proto_tag: ProtoTag,
dc_idx: i16,
) -> [u8; HANDSHAKE_LEN] {
make_mtproto_handshake_with_proto_bytes(secret_hex, proto_tag.to_bytes(), dc_idx) make_mtproto_handshake_with_proto_bytes(secret_hex, proto_tag.to_bytes(), dc_idx)
} }
fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig { fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig {
let mut cfg = ProxyConfig::default(); let mut cfg = ProxyConfig::default();
cfg.access.users.clear(); cfg.access.users.clear();
cfg.access cfg.access.users.insert("user".to_string(), secret_hex.to_string());
.users
.insert("user".to_string(), secret_hex.to_string());
cfg.access.ignore_time_skew = true; cfg.access.ignore_time_skew = true;
cfg.general.modes.secure = true; cfg.general.modes.secure = true;
cfg cfg
@ -146,9 +140,7 @@ async fn mtproto_handshake_fuzz_corpus_never_panics_and_stays_fail_closed() {
for _ in 0..32 { for _ in 0..32 {
let mut mutated = base; let mut mutated = base;
for _ in 0..4 { for _ in 0..4 {
seed = seed seed = seed.wrapping_mul(2862933555777941757).wrapping_add(3037000493);
.wrapping_mul(2862933555777941757)
.wrapping_add(3037000493);
let idx = SKIP_LEN + (seed as usize % (PREKEY_LEN + IV_LEN)); let idx = SKIP_LEN + (seed as usize % (PREKEY_LEN + IV_LEN));
mutated[idx] ^= ((seed >> 19) as u8).wrapping_add(1); mutated[idx] ^= ((seed >> 19) as u8).wrapping_add(1);
} }
@ -275,4 +267,4 @@ async fn mtproto_handshake_mixed_corpus_never_panics_and_exact_duplicates_are_re
} }
clear_auth_probe_state_for_testing(); clear_auth_probe_state_for_testing();
} }

View File

@ -1,8 +1,8 @@
use super::*; use super::*;
use crate::crypto::{sha256, sha256_hmac}; use crate::crypto::{sha256, sha256_hmac};
use dashmap::DashMap; use dashmap::DashMap;
use rand::rngs::StdRng;
use rand::{RngExt, SeedableRng}; use rand::{RngExt, SeedableRng};
use rand::rngs::StdRng;
use std::net::{IpAddr, Ipv4Addr}; use std::net::{IpAddr, Ipv4Addr};
use std::sync::Arc; use std::sync::Arc;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
@ -80,7 +80,8 @@ fn make_valid_tls_client_hello_with_alpn(
for i in 0..4 { for i in 0..4 {
digest[28 + i] ^= ts[i]; digest[28 + i] ^= ts[i];
} }
record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].copy_from_slice(&digest); record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN]
.copy_from_slice(&digest);
record record
} }
@ -150,7 +151,8 @@ fn make_valid_tls_client_hello_with_sni_and_alpn(
for i in 0..4 { for i in 0..4 {
digest[28 + i] ^= ts[i]; digest[28 + i] ^= ts[i];
} }
record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].copy_from_slice(&digest); record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN]
.copy_from_slice(&digest);
record record
} }
@ -165,11 +167,7 @@ fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig {
cfg cfg
} }
fn make_valid_mtproto_handshake( fn make_valid_mtproto_handshake(secret_hex: &str, proto_tag: ProtoTag, dc_idx: i16) -> [u8; HANDSHAKE_LEN] {
secret_hex: &str,
proto_tag: ProtoTag,
dc_idx: i16,
) -> [u8; HANDSHAKE_LEN] {
let secret = hex::decode(secret_hex).expect("secret hex must decode for mtproto test helper"); let secret = hex::decode(secret_hex).expect("secret hex must decode for mtproto test helper");
let mut handshake = [0x5Au8; HANDSHAKE_LEN]; let mut handshake = [0x5Au8; HANDSHAKE_LEN];
@ -330,10 +328,7 @@ fn test_generate_tg_nonce_fast_mode_embeds_reversed_client_enc_material() {
expected.extend_from_slice(&client_enc_iv.to_be_bytes()); expected.extend_from_slice(&client_enc_iv.to_be_bytes());
expected.reverse(); expected.reverse();
assert_eq!( assert_eq!(&nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN], expected.as_slice());
&nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN],
expected.as_slice()
);
} }
#[test] #[test]
@ -450,9 +445,7 @@ async fn tls_replay_with_ignore_time_skew_and_small_boot_timestamp_is_still_bloc
#[tokio::test] #[tokio::test]
async fn tls_replay_concurrent_identical_handshake_allows_exactly_one_success() { async fn tls_replay_concurrent_identical_handshake_allows_exactly_one_success() {
let secret = [0x77u8; 16]; let secret = [0x77u8; 16];
let config = Arc::new(test_config_with_secret_hex( let config = Arc::new(test_config_with_secret_hex("77777777777777777777777777777777"));
"77777777777777777777777777777777",
));
let replay_checker = Arc::new(ReplayChecker::new(4096, Duration::from_secs(60))); let replay_checker = Arc::new(ReplayChecker::new(4096, Duration::from_secs(60)));
let rng = Arc::new(SecureRandom::new()); let rng = Arc::new(SecureRandom::new());
let handshake = Arc::new(make_valid_tls_handshake(&secret, 0)); let handshake = Arc::new(make_valid_tls_handshake(&secret, 0));
@ -792,10 +785,10 @@ async fn mixed_secret_lengths_keep_valid_user_authenticating() {
.access .access
.users .users
.insert("broken_user".to_string(), "aa".to_string()); .insert("broken_user".to_string(), "aa".to_string());
config.access.users.insert( config
"valid_user".to_string(), .access
"22222222222222222222222222222222".to_string(), .users
); .insert("valid_user".to_string(), "22222222222222222222222222222222".to_string());
config.access.ignore_time_skew = true; config.access.ignore_time_skew = true;
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
@ -836,8 +829,12 @@ async fn tls_sni_preferred_user_hint_selects_matching_identity_first() {
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let rng = SecureRandom::new(); let rng = SecureRandom::new();
let peer: SocketAddr = "198.51.100.188:44326".parse().unwrap(); let peer: SocketAddr = "198.51.100.188:44326".parse().unwrap();
let handshake = let handshake = make_valid_tls_client_hello_with_sni_and_alpn(
make_valid_tls_client_hello_with_sni_and_alpn(&shared_secret, 0, "user-b", &[b"h2"]); &shared_secret,
0,
"user-b",
&[b"h2"],
);
let result = handle_tls_handshake( let result = handle_tls_handshake(
&handshake, &handshake,
@ -871,10 +868,10 @@ fn stress_decode_user_secrets_keeps_preferred_user_first_in_large_set() {
let secret_hex = "7f7f7f7f7f7f7f7f7f7f7f7f7f7f7f7f".to_string(); let secret_hex = "7f7f7f7f7f7f7f7f7f7f7f7f7f7f7f7f".to_string();
for i in 0..4096usize { for i in 0..4096usize {
config config.access.users.insert(
.access format!("decoy-{i:04}.example"),
.users secret_hex.clone(),
.insert(format!("decoy-{i:04}.example"), secret_hex.clone()); );
} }
config config
.access .access
@ -913,10 +910,10 @@ async fn stress_tls_sni_preferred_user_hint_scales_to_large_user_set() {
let secret_hex = "7f7f7f7f7f7f7f7f7f7f7f7f7f7f7f7f".to_string(); let secret_hex = "7f7f7f7f7f7f7f7f7f7f7f7f7f7f7f7f".to_string();
for i in 0..4096usize { for i in 0..4096usize {
config config.access.users.insert(
.access format!("decoy-{i:04}.example"),
.users secret_hex.clone(),
.insert(format!("decoy-{i:04}.example"), secret_hex.clone()); );
} }
config config
.access .access
@ -948,7 +945,8 @@ async fn stress_tls_sni_preferred_user_hint_scales_to_large_user_set() {
match result { match result {
HandshakeResult::Success((_, _, user)) => { HandshakeResult::Success((_, _, user)) => {
assert_eq!( assert_eq!(
user, preferred_user, user,
preferred_user,
"SNI preferred-user hint must remain stable under large user cardinality" "SNI preferred-user hint must remain stable under large user cardinality"
); );
} }
@ -1882,15 +1880,11 @@ fn auth_probe_ipv6_different_prefixes_use_distinct_buckets() {
"different IPv6 /64 prefixes must not share throttle buckets" "different IPv6 /64 prefixes must not share throttle buckets"
); );
assert_eq!( assert_eq!(
state state.get(&normalize_auth_probe_ip(ip_a)).map(|entry| entry.fail_streak),
.get(&normalize_auth_probe_ip(ip_a))
.map(|entry| entry.fail_streak),
Some(1) Some(1)
); );
assert_eq!( assert_eq!(
state state.get(&normalize_auth_probe_ip(ip_b)).map(|entry| entry.fail_streak),
.get(&normalize_auth_probe_ip(ip_b))
.map(|entry| entry.fail_streak),
Some(1) Some(1)
); );
} }
@ -1950,6 +1944,7 @@ fn auth_probe_eviction_offset_changes_with_time_component() {
); );
} }
#[test] #[test]
fn auth_probe_round_limited_overcap_eviction_marks_saturation_and_keeps_newcomer_trackable() { fn auth_probe_round_limited_overcap_eviction_marks_saturation_and_keeps_newcomer_trackable() {
let _guard = auth_probe_test_lock() let _guard = auth_probe_test_lock()
@ -1991,10 +1986,7 @@ fn auth_probe_round_limited_overcap_eviction_marks_saturation_and_keeps_newcomer
let newcomer = IpAddr::V4(Ipv4Addr::new(203, 0, 113, 40)); let newcomer = IpAddr::V4(Ipv4Addr::new(203, 0, 113, 40));
auth_probe_record_failure_with_state(&state, newcomer, now + Duration::from_millis(1)); auth_probe_record_failure_with_state(&state, newcomer, now + Duration::from_millis(1));
assert!( assert!(state.get(&newcomer).is_some(), "newcomer must still be tracked under over-cap pressure");
state.get(&newcomer).is_some(),
"newcomer must still be tracked under over-cap pressure"
);
assert!( assert!(
state.get(&sentinel).is_some(), state.get(&sentinel).is_some(),
"high fail-streak sentinel must survive round-limited eviction" "high fail-streak sentinel must survive round-limited eviction"
@ -2085,20 +2077,13 @@ fn stress_auth_probe_overcap_churn_does_not_starve_high_threat_sentinel_bucket()
((step >> 8) & 0xff) as u8, ((step >> 8) & 0xff) as u8,
(step & 0xff) as u8, (step & 0xff) as u8,
)); ));
auth_probe_record_failure_with_state( auth_probe_record_failure_with_state(&state, newcomer, base_now + Duration::from_millis(step as u64 + 1));
&state,
newcomer,
base_now + Duration::from_millis(step as u64 + 1),
);
assert!( assert!(
state.get(&sentinel).is_some(), state.get(&sentinel).is_some(),
"step {step}: high-threat sentinel must not be starved by newcomer churn" "step {step}: high-threat sentinel must not be starved by newcomer churn"
); );
assert!( assert!(state.get(&newcomer).is_some(), "step {step}: newcomer must be tracked");
state.get(&newcomer).is_some(),
"step {step}: newcomer must be tracked"
);
} }
} }
@ -2144,22 +2129,10 @@ fn light_fuzz_auth_probe_overcap_eviction_prefers_less_threatening_entries() {
); );
} }
let newcomer = IpAddr::V4(Ipv4Addr::new( let newcomer = IpAddr::V4(Ipv4Addr::new(203, 10, ((round >> 8) & 0xff) as u8, (round & 0xff) as u8));
203, auth_probe_record_failure_with_state(&state, newcomer, now + Duration::from_millis(round as u64 + 1));
10,
((round >> 8) & 0xff) as u8,
(round & 0xff) as u8,
));
auth_probe_record_failure_with_state(
&state,
newcomer,
now + Duration::from_millis(round as u64 + 1),
);
assert!( assert!(state.get(&newcomer).is_some(), "round {round}: newcomer should be tracked");
state.get(&newcomer).is_some(),
"round {round}: newcomer should be tracked"
);
assert!( assert!(
state.get(&sentinel).is_some(), state.get(&sentinel).is_some(),
"round {round}: high fail-streak sentinel should survive mixed low-threat pool" "round {round}: high fail-streak sentinel should survive mixed low-threat pool"
@ -2172,12 +2145,7 @@ fn light_fuzz_auth_probe_eviction_offset_is_deterministic_per_input_pair() {
let base = Instant::now(); let base = Instant::now();
for _ in 0..4096usize { for _ in 0..4096usize {
let ip = IpAddr::V4(Ipv4Addr::new( let ip = IpAddr::V4(Ipv4Addr::new(rng.random(), rng.random(), rng.random(), rng.random()));
rng.random(),
rng.random(),
rng.random(),
rng.random(),
));
let offset_ns = rng.random_range(0_u64..2_000_000); let offset_ns = rng.random_range(0_u64..2_000_000);
let when = base + Duration::from_nanos(offset_ns); let when = base + Duration::from_nanos(offset_ns);
@ -2276,7 +2244,8 @@ async fn auth_probe_concurrent_failures_do_not_lose_fail_streak_updates() {
let streak = auth_probe_fail_streak_for_testing(peer_ip) let streak = auth_probe_fail_streak_for_testing(peer_ip)
.expect("tracked peer must exist after concurrent failure burst"); .expect("tracked peer must exist after concurrent failure burst");
assert_eq!( assert_eq!(
streak as usize, tasks, streak as usize,
tasks,
"concurrent failures for one source must account every attempt" "concurrent failures for one source must account every attempt"
); );
} }
@ -2289,9 +2258,7 @@ async fn invalid_probe_noise_from_other_ips_does_not_break_valid_tls_handshake()
clear_auth_probe_state_for_testing(); clear_auth_probe_state_for_testing();
let secret = [0x31u8; 16]; let secret = [0x31u8; 16];
let config = Arc::new(test_config_with_secret_hex( let config = Arc::new(test_config_with_secret_hex("31313131313131313131313131313131"));
"31313131313131313131313131313131",
));
let replay_checker = Arc::new(ReplayChecker::new(4096, Duration::from_secs(60))); let replay_checker = Arc::new(ReplayChecker::new(4096, Duration::from_secs(60)));
let rng = Arc::new(SecureRandom::new()); let rng = Arc::new(SecureRandom::new());
let victim_peer: SocketAddr = "198.51.100.91:44391".parse().unwrap(); let victim_peer: SocketAddr = "198.51.100.91:44391".parse().unwrap();
@ -2878,10 +2845,7 @@ async fn saturation_grace_progression_tls_reaches_cap_then_stops_incrementing()
) )
.await; .await;
assert!(matches!(result, HandshakeResult::BadClient { .. })); assert!(matches!(result, HandshakeResult::BadClient { .. }));
assert_eq!( assert_eq!(auth_probe_fail_streak_for_testing(peer.ip()), Some(expected));
auth_probe_fail_streak_for_testing(peer.ip()),
Some(expected)
);
} }
{ {
@ -2960,10 +2924,7 @@ async fn saturation_grace_progression_mtproto_reaches_cap_then_stops_incrementin
) )
.await; .await;
assert!(matches!(result, HandshakeResult::BadClient { .. })); assert!(matches!(result, HandshakeResult::BadClient { .. }));
assert_eq!( assert_eq!(auth_probe_fail_streak_for_testing(peer.ip()), Some(expected));
auth_probe_fail_streak_for_testing(peer.ip()),
Some(expected)
);
} }
{ {
@ -3187,9 +3148,7 @@ async fn adversarial_same_peer_invalid_tls_storm_does_not_bypass_saturation_grac
.unwrap_or_else(|poisoned| poisoned.into_inner()); .unwrap_or_else(|poisoned| poisoned.into_inner());
clear_auth_probe_state_for_testing(); clear_auth_probe_state_for_testing();
let config = Arc::new(test_config_with_secret_hex( let config = Arc::new(test_config_with_secret_hex("75757575757575757575757575757575"));
"75757575757575757575757575757575",
));
let replay_checker = Arc::new(ReplayChecker::new(1024, Duration::from_secs(60))); let replay_checker = Arc::new(ReplayChecker::new(1024, Duration::from_secs(60)));
let rng = Arc::new(SecureRandom::new()); let rng = Arc::new(SecureRandom::new());
let peer: SocketAddr = "198.51.100.212:45212".parse().unwrap(); let peer: SocketAddr = "198.51.100.212:45212".parse().unwrap();
@ -3337,11 +3296,7 @@ async fn adversarial_saturation_burst_only_admits_valid_tls_and_mtproto_handshak
} }
let valid_tls = Arc::new(make_valid_tls_handshake(&secret, 0)); let valid_tls = Arc::new(make_valid_tls_handshake(&secret, 0));
let valid_mtproto = Arc::new(make_valid_mtproto_handshake( let valid_mtproto = Arc::new(make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 3));
secret_hex,
ProtoTag::Secure,
3,
));
let mut invalid_tls = vec![0x42u8; tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + 32]; let mut invalid_tls = vec![0x42u8; tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + 32];
invalid_tls[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = 32; invalid_tls[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = 32;
let invalid_tls = Arc::new(invalid_tls); let invalid_tls = Arc::new(invalid_tls);
@ -3413,9 +3368,7 @@ async fn adversarial_saturation_burst_only_admits_valid_tls_and_mtproto_handshak
match task.await.unwrap() { match task.await.unwrap() {
HandshakeResult::BadClient { .. } => bad_clients += 1, HandshakeResult::BadClient { .. } => bad_clients += 1,
HandshakeResult::Success(_) => panic!("invalid TLS probe unexpectedly authenticated"), HandshakeResult::Success(_) => panic!("invalid TLS probe unexpectedly authenticated"),
HandshakeResult::Error(err) => { HandshakeResult::Error(err) => panic!("unexpected error in invalid TLS saturation burst test: {err}"),
panic!("unexpected error in invalid TLS saturation burst test: {err}")
}
} }
} }
@ -3432,7 +3385,8 @@ async fn adversarial_saturation_burst_only_admits_valid_tls_and_mtproto_handshak
); );
assert_eq!( assert_eq!(
bad_clients, 48, bad_clients,
48,
"all invalid TLS probes in mixed saturation burst must be rejected" "all invalid TLS probes in mixed saturation burst must be rejected"
); );
} }

View File

@ -57,25 +57,6 @@ fn spread_u128(values: &[u128]) -> u128 {
max_v - min_v max_v - min_v
} }
fn interval_gap_usize(a: &BTreeSet<usize>, b: &BTreeSet<usize>) -> usize {
if a.is_empty() || b.is_empty() {
return 0;
}
let a_min = *a.iter().next().unwrap();
let a_max = *a.iter().next_back().unwrap();
let b_min = *b.iter().next().unwrap();
let b_max = *b.iter().next_back().unwrap();
if a_max < b_min {
b_min - a_max
} else if b_max < a_min {
a_min - b_max
} else {
0
}
}
async fn collect_timing_samples(path: PathClass, timing_norm_enabled: bool, n: usize) -> Vec<u128> { async fn collect_timing_samples(path: PathClass, timing_norm_enabled: bool, n: usize) -> Vec<u128> {
let mut out = Vec::with_capacity(n); let mut out = Vec::with_capacity(n);
for _ in 0..n { for _ in 0..n {
@ -285,40 +266,23 @@ async fn integration_ab_harness_envelope_and_blur_improve_obfuscation_vs_baselin
let baseline_overlap = baseline_a.intersection(&baseline_b).count(); let baseline_overlap = baseline_a.intersection(&baseline_b).count();
let hardened_overlap = hardened_a.intersection(&hardened_b).count(); let hardened_overlap = hardened_a.intersection(&hardened_b).count();
let baseline_gap = interval_gap_usize(&baseline_a, &baseline_b);
let hardened_gap = interval_gap_usize(&hardened_a, &hardened_b);
println!( println!(
"ab_harness_length baseline_overlap={} hardened_overlap={} baseline_gap={} hardened_gap={} baseline_a={} baseline_b={} hardened_a={} hardened_b={}", "ab_harness_length baseline_overlap={} hardened_overlap={} baseline_a={} baseline_b={} hardened_a={} hardened_b={}",
baseline_overlap, baseline_overlap,
hardened_overlap, hardened_overlap,
baseline_gap,
hardened_gap,
baseline_a.len(), baseline_a.len(),
baseline_b.len(), baseline_b.len(),
hardened_a.len(), hardened_a.len(),
hardened_b.len() hardened_b.len()
); );
assert_eq!( assert_eq!(baseline_overlap, 0, "baseline above-cap classes should be disjoint");
baseline_overlap, 0,
"baseline above-cap classes should be disjoint"
);
assert!( assert!(
hardened_a.len() > baseline_a.len() && hardened_b.len() > baseline_b.len(), hardened_overlap > baseline_overlap,
"above-cap blur should widen per-class emitted lengths: baseline_a={} baseline_b={} hardened_a={} hardened_b={}", "above-cap blur should increase cross-class overlap: baseline={} hardened={}",
baseline_a.len(),
baseline_b.len(),
hardened_a.len(),
hardened_b.len()
);
assert!(
hardened_overlap > baseline_overlap || hardened_gap < baseline_gap,
"above-cap blur should reduce class separability via direct overlap or tighter interval gap: baseline_overlap={} hardened_overlap={} baseline_gap={} hardened_gap={}",
baseline_overlap, baseline_overlap,
hardened_overlap, hardened_overlap
baseline_gap,
hardened_gap
); );
} }
@ -350,10 +314,7 @@ fn timing_classifier_helper_threshold_accuracy_drops_for_identical_sets() {
let a = vec![10u128, 11, 12, 13, 14]; let a = vec![10u128, 11, 12, 13, 14];
let b = vec![10u128, 11, 12, 13, 14]; let b = vec![10u128, 11, 12, 13, 14];
let acc = best_threshold_accuracy_u128(&a, &b); let acc = best_threshold_accuracy_u128(&a, &b);
assert!( assert!(acc <= 0.6, "identical sets should not be strongly separable");
acc <= 0.6,
"identical sets should not be strongly separable"
);
} }
#[test] #[test]
@ -375,10 +336,7 @@ async fn timing_classifier_baseline_connect_fail_vs_slow_backend_is_highly_separ
let slow = collect_timing_samples(PathClass::SlowBackend, false, 8).await; let slow = collect_timing_samples(PathClass::SlowBackend, false, 8).await;
let acc = best_threshold_accuracy_u128(&fail, &slow); let acc = best_threshold_accuracy_u128(&fail, &slow);
assert!( assert!(acc >= 0.80, "baseline timing classes should be separable enough");
acc >= 0.80,
"baseline timing classes should be separable enough"
);
} }
#[tokio::test] #[tokio::test]
@ -450,10 +408,7 @@ async fn timing_classifier_normalized_mean_bucket_delta_connect_fail_vs_connect_
let fail_mean = mean_ms(&fail); let fail_mean = mean_ms(&fail);
let success_mean = mean_ms(&success); let success_mean = mean_ms(&success);
let delta_bucket = ((fail_mean as i128 - success_mean as i128).abs()) / 20; let delta_bucket = ((fail_mean as i128 - success_mean as i128).abs()) / 20;
assert!( assert!(delta_bucket <= 3, "mean bucket delta too large: {delta_bucket}");
delta_bucket <= 3,
"mean bucket delta too large: {delta_bucket}"
);
} }
#[tokio::test] #[tokio::test]
@ -463,10 +418,7 @@ async fn timing_classifier_normalized_p95_bucket_delta_connect_success_vs_slow_i
let p95_success = percentile_ms(success, 95, 100); let p95_success = percentile_ms(success, 95, 100);
let p95_slow = percentile_ms(slow, 95, 100); let p95_slow = percentile_ms(slow, 95, 100);
let delta_bucket = ((p95_success as i128 - p95_slow as i128).abs()) / 20; let delta_bucket = ((p95_success as i128 - p95_slow as i128).abs()) / 20;
assert!( assert!(delta_bucket <= 4, "p95 bucket delta too large: {delta_bucket}");
delta_bucket <= 4,
"p95 bucket delta too large: {delta_bucket}"
);
} }
#[tokio::test] #[tokio::test]
@ -482,10 +434,7 @@ async fn timing_classifier_normalized_spread_is_not_worse_than_baseline_for_conn
} }
#[tokio::test] #[tokio::test]
async fn timing_classifier_light_fuzz_pairwise_bucketed_accuracy_stays_bounded_under_normalization() async fn timing_classifier_light_fuzz_pairwise_bucketed_accuracy_stays_bounded_under_normalization() {
{
const SAMPLE_COUNT: usize = 6;
let pairs = [ let pairs = [
(PathClass::ConnectFail, PathClass::ConnectSuccess), (PathClass::ConnectFail, PathClass::ConnectSuccess),
(PathClass::ConnectFail, PathClass::SlowBackend), (PathClass::ConnectFail, PathClass::SlowBackend),
@ -496,14 +445,12 @@ async fn timing_classifier_light_fuzz_pairwise_bucketed_accuracy_stays_bounded_u
let mut baseline_sum = 0.0f64; let mut baseline_sum = 0.0f64;
let mut hardened_sum = 0.0f64; let mut hardened_sum = 0.0f64;
let mut pair_count = 0usize; let mut pair_count = 0usize;
let acc_quant_step = 1.0 / (2 * SAMPLE_COUNT) as f64;
let tolerated_pair_regression = acc_quant_step + 0.03;
for (a, b) in pairs { for (a, b) in pairs {
let baseline_a = collect_timing_samples(a, false, SAMPLE_COUNT).await; let baseline_a = collect_timing_samples(a, false, 6).await;
let baseline_b = collect_timing_samples(b, false, SAMPLE_COUNT).await; let baseline_b = collect_timing_samples(b, false, 6).await;
let hardened_a = collect_timing_samples(a, true, SAMPLE_COUNT).await; let hardened_a = collect_timing_samples(a, true, 6).await;
let hardened_b = collect_timing_samples(b, true, SAMPLE_COUNT).await; let hardened_b = collect_timing_samples(b, true, 6).await;
let baseline_acc = best_threshold_accuracy_u128( let baseline_acc = best_threshold_accuracy_u128(
&bucketize_ms(&baseline_a, 20), &bucketize_ms(&baseline_a, 20),
@ -519,15 +466,11 @@ async fn timing_classifier_light_fuzz_pairwise_bucketed_accuracy_stays_bounded_u
// Guard hard only on informative baseline pairs. // Guard hard only on informative baseline pairs.
if baseline_acc >= 0.75 { if baseline_acc >= 0.75 {
assert!( assert!(
hardened_acc <= baseline_acc + tolerated_pair_regression, hardened_acc <= baseline_acc + 0.05,
"normalization should not materially worsen informative pair: baseline={baseline_acc:.3} hardened={hardened_acc:.3} tolerated={tolerated_pair_regression:.3}" "normalization should not materially worsen informative pair: baseline={baseline_acc:.3} hardened={hardened_acc:.3}"
); );
} }
println!(
"timing_classifier_pair baseline={baseline_acc:.3} hardened={hardened_acc:.3} tolerated_pair_regression={tolerated_pair_regression:.3}"
);
if hardened_acc + 0.05 <= baseline_acc { if hardened_acc + 0.05 <= baseline_acc {
meaningful_improvement_seen = true; meaningful_improvement_seen = true;
} }
@ -541,7 +484,7 @@ async fn timing_classifier_light_fuzz_pairwise_bucketed_accuracy_stays_bounded_u
let hardened_avg = hardened_sum / pair_count as f64; let hardened_avg = hardened_sum / pair_count as f64;
assert!( assert!(
hardened_avg <= baseline_avg + 0.10, hardened_avg <= baseline_avg + 0.08,
"normalization should not materially increase average pairwise separability: baseline_avg={baseline_avg:.3} hardened_avg={hardened_avg:.3}" "normalization should not materially increase average pairwise separability: baseline_avg={baseline_avg:.3} hardened_avg={hardened_avg:.3}"
); );
@ -561,10 +504,7 @@ async fn timing_classifier_stress_parallel_sampling_finishes_and_stays_bounded()
_ => PathClass::SlowBackend, _ => PathClass::SlowBackend,
}; };
let sample = measure_masking_duration_ms(class, true).await; let sample = measure_masking_duration_ms(class, true).await;
assert!( assert!((100..=1600).contains(&sample), "stress sample out of bounds: {sample}");
(100..=1600).contains(&sample),
"stress sample out of bounds: {sample}"
);
})); }));
} }

View File

@ -1,13 +1,13 @@
use super::*; use super::*;
use std::sync::Arc;
use tokio::io::duplex;
use tokio::net::TcpListener;
use tokio::time::{Instant, Duration};
use crate::config::ProxyConfig; use crate::config::ProxyConfig;
use crate::proxy::relay::relay_bidirectional; use crate::proxy::relay::relay_bidirectional;
use crate::stats::Stats; use crate::stats::Stats;
use crate::stats::beobachten::BeobachtenStore; use crate::stats::beobachten::BeobachtenStore;
use crate::stream::BufferPool; use crate::stream::BufferPool;
use std::sync::Arc;
use tokio::io::duplex;
use tokio::net::TcpListener;
use tokio::time::{Duration, Instant};
// ------------------------------------------------------------------ // ------------------------------------------------------------------
// Probing Indistinguishability (OWASP ASVS 5.1.7) // Probing Indistinguishability (OWASP ASVS 5.1.7)
@ -19,7 +19,7 @@ async fn masking_probes_indistinguishable_timing() {
config.censorship.mask = true; config.censorship.mask = true;
config.censorship.mask_host = Some("127.0.0.1".to_string()); config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = 80; // Should timeout/refuse config.censorship.mask_port = 80; // Should timeout/refuse
let peer: SocketAddr = "192.0.2.10:443".parse().unwrap(); let peer: SocketAddr = "192.0.2.10:443".parse().unwrap();
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let beobachten = BeobachtenStore::new(); let beobachten = BeobachtenStore::new();
@ -28,17 +28,14 @@ async fn masking_probes_indistinguishable_timing() {
let probes = vec![ let probes = vec![
(b"GET / HTTP/1.1\r\nHost: x\r\n\r\n".to_vec(), "HTTP"), (b"GET / HTTP/1.1\r\nHost: x\r\n\r\n".to_vec(), "HTTP"),
(b"SSH-2.0-probe".to_vec(), "SSH"), (b"SSH-2.0-probe".to_vec(), "SSH"),
( (vec![0x16, 0x03, 0x03, 0x00, 0x05, 0x01, 0x00, 0x00, 0x01, 0x00], "TLS-scanner"),
vec![0x16, 0x03, 0x03, 0x00, 0x05, 0x01, 0x00, 0x00, 0x01, 0x00],
"TLS-scanner",
),
(vec![0x42; 5], "port-scanner"), (vec![0x42; 5], "port-scanner"),
]; ];
for (probe, type_name) in probes { for (probe, type_name) in probes {
let (client_reader, _client_writer) = duplex(256); let (client_reader, _client_writer) = duplex(256);
let (_client_visible_reader, client_visible_writer) = duplex(256); let (_client_visible_reader, client_visible_writer) = duplex(256);
let start = Instant::now(); let start = Instant::now();
handle_bad_client( handle_bad_client(
client_reader, client_reader,
@ -48,17 +45,13 @@ async fn masking_probes_indistinguishable_timing() {
local_addr, local_addr,
&config, &config,
&beobachten, &beobachten,
) ).await;
.await;
let elapsed = start.elapsed(); let elapsed = start.elapsed();
// We expect any outcome to take roughly MASK_TIMEOUT (50ms in tests) // We expect any outcome to take roughly MASK_TIMEOUT (50ms in tests)
// to mask whether the backend was reachable or refused. // to mask whether the backend was reachable or refused.
assert!( assert!(elapsed >= Duration::from_millis(30), "Probe {type_name} finished too fast: {elapsed:?}");
elapsed >= Duration::from_millis(30),
"Probe {type_name} finished too fast: {elapsed:?}"
);
} }
} }
@ -83,7 +76,7 @@ async fn masking_budget_stress_under_load() {
let (_client_visible_reader, client_visible_writer) = duplex(256); let (_client_visible_reader, client_visible_writer) = duplex(256);
let config = config.clone(); let config = config.clone();
let beobachten = Arc::clone(&beobachten); let beobachten = Arc::clone(&beobachten);
tasks.push(tokio::spawn(async move { tasks.push(tokio::spawn(async move {
let start = Instant::now(); let start = Instant::now();
handle_bad_client( handle_bad_client(
@ -94,18 +87,14 @@ async fn masking_budget_stress_under_load() {
local_addr, local_addr,
&config, &config,
&beobachten, &beobachten,
) ).await;
.await;
start.elapsed() start.elapsed()
})); }));
} }
for task in tasks { for task in tasks {
let elapsed = task.await.unwrap(); let elapsed = task.await.unwrap();
assert!( assert!(elapsed >= Duration::from_millis(30), "Stress probe finished too fast: {elapsed:?}");
elapsed >= Duration::from_millis(30),
"Stress probe finished too fast: {elapsed:?}"
);
} }
} }
@ -119,10 +108,10 @@ fn test_detect_client_type_boundary_cases() {
assert_eq!(detect_client_type(&[0x42; 9]), "port-scanner"); assert_eq!(detect_client_type(&[0x42; 9]), "port-scanner");
// 10 bytes = unknown // 10 bytes = unknown
assert_eq!(detect_client_type(&[0x42; 10]), "unknown"); assert_eq!(detect_client_type(&[0x42; 10]), "unknown");
// HTTP verbs without trailing space // HTTP verbs without trailing space
assert_eq!(detect_client_type(b"GET/"), "port-scanner"); // because len < 10 assert_eq!(detect_client_type(b"GET/"), "port-scanner"); // because len < 10
assert_eq!(detect_client_type(b"GET /path"), "HTTP"); assert_eq!(detect_client_type(b"GET /path"), "HTTP");
} }
// ------------------------------------------------------------------ // ------------------------------------------------------------------
@ -144,9 +133,7 @@ async fn masking_slowloris_client_idle_timeout_rejected() {
assert_eq!(observed, initial); assert_eq!(observed, initial);
let mut drip = [0u8; 1]; let mut drip = [0u8; 1];
let drip_read = let drip_read = tokio::time::timeout(Duration::from_millis(220), stream.read_exact(&mut drip)).await;
tokio::time::timeout(Duration::from_millis(220), stream.read_exact(&mut drip))
.await;
assert!( assert!(
drip_read.is_err() || drip_read.unwrap().is_err(), drip_read.is_err() || drip_read.unwrap().is_err(),
"backend must not receive post-timeout slowloris drip bytes" "backend must not receive post-timeout slowloris drip bytes"
@ -196,31 +183,18 @@ async fn masking_fallback_down_mimics_timeout() {
config.censorship.mask = true; config.censorship.mask = true;
config.censorship.mask_host = Some("127.0.0.1".to_string()); config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = 1; // Unlikely port config.censorship.mask_port = 1; // Unlikely port
let (server_reader, server_writer) = duplex(1024); let (server_reader, server_writer) = duplex(1024);
let beobachten = BeobachtenStore::new(); let beobachten = BeobachtenStore::new();
let peer: SocketAddr = "192.0.2.12:12345".parse().unwrap(); let peer: SocketAddr = "192.0.2.12:12345".parse().unwrap();
let local: SocketAddr = "192.0.2.1:443".parse().unwrap(); let local: SocketAddr = "192.0.2.1:443".parse().unwrap();
let start = Instant::now(); let start = Instant::now();
handle_bad_client( handle_bad_client(server_reader, server_writer, b"GET / HTTP/1.1\r\n", peer, local, &config, &beobachten).await;
server_reader,
server_writer,
b"GET / HTTP/1.1\r\n",
peer,
local,
&config,
&beobachten,
)
.await;
let elapsed = start.elapsed(); let elapsed = start.elapsed();
// It should wait for MASK_TIMEOUT (50ms in tests) even if connection was refused immediately // It should wait for MASK_TIMEOUT (50ms in tests) even if connection was refused immediately
assert!( assert!(elapsed >= Duration::from_millis(40), "Must respect connect budget even on failure: {:?}", elapsed);
elapsed >= Duration::from_millis(40),
"Must respect connect budget even on failure: {:?}",
elapsed
);
} }
// ------------------------------------------------------------------ // ------------------------------------------------------------------
@ -231,13 +205,7 @@ async fn masking_fallback_down_mimics_timeout() {
async fn masking_ssrf_resolve_internal_ranges_blocked() { async fn masking_ssrf_resolve_internal_ranges_blocked() {
use crate::network::dns_overrides::resolve_socket_addr; use crate::network::dns_overrides::resolve_socket_addr;
let blocked_ips = [ let blocked_ips = ["127.0.0.1", "169.254.169.254", "10.0.0.1", "192.168.1.1", "0.0.0.0"];
"127.0.0.1",
"169.254.169.254",
"10.0.0.1",
"192.168.1.1",
"0.0.0.0",
];
for ip in blocked_ips { for ip in blocked_ips {
assert!( assert!(
@ -302,10 +270,7 @@ async fn masking_zero_length_initial_data_does_not_hang_or_panic() {
.await .await
.unwrap() .unwrap()
.unwrap(); .unwrap();
assert_eq!( assert_eq!(n, 0, "backend must observe clean EOF for empty initial payload");
n, 0,
"backend must observe clean EOF for empty initial payload"
);
}); });
let mut config = ProxyConfig::default(); let mut config = ProxyConfig::default();
@ -347,10 +312,7 @@ async fn masking_oversized_initial_payload_is_forwarded_verbatim() {
let (mut stream, _) = listener.accept().await.unwrap(); let (mut stream, _) = listener.accept().await.unwrap();
let mut observed = vec![0u8; payload.len()]; let mut observed = vec![0u8; payload.len()];
stream.read_exact(&mut observed).await.unwrap(); stream.read_exact(&mut observed).await.unwrap();
assert_eq!( assert_eq!(observed, payload, "large initial payload must stay byte-for-byte");
observed, payload,
"large initial payload must stay byte-for-byte"
);
} }
}); });
@ -529,10 +491,7 @@ async fn chaos_burst_reconnect_storm_for_masking_and_relay_concurrently() {
}); });
let mut observed = vec![0u8; expected_reply.len()]; let mut observed = vec![0u8; expected_reply.len()];
client_visible_reader client_visible_reader.read_exact(&mut observed).await.unwrap();
.read_exact(&mut observed)
.await
.unwrap();
assert_eq!(observed, expected_reply); assert_eq!(observed, expected_reply);
timeout(Duration::from_secs(2), handle) timeout(Duration::from_secs(2), handle)
@ -687,10 +646,7 @@ async fn chaos_burst_reconnect_storm_for_masking_and_relay_multiwave_soak() {
}); });
let mut observed = vec![0u8; expected_reply.len()]; let mut observed = vec![0u8; expected_reply.len()];
client_visible_reader client_visible_reader.read_exact(&mut observed).await.unwrap();
.read_exact(&mut observed)
.await
.unwrap();
assert_eq!(observed, expected_reply); assert_eq!(observed, expected_reply);
timeout(Duration::from_secs(3), handle) timeout(Duration::from_secs(3), handle)

View File

@ -1,107 +0,0 @@
use super::*;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::net::TcpListener;
use tokio::time::Duration;
async fn capture_forwarded_len_with_mode(
body_sent: usize,
close_client_after_write: bool,
aggressive_mode: bool,
above_cap_blur: bool,
above_cap_blur_max_bytes: usize,
) -> usize {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = backend_addr.port();
config.censorship.mask_shape_hardening = true;
config.censorship.mask_shape_hardening_aggressive_mode = aggressive_mode;
config.censorship.mask_shape_bucket_floor_bytes = 512;
config.censorship.mask_shape_bucket_cap_bytes = 4096;
config.censorship.mask_shape_above_cap_blur = above_cap_blur;
config.censorship.mask_shape_above_cap_blur_max_bytes = above_cap_blur_max_bytes;
let accept_task = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut got = Vec::new();
let _ = tokio::time::timeout(Duration::from_secs(2), stream.read_to_end(&mut got)).await;
got.len()
});
let (server_reader, mut client_writer) = duplex(64 * 1024);
let (_client_visible_reader, client_visible_writer) = duplex(64 * 1024);
let peer: SocketAddr = "198.51.100.248:57248".parse().unwrap();
let local: SocketAddr = "127.0.0.1:443".parse().unwrap();
let beobachten = BeobachtenStore::new();
let mut probe = vec![0u8; 5 + body_sent];
probe[0] = 0x16;
probe[1] = 0x03;
probe[2] = 0x01;
probe[3..5].copy_from_slice(&7000u16.to_be_bytes());
probe[5..].fill(0x31);
let fallback = tokio::spawn(async move {
handle_bad_client(
server_reader,
client_visible_writer,
&probe,
peer,
local,
&config,
&beobachten,
)
.await;
});
if close_client_after_write {
client_writer.shutdown().await.unwrap();
} else {
client_writer.write_all(b"keepalive").await.unwrap();
tokio::time::sleep(Duration::from_millis(170)).await;
drop(client_writer);
}
let _ = tokio::time::timeout(Duration::from_secs(4), fallback)
.await
.unwrap()
.unwrap();
tokio::time::timeout(Duration::from_secs(4), accept_task)
.await
.unwrap()
.unwrap()
}
#[tokio::test]
async fn aggressive_mode_shapes_backend_silent_non_eof_path() {
let body_sent = 17usize;
let floor = 512usize;
let legacy = capture_forwarded_len_with_mode(body_sent, false, false, false, 0).await;
let aggressive = capture_forwarded_len_with_mode(body_sent, false, true, false, 0).await;
assert!(legacy < floor, "legacy mode should keep timeout path unshaped");
assert!(
aggressive >= floor,
"aggressive mode must shape backend-silent non-EOF paths (aggressive={aggressive}, floor={floor})"
);
}
#[tokio::test]
async fn aggressive_mode_enforces_positive_above_cap_blur() {
let body_sent = 5000usize;
let base = 5 + body_sent;
for _ in 0..48 {
let observed = capture_forwarded_len_with_mode(body_sent, true, true, true, 1).await;
assert!(
observed > base,
"aggressive mode must not emit exact base length when blur is enabled (observed={observed}, base={base})"
);
}
}

View File

@ -1,14 +1,14 @@
use super::*; use super::*;
use crate::config::ProxyConfig; use crate::config::ProxyConfig;
use std::pin::Pin;
use std::sync::Arc; use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::atomic::{AtomicBool, Ordering};
use std::pin::Pin;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
use tokio::io::{AsyncBufReadExt, BufReader, duplex}; use tokio::io::{duplex, AsyncBufReadExt, BufReader};
use tokio::net::TcpListener; use tokio::net::TcpListener;
#[cfg(unix)] #[cfg(unix)]
use tokio::net::UnixListener; use tokio::net::UnixListener;
use tokio::time::{Duration, Instant, sleep, timeout}; use tokio::time::{Instant, sleep, timeout, Duration};
#[tokio::test] #[tokio::test]
async fn bad_client_probe_is_forwarded_verbatim_to_mask_backend() { async fn bad_client_probe_is_forwarded_verbatim_to_mask_backend() {
@ -56,10 +56,7 @@ async fn bad_client_probe_is_forwarded_verbatim_to_mask_backend() {
.await; .await;
let mut observed = vec![0u8; backend_reply.len()]; let mut observed = vec![0u8; backend_reply.len()];
client_visible_reader client_visible_reader.read_exact(&mut observed).await.unwrap();
.read_exact(&mut observed)
.await
.unwrap();
assert_eq!(observed, backend_reply); assert_eq!(observed, backend_reply);
accept_task.await.unwrap(); accept_task.await.unwrap();
} }
@ -111,10 +108,7 @@ async fn tls_scanner_probe_keeps_http_like_fallback_surface() {
.await; .await;
let mut observed = vec![0u8; backend_reply.len()]; let mut observed = vec![0u8; backend_reply.len()];
client_visible_reader client_visible_reader.read_exact(&mut observed).await.unwrap();
.read_exact(&mut observed)
.await
.unwrap();
assert_eq!(observed, backend_reply); assert_eq!(observed, backend_reply);
let snapshot = beobachten.snapshot_text(Duration::from_secs(60)); let snapshot = beobachten.snapshot_text(Duration::from_secs(60));
@ -153,8 +147,8 @@ fn build_mask_proxy_header_v2_matches_builder_output() {
let expected = ProxyProtocolV2Builder::new() let expected = ProxyProtocolV2Builder::new()
.with_addrs(peer, local_addr) .with_addrs(peer, local_addr)
.build(); .build();
let actual = let actual = build_mask_proxy_header(2, peer, local_addr)
build_mask_proxy_header(2, peer, local_addr).expect("v2 mode must produce a header"); .expect("v2 mode must produce a header");
assert_eq!(actual, expected, "v2 header bytes must be deterministic"); assert_eq!(actual, expected, "v2 header bytes must be deterministic");
} }
@ -165,8 +159,8 @@ fn build_mask_proxy_header_v1_mixed_ip_family_uses_generic_unknown_form() {
let local_addr: SocketAddr = "[2001:db8::1]:443".parse().unwrap(); let local_addr: SocketAddr = "[2001:db8::1]:443".parse().unwrap();
let expected = ProxyProtocolV1Builder::new().build(); let expected = ProxyProtocolV1Builder::new().build();
let actual = let actual = build_mask_proxy_header(1, peer, local_addr)
build_mask_proxy_header(1, peer, local_addr).expect("v1 mode must produce a header"); .expect("v1 mode must produce a header");
assert_eq!(actual, expected, "mixed-family v1 must use UNKNOWN form"); assert_eq!(actual, expected, "mixed-family v1 must use UNKNOWN form");
} }
@ -203,10 +197,7 @@ async fn beobachten_records_scanner_class_when_mask_is_disabled() {
client_reader_side.write_all(b"noise").await.unwrap(); client_reader_side.write_all(b"noise").await.unwrap();
drop(client_reader_side); drop(client_reader_side);
let beobachten = timeout(Duration::from_secs(3), task) let beobachten = timeout(Duration::from_secs(3), task).await.unwrap().unwrap();
.await
.unwrap()
.unwrap();
let snapshot = beobachten.snapshot_text(Duration::from_secs(60)); let snapshot = beobachten.snapshot_text(Duration::from_secs(60));
assert!(snapshot.contains("[SSH]")); assert!(snapshot.contains("[SSH]"));
assert!(snapshot.contains("203.0.113.99-1")); assert!(snapshot.contains("203.0.113.99-1"));
@ -250,10 +241,7 @@ async fn backend_unavailable_falls_back_to_silent_consume() {
client_reader_side.write_all(b"noise").await.unwrap(); client_reader_side.write_all(b"noise").await.unwrap();
drop(client_reader_side); drop(client_reader_side);
timeout(Duration::from_secs(3), task) timeout(Duration::from_secs(3), task).await.unwrap().unwrap();
.await
.unwrap()
.unwrap();
let mut buf = [0u8; 1]; let mut buf = [0u8; 1];
let n = timeout(Duration::from_secs(1), client_visible_reader.read(&mut buf)) let n = timeout(Duration::from_secs(1), client_visible_reader.read(&mut buf))
@ -405,9 +393,9 @@ async fn proxy_header_write_error_on_tcp_path_still_honors_coarse_outcome_budget
.await; .await;
}); });
timeout(Duration::from_millis(35), task).await.expect_err( timeout(Duration::from_millis(35), task)
"proxy-header write error path should remain inside coarse masking budget window", .await
); .expect_err("proxy-header write error path should remain inside coarse masking budget window");
assert!( assert!(
started.elapsed() >= Duration::from_millis(35), started.elapsed() >= Duration::from_millis(35),
"proxy-header write error path should avoid immediate-return timing signature" "proxy-header write error path should avoid immediate-return timing signature"
@ -462,9 +450,9 @@ async fn proxy_header_write_error_on_unix_path_still_honors_coarse_outcome_budge
.await; .await;
}); });
timeout(Duration::from_millis(35), task).await.expect_err( timeout(Duration::from_millis(35), task)
"unix proxy-header write error path should remain inside coarse masking budget window", .await
); .expect_err("unix proxy-header write error path should remain inside coarse masking budget window");
assert!( assert!(
started.elapsed() >= Duration::from_millis(35), started.elapsed() >= Duration::from_millis(35),
"unix proxy-header write error path should avoid immediate-return timing signature" "unix proxy-header write error path should avoid immediate-return timing signature"
@ -498,14 +486,8 @@ async fn unix_socket_proxy_protocol_v1_header_is_sent_before_probe() {
let mut header_line = Vec::new(); let mut header_line = Vec::new();
reader.read_until(b'\n', &mut header_line).await.unwrap(); reader.read_until(b'\n', &mut header_line).await.unwrap();
let header_text = String::from_utf8(header_line).unwrap(); let header_text = String::from_utf8(header_line).unwrap();
assert!( assert!(header_text.starts_with("PROXY "), "must start with PROXY prefix");
header_text.starts_with("PROXY "), assert!(header_text.ends_with("\r\n"), "v1 header must end with CRLF");
"must start with PROXY prefix"
);
assert!(
header_text.ends_with("\r\n"),
"v1 header must end with CRLF"
);
let mut received_probe = vec![0u8; probe.len()]; let mut received_probe = vec![0u8; probe.len()];
reader.read_exact(&mut received_probe).await.unwrap(); reader.read_exact(&mut received_probe).await.unwrap();
@ -541,10 +523,7 @@ async fn unix_socket_proxy_protocol_v1_header_is_sent_before_probe() {
.await; .await;
let mut observed = vec![0u8; backend_reply.len()]; let mut observed = vec![0u8; backend_reply.len()];
client_visible_reader client_visible_reader.read_exact(&mut observed).await.unwrap();
.read_exact(&mut observed)
.await
.unwrap();
assert_eq!(observed, backend_reply); assert_eq!(observed, backend_reply);
accept_task.await.unwrap(); accept_task.await.unwrap();
@ -573,10 +552,7 @@ async fn unix_socket_proxy_protocol_v2_header_is_sent_before_probe() {
let mut sig = [0u8; 12]; let mut sig = [0u8; 12];
stream.read_exact(&mut sig).await.unwrap(); stream.read_exact(&mut sig).await.unwrap();
assert_eq!( assert_eq!(&sig, b"\r\n\r\n\0\r\nQUIT\n", "v2 signature must match spec");
&sig, b"\r\n\r\n\0\r\nQUIT\n",
"v2 signature must match spec"
);
let mut fixed = [0u8; 4]; let mut fixed = [0u8; 4];
stream.read_exact(&mut fixed).await.unwrap(); stream.read_exact(&mut fixed).await.unwrap();
@ -617,10 +593,7 @@ async fn unix_socket_proxy_protocol_v2_header_is_sent_before_probe() {
.await; .await;
let mut observed = vec![0u8; backend_reply.len()]; let mut observed = vec![0u8; backend_reply.len()];
client_visible_reader client_visible_reader.read_exact(&mut observed).await.unwrap();
.read_exact(&mut observed)
.await
.unwrap();
assert_eq!(observed, backend_reply); assert_eq!(observed, backend_reply);
accept_task.await.unwrap(); accept_task.await.unwrap();
@ -920,16 +893,10 @@ async fn mask_disabled_consumes_client_data_without_response() {
.await; .await;
}); });
client_reader_side client_reader_side.write_all(b"untrusted payload").await.unwrap();
.write_all(b"untrusted payload")
.await
.unwrap();
drop(client_reader_side); drop(client_reader_side);
timeout(Duration::from_secs(3), task) timeout(Duration::from_secs(3), task).await.unwrap().unwrap();
.await
.unwrap()
.unwrap();
let mut buf = [0u8; 1]; let mut buf = [0u8; 1];
let n = timeout(Duration::from_secs(1), client_visible_reader.read(&mut buf)) let n = timeout(Duration::from_secs(1), client_visible_reader.read(&mut buf))
@ -995,10 +962,7 @@ async fn proxy_protocol_v1_header_is_sent_before_probe() {
.await; .await;
let mut observed = vec![0u8; backend_reply.len()]; let mut observed = vec![0u8; backend_reply.len()];
client_visible_reader client_visible_reader.read_exact(&mut observed).await.unwrap();
.read_exact(&mut observed)
.await
.unwrap();
assert_eq!(observed, backend_reply); assert_eq!(observed, backend_reply);
accept_task.await.unwrap(); accept_task.await.unwrap();
} }
@ -1062,10 +1026,7 @@ async fn proxy_protocol_v2_header_is_sent_before_probe() {
.await; .await;
let mut observed = vec![0u8; backend_reply.len()]; let mut observed = vec![0u8; backend_reply.len()];
client_visible_reader client_visible_reader.read_exact(&mut observed).await.unwrap();
.read_exact(&mut observed)
.await
.unwrap();
assert_eq!(observed, backend_reply); assert_eq!(observed, backend_reply);
accept_task.await.unwrap(); accept_task.await.unwrap();
} }
@ -1125,10 +1086,7 @@ async fn proxy_protocol_v1_mixed_family_falls_back_to_unknown_header() {
.await; .await;
let mut observed = vec![0u8; backend_reply.len()]; let mut observed = vec![0u8; backend_reply.len()];
client_visible_reader client_visible_reader.read_exact(&mut observed).await.unwrap();
.read_exact(&mut observed)
.await
.unwrap();
assert_eq!(observed, backend_reply); assert_eq!(observed, backend_reply);
accept_task.await.unwrap(); accept_task.await.unwrap();
} }
@ -1136,11 +1094,7 @@ async fn proxy_protocol_v1_mixed_family_falls_back_to_unknown_header() {
#[cfg(unix)] #[cfg(unix)]
#[tokio::test] #[tokio::test]
async fn unix_socket_mask_path_forwards_probe_and_response() { async fn unix_socket_mask_path_forwards_probe_and_response() {
let sock_path = format!( let sock_path = format!("/tmp/telemt-mask-test-{}-{}.sock", std::process::id(), rand::random::<u64>());
"/tmp/telemt-mask-test-{}-{}.sock",
std::process::id(),
rand::random::<u64>()
);
let _ = std::fs::remove_file(&sock_path); let _ = std::fs::remove_file(&sock_path);
let listener = UnixListener::bind(&sock_path).unwrap(); let listener = UnixListener::bind(&sock_path).unwrap();
@ -1184,10 +1138,7 @@ async fn unix_socket_mask_path_forwards_probe_and_response() {
.await; .await;
let mut observed = vec![0u8; backend_reply.len()]; let mut observed = vec![0u8; backend_reply.len()];
client_visible_reader client_visible_reader.read_exact(&mut observed).await.unwrap();
.read_exact(&mut observed)
.await
.unwrap();
assert_eq!(observed, backend_reply); assert_eq!(observed, backend_reply);
accept_task.await.unwrap(); accept_task.await.unwrap();
@ -1220,10 +1171,7 @@ async fn mask_disabled_slowloris_connection_is_closed_by_consume_timeout() {
.await; .await;
}); });
timeout(Duration::from_secs(1), task) timeout(Duration::from_secs(1), task).await.unwrap().unwrap();
.await
.unwrap()
.unwrap();
} }
#[tokio::test] #[tokio::test]
@ -1375,27 +1323,20 @@ async fn relay_to_mask_keeps_backend_to_client_flow_when_client_to_backend_stall
0, 0,
false, false,
0, 0,
false,
) )
.await; .await;
}); });
// Allow relay tasks to start, then emulate mask backend response. // Allow relay tasks to start, then emulate mask backend response.
sleep(Duration::from_millis(20)).await; sleep(Duration::from_millis(20)).await;
backend_feed_writer backend_feed_writer.write_all(b"HTTP/1.1 200 OK\r\n\r\n").await.unwrap();
.write_all(b"HTTP/1.1 200 OK\r\n\r\n")
.await
.unwrap();
backend_feed_writer.shutdown().await.unwrap(); backend_feed_writer.shutdown().await.unwrap();
let mut observed = vec![0u8; 19]; let mut observed = vec![0u8; 19];
timeout( timeout(Duration::from_secs(1), client_visible_reader.read_exact(&mut observed))
Duration::from_secs(1), .await
client_visible_reader.read_exact(&mut observed), .unwrap()
) .unwrap();
.await
.unwrap()
.unwrap();
assert_eq!(observed, b"HTTP/1.1 200 OK\r\n\r\n"); assert_eq!(observed, b"HTTP/1.1 200 OK\r\n\r\n");
relay.abort(); relay.abort();
@ -1453,23 +1394,14 @@ async fn relay_to_mask_preserves_backend_response_after_client_half_close() {
client_write.shutdown().await.unwrap(); client_write.shutdown().await.unwrap();
let mut observed_resp = vec![0u8; response.len()]; let mut observed_resp = vec![0u8; response.len()];
timeout( timeout(Duration::from_secs(1), client_visible_reader.read_exact(&mut observed_resp))
Duration::from_secs(1), .await
client_visible_reader.read_exact(&mut observed_resp), .unwrap()
) .unwrap();
.await
.unwrap()
.unwrap();
assert_eq!(observed_resp, response); assert_eq!(observed_resp, response);
timeout(Duration::from_secs(1), fallback_task) timeout(Duration::from_secs(1), fallback_task).await.unwrap().unwrap();
.await timeout(Duration::from_secs(1), backend_task).await.unwrap().unwrap();
.unwrap()
.unwrap();
timeout(Duration::from_secs(1), backend_task)
.await
.unwrap()
.unwrap();
} }
#[tokio::test] #[tokio::test]
@ -1505,7 +1437,6 @@ async fn relay_to_mask_timeout_cancels_and_drops_all_io_endpoints() {
0, 0,
false, false,
0, 0,
false,
), ),
) )
.await; .await;
@ -1643,11 +1574,9 @@ async fn timing_matrix_masking_classes_under_controlled_inputs() {
(mean, min, p95, max) (mean, min, p95, max)
} }
let (disabled_mean, disabled_min, disabled_p95, disabled_max) = let (disabled_mean, disabled_min, disabled_p95, disabled_max) = summarize(&mut disabled_samples);
summarize(&mut disabled_samples);
let (refused_mean, refused_min, refused_p95, refused_max) = summarize(&mut refused_samples); let (refused_mean, refused_min, refused_p95, refused_max) = summarize(&mut refused_samples);
let (reachable_mean, reachable_min, reachable_p95, reachable_max) = let (reachable_mean, reachable_min, reachable_p95, reachable_max) = summarize(&mut reachable_samples);
summarize(&mut reachable_samples);
println!( println!(
"TIMING_MATRIX masking class=disabled_eof mean_ms={:.2} min_ms={} p95_ms={} max_ms={} bucket_mean={}", "TIMING_MATRIX masking class=disabled_eof mean_ms={:.2} min_ms={} p95_ms={} max_ms={} bucket_mean={}",
@ -1769,10 +1698,7 @@ async fn reachable_backend_one_response_then_silence_is_cut_by_idle_timeout() {
let elapsed = started.elapsed(); let elapsed = started.elapsed();
let mut observed = vec![0u8; response.len()]; let mut observed = vec![0u8; response.len()];
client_visible_reader client_visible_reader.read_exact(&mut observed).await.unwrap();
.read_exact(&mut observed)
.await
.unwrap();
assert_eq!(observed, response); assert_eq!(observed, response);
assert!( assert!(
elapsed < Duration::from_millis(190), elapsed < Duration::from_millis(190),
@ -1837,9 +1763,6 @@ async fn adversarial_client_drip_feed_longer_than_idle_timeout_is_cut_off() {
let _ = client_writer_side.write_all(b"X").await; let _ = client_writer_side.write_all(b"X").await;
drop(client_writer_side); drop(client_writer_side);
timeout(Duration::from_secs(1), relay_task) timeout(Duration::from_secs(1), relay_task).await.unwrap().unwrap();
.await
.unwrap()
.unwrap();
accept_task.await.unwrap(); accept_task.await.unwrap();
} }

View File

@ -1,5 +1,5 @@
use super::*; use super::*;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio::time::Duration; use tokio::time::Duration;

View File

@ -1,182 +0,0 @@
use super::*;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::net::TcpListener;
use tokio::time::Duration;
async fn capture_forwarded_len_with_optional_eof(
body_sent: usize,
shape_hardening: bool,
above_cap_blur: bool,
above_cap_blur_max_bytes: usize,
close_client_after_write: bool,
) -> usize {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = backend_addr.port();
config.censorship.mask_shape_hardening = shape_hardening;
config.censorship.mask_shape_bucket_floor_bytes = 512;
config.censorship.mask_shape_bucket_cap_bytes = 4096;
config.censorship.mask_shape_above_cap_blur = above_cap_blur;
config.censorship.mask_shape_above_cap_blur_max_bytes = above_cap_blur_max_bytes;
let accept_task = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut got = Vec::new();
let _ = tokio::time::timeout(Duration::from_secs(2), stream.read_to_end(&mut got)).await;
got.len()
});
let (server_reader, mut client_writer) = duplex(64 * 1024);
let (_client_visible_reader, client_visible_writer) = duplex(64 * 1024);
let peer: SocketAddr = "198.51.100.241:57241".parse().unwrap();
let local: SocketAddr = "127.0.0.1:443".parse().unwrap();
let beobachten = BeobachtenStore::new();
let mut probe = vec![0u8; 5 + body_sent];
probe[0] = 0x16;
probe[1] = 0x03;
probe[2] = 0x01;
probe[3..5].copy_from_slice(&7000u16.to_be_bytes());
probe[5..].fill(0x73);
let fallback = tokio::spawn(async move {
handle_bad_client(
server_reader,
client_visible_writer,
&probe,
peer,
local,
&config,
&beobachten,
)
.await;
});
if close_client_after_write {
client_writer.shutdown().await.unwrap();
} else {
client_writer.write_all(b"keepalive").await.unwrap();
tokio::time::sleep(Duration::from_millis(170)).await;
drop(client_writer);
}
let _ = tokio::time::timeout(Duration::from_secs(4), fallback)
.await
.unwrap()
.unwrap();
tokio::time::timeout(Duration::from_secs(4), accept_task)
.await
.unwrap()
.unwrap()
}
#[tokio::test]
#[ignore = "red-team detector: shaping on non-EOF timeout path is disabled by design to prevent post-timeout tail leaks"]
async fn security_shape_padding_applies_without_client_eof_when_backend_silent() {
let body_sent = 17usize;
let hardened_floor = 512usize;
let with_eof = capture_forwarded_len_with_optional_eof(body_sent, true, false, 0, true).await;
let without_eof =
capture_forwarded_len_with_optional_eof(body_sent, true, false, 0, false).await;
assert!(
with_eof >= hardened_floor,
"EOF path should be shaped to floor (with_eof={with_eof}, floor={hardened_floor})"
);
assert!(
without_eof >= hardened_floor,
"non-EOF path should also be shaped when backend is silent (without_eof={without_eof}, floor={hardened_floor})"
);
}
#[tokio::test]
#[ignore = "red-team detector: blur currently allows zero-extra sample by design within [0..=max] bound"]
async fn security_above_cap_blur_never_emits_exact_base_length() {
let body_sent = 5000usize;
let base = 5 + body_sent;
let max_blur = 1usize;
for _ in 0..64 {
let observed =
capture_forwarded_len_with_optional_eof(body_sent, true, true, max_blur, true).await;
assert!(
observed > base,
"above-cap blur must add at least one byte when enabled (observed={observed}, base={base})"
);
}
}
#[tokio::test]
#[ignore = "red-team detector: shape padding currently depends on EOF, enabling idle-timeout bypass probes"]
async fn redteam_detector_shape_padding_must_not_depend_on_client_eof() {
let body_sent = 17usize;
let hardened_floor = 512usize;
let with_eof = capture_forwarded_len_with_optional_eof(body_sent, true, false, 0, true).await;
let without_eof =
capture_forwarded_len_with_optional_eof(body_sent, true, false, 0, false).await;
assert!(
with_eof >= hardened_floor,
"sanity check failed: EOF path should be shaped to floor (with_eof={with_eof}, floor={hardened_floor})"
);
assert!(
without_eof >= hardened_floor,
"strict anti-probing model expects shaping even without EOF; observed without_eof={without_eof}, floor={hardened_floor}"
);
}
#[tokio::test]
#[ignore = "red-team detector: zero-extra above-cap blur samples leak exact class boundary"]
async fn redteam_detector_above_cap_blur_must_never_emit_exact_base_length() {
let body_sent = 5000usize;
let base = 5 + body_sent;
let mut saw_exact_base = false;
let max_blur = 1usize;
for _ in 0..96 {
let observed =
capture_forwarded_len_with_optional_eof(body_sent, true, true, max_blur, true).await;
if observed == base {
saw_exact_base = true;
break;
}
}
assert!(
!saw_exact_base,
"strict anti-classifier model expects >0 blur always; observed exact base length leaks class"
);
}
#[tokio::test]
#[ignore = "red-team detector: disjoint above-cap ranges enable near-perfect size-class classification"]
async fn redteam_detector_above_cap_blur_ranges_for_far_classes_should_overlap() {
let mut a_min = usize::MAX;
let mut a_max = 0usize;
let mut b_min = usize::MAX;
let mut b_max = 0usize;
for _ in 0..48 {
let a = capture_forwarded_len_with_optional_eof(5000, true, true, 64, true).await;
let b = capture_forwarded_len_with_optional_eof(7000, true, true, 64, true).await;
a_min = a_min.min(a);
a_max = a_max.max(a);
b_min = b_min.min(b);
b_max = b_max.max(b);
}
let overlap = a_min <= b_max && b_min <= a_max;
assert!(
overlap,
"strict anti-classifier model expects overlapping output bands; class_a=[{a_min},{a_max}] class_b=[{b_min},{b_max}]"
);
}

View File

@ -1,5 +1,5 @@
use super::*; use super::*;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio::time::Duration; use tokio::time::Duration;
@ -90,7 +90,9 @@ fn nearest_centroid_classifier_accuracy(
samples_b: &[usize], samples_b: &[usize],
samples_c: &[usize], samples_c: &[usize],
) -> f64 { ) -> f64 {
let mean = |xs: &[usize]| -> f64 { xs.iter().copied().sum::<usize>() as f64 / xs.len() as f64 }; let mean = |xs: &[usize]| -> f64 {
xs.iter().copied().sum::<usize>() as f64 / xs.len() as f64
};
let ca = mean(samples_a); let ca = mean(samples_a);
let cb = mean(samples_b); let cb = mean(samples_b);
@ -102,7 +104,11 @@ fn nearest_centroid_classifier_accuracy(
for &x in samples_a { for &x in samples_a {
total += 1; total += 1;
let xf = x as f64; let xf = x as f64;
let d = [(xf - ca).abs(), (xf - cb).abs(), (xf - cc).abs()]; let d = [
(xf - ca).abs(),
(xf - cb).abs(),
(xf - cc).abs(),
];
if d[0] <= d[1] && d[0] <= d[2] { if d[0] <= d[1] && d[0] <= d[2] {
correct += 1; correct += 1;
} }
@ -111,7 +117,11 @@ fn nearest_centroid_classifier_accuracy(
for &x in samples_b { for &x in samples_b {
total += 1; total += 1;
let xf = x as f64; let xf = x as f64;
let d = [(xf - ca).abs(), (xf - cb).abs(), (xf - cc).abs()]; let d = [
(xf - ca).abs(),
(xf - cb).abs(),
(xf - cc).abs(),
];
if d[1] <= d[0] && d[1] <= d[2] { if d[1] <= d[0] && d[1] <= d[2] {
correct += 1; correct += 1;
} }
@ -120,7 +130,11 @@ fn nearest_centroid_classifier_accuracy(
for &x in samples_c { for &x in samples_c {
total += 1; total += 1;
let xf = x as f64; let xf = x as f64;
let d = [(xf - ca).abs(), (xf - cb).abs(), (xf - cc).abs()]; let d = [
(xf - ca).abs(),
(xf - cb).abs(),
(xf - cc).abs(),
];
if d[2] <= d[0] && d[2] <= d[1] { if d[2] <= d[0] && d[2] <= d[1] {
correct += 1; correct += 1;
} }
@ -152,10 +166,7 @@ async fn masking_shape_classifier_resistance_blur_reduces_threshold_attack_accur
let hardened_acc = best_threshold_accuracy(&hardened_a, &hardened_b); let hardened_acc = best_threshold_accuracy(&hardened_a, &hardened_b);
// Baseline classes are deterministic/non-overlapping -> near-perfect threshold attack. // Baseline classes are deterministic/non-overlapping -> near-perfect threshold attack.
assert!( assert!(baseline_acc >= 0.99, "baseline separability unexpectedly low: {baseline_acc:.3}");
baseline_acc >= 0.99,
"baseline separability unexpectedly low: {baseline_acc:.3}"
);
// Blur must materially reduce the best one-dimensional length classifier. // Blur must materially reduce the best one-dimensional length classifier.
assert!( assert!(
hardened_acc <= 0.90, hardened_acc <= 0.90,
@ -236,11 +247,7 @@ async fn masking_shape_classifier_resistance_edge_max_extra_one_has_two_point_su
seen.insert(observed); seen.insert(observed);
} }
assert_eq!( assert_eq!(seen.len(), 2, "both support points should appear under repeated sampling");
seen.len(),
2,
"both support points should appear under repeated sampling"
);
} }
#[tokio::test] #[tokio::test]
@ -255,25 +262,13 @@ async fn masking_shape_classifier_resistance_negative_blur_without_shape_hardeni
bs_observed.insert(capture_forwarded_len(BODY_B, false, true, 96).await); bs_observed.insert(capture_forwarded_len(BODY_B, false, true, 96).await);
} }
assert_eq!( assert_eq!(as_observed.len(), 1, "without shape hardening class A must stay deterministic");
as_observed.len(), assert_eq!(bs_observed.len(), 1, "without shape hardening class B must stay deterministic");
1, assert_ne!(as_observed, bs_observed, "distinct classes should remain separable without shaping");
"without shape hardening class A must stay deterministic"
);
assert_eq!(
bs_observed.len(),
1,
"without shape hardening class B must stay deterministic"
);
assert_ne!(
as_observed, bs_observed,
"distinct classes should remain separable without shaping"
);
} }
#[tokio::test] #[tokio::test]
async fn masking_shape_classifier_resistance_adversarial_three_class_centroid_attack_degrades_with_blur() async fn masking_shape_classifier_resistance_adversarial_three_class_centroid_attack_degrades_with_blur() {
{
const SAMPLES: usize = 80; const SAMPLES: usize = 80;
const MAX_EXTRA: usize = 96; const MAX_EXTRA: usize = 96;
const C1: usize = 5000; const C1: usize = 5000;
@ -300,23 +295,13 @@ async fn masking_shape_classifier_resistance_adversarial_three_class_centroid_at
let base_acc = nearest_centroid_classifier_accuracy(&base1, &base2, &base3); let base_acc = nearest_centroid_classifier_accuracy(&base1, &base2, &base3);
let hard_acc = nearest_centroid_classifier_accuracy(&hard1, &hard2, &hard3); let hard_acc = nearest_centroid_classifier_accuracy(&hard1, &hard2, &hard3);
assert!( assert!(base_acc >= 0.99, "baseline centroid separability should be near-perfect");
base_acc >= 0.99, assert!(hard_acc <= 0.88, "blur should materially degrade 3-class centroid attack");
"baseline centroid separability should be near-perfect" assert!(hard_acc <= base_acc - 0.1, "accuracy drop should be meaningful");
);
assert!(
hard_acc <= 0.88,
"blur should materially degrade 3-class centroid attack"
);
assert!(
hard_acc <= base_acc - 0.1,
"accuracy drop should be meaningful"
);
} }
#[tokio::test] #[tokio::test]
async fn masking_shape_classifier_resistance_light_fuzz_bounds_hold_for_randomized_above_cap_campaign() async fn masking_shape_classifier_resistance_light_fuzz_bounds_hold_for_randomized_above_cap_campaign() {
{
let mut s: u64 = 0xDEAD_BEEF_CAFE_BABE; let mut s: u64 = 0xDEAD_BEEF_CAFE_BABE;
for _ in 0..96 { for _ in 0..96 {
s ^= s << 7; s ^= s << 7;

View File

@ -1,6 +1,6 @@
use super::*; use super::*;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex, empty, sink}; use tokio::io::{duplex, empty, sink, AsyncReadExt, AsyncWriteExt};
use tokio::time::{Duration, sleep, timeout}; use tokio::time::{sleep, timeout, Duration};
fn oracle_len( fn oracle_len(
total_sent: usize, total_sent: usize,
@ -42,7 +42,6 @@ async fn run_relay_case(
cap, cap,
above_cap_blur, above_cap_blur,
above_cap_blur_max_bytes, above_cap_blur_max_bytes,
false,
) )
.await; .await;
}); });
@ -55,23 +54,17 @@ async fn run_relay_case(
client_writer.shutdown().await.unwrap(); client_writer.shutdown().await.unwrap();
} }
timeout(Duration::from_secs(2), relay) timeout(Duration::from_secs(2), relay).await.unwrap().unwrap();
.await
.unwrap()
.unwrap();
if !close_client { if !close_client {
drop(client_writer); drop(client_writer);
} }
let mut observed = Vec::new(); let mut observed = Vec::new();
timeout( timeout(Duration::from_secs(2), mask_observer.read_to_end(&mut observed))
Duration::from_secs(2), .await
mask_observer.read_to_end(&mut observed), .unwrap()
) .unwrap();
.await
.unwrap()
.unwrap();
observed observed
} }
@ -104,29 +97,12 @@ async fn masking_shape_guard_positive_clean_eof_path_shapes_and_preserves_prefix
let extra = vec![0x55; 300]; let extra = vec![0x55; 300];
let total = initial.len() + extra.len(); let total = initial.len() + extra.len();
let observed = run_relay_case( let observed = run_relay_case(initial.clone(), extra.clone(), true, true, 512, 4096, false, 0).await;
initial.clone(),
extra.clone(),
true,
true,
512,
4096,
false,
0,
)
.await;
let expected_len = oracle_len(total, true, true, initial.len(), 512, 4096); let expected_len = oracle_len(total, true, true, initial.len(), 512, 4096);
assert_eq!( assert_eq!(observed.len(), expected_len, "clean EOF path must be bucket-shaped");
observed.len(),
expected_len,
"clean EOF path must be bucket-shaped"
);
assert_eq!(&observed[..initial.len()], initial.as_slice()); assert_eq!(&observed[..initial.len()], initial.as_slice());
assert_eq!( assert_eq!(&observed[initial.len()..(initial.len() + extra.len())], extra.as_slice());
&observed[initial.len()..(initial.len() + extra.len())],
extra.as_slice()
);
} }
#[tokio::test] #[tokio::test]
@ -136,11 +112,7 @@ async fn masking_shape_guard_edge_empty_initial_remains_transparent_under_clean_
let observed = run_relay_case(initial, extra.clone(), true, true, 512, 4096, false, 0).await; let observed = run_relay_case(initial, extra.clone(), true, true, 512, 4096, false, 0).await;
assert_eq!( assert_eq!(observed.len(), extra.len(), "empty initial_data must never trigger shaping");
observed.len(),
extra.len(),
"empty initial_data must never trigger shaping"
);
assert_eq!(observed, extra); assert_eq!(observed, extra);
} }
@ -240,19 +212,13 @@ async fn masking_shape_guard_stress_parallel_mixed_sessions_keep_oracle_and_no_h
assert_eq!(&observed[..initial_len], initial.as_slice()); assert_eq!(&observed[..initial_len], initial.as_slice());
} }
if extra_len > 0 { if extra_len > 0 {
assert_eq!( assert_eq!(&observed[initial_len..(initial_len + extra_len)], extra.as_slice());
&observed[initial_len..(initial_len + extra_len)],
extra.as_slice()
);
} }
})); }));
} }
for task in tasks { for task in tasks {
timeout(Duration::from_secs(3), task) timeout(Duration::from_secs(3), task).await.unwrap().unwrap();
.await
.unwrap()
.unwrap();
} }
} }
@ -272,10 +238,7 @@ async fn masking_shape_guard_integration_slow_drip_timeout_is_cut_without_tail_l
let mut one = [0u8; 1]; let mut one = [0u8; 1];
let r = timeout(Duration::from_millis(220), stream.read_exact(&mut one)).await; let r = timeout(Duration::from_millis(220), stream.read_exact(&mut one)).await;
assert!( assert!(r.is_err() || r.unwrap().is_err(), "no post-timeout drip/tail may reach backend");
r.is_err() || r.unwrap().is_err(),
"no post-timeout drip/tail may reach backend"
);
} }
}); });
@ -311,14 +274,8 @@ async fn masking_shape_guard_integration_slow_drip_timeout_is_cut_without_tail_l
sleep(Duration::from_millis(160)).await; sleep(Duration::from_millis(160)).await;
let _ = client_writer.write_all(b"X").await; let _ = client_writer.write_all(b"X").await;
timeout(Duration::from_secs(2), relay) timeout(Duration::from_secs(2), relay).await.unwrap().unwrap();
.await timeout(Duration::from_secs(2), accept_task).await.unwrap().unwrap();
.unwrap()
.unwrap();
timeout(Duration::from_secs(2), accept_task)
.await
.unwrap()
.unwrap();
} }
#[tokio::test] #[tokio::test]
@ -395,10 +352,7 @@ async fn masking_shape_guard_above_cap_blur_parallel_stress_keeps_bounds() {
} }
for task in tasks { for task in tasks {
timeout(Duration::from_secs(3), task) timeout(Duration::from_secs(3), task).await.unwrap().unwrap();
.await
.unwrap()
.unwrap();
} }
} }

View File

@ -1,7 +1,7 @@
use super::*; use super::*;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio::time::{Duration, timeout}; use tokio::time::{timeout, Duration};
#[tokio::test] #[tokio::test]
async fn shape_guard_empty_initial_data_keeps_transparent_length_on_clean_eof() { async fn shape_guard_empty_initial_data_keeps_transparent_length_on_clean_eof() {
@ -15,10 +15,7 @@ async fn shape_guard_empty_initial_data_keeps_transparent_length_on_clean_eof()
let (mut stream, _) = listener.accept().await.unwrap(); let (mut stream, _) = listener.accept().await.unwrap();
let mut got = Vec::new(); let mut got = Vec::new();
stream.read_to_end(&mut got).await.unwrap(); stream.read_to_end(&mut got).await.unwrap();
assert_eq!( assert_eq!(got, expected, "empty initial_data path must not inject shape padding");
got, expected,
"empty initial_data path must not inject shape padding"
);
} }
}); });
@ -54,14 +51,8 @@ async fn shape_guard_empty_initial_data_keeps_transparent_length_on_clean_eof()
client_writer.write_all(&client_payload).await.unwrap(); client_writer.write_all(&client_payload).await.unwrap();
client_writer.shutdown().await.unwrap(); client_writer.shutdown().await.unwrap();
timeout(Duration::from_secs(2), relay_task) timeout(Duration::from_secs(2), relay_task).await.unwrap().unwrap();
.await timeout(Duration::from_secs(2), accept_task).await.unwrap().unwrap();
.unwrap()
.unwrap();
timeout(Duration::from_secs(2), accept_task)
.await
.unwrap()
.unwrap();
} }
#[tokio::test] #[tokio::test]
@ -114,10 +105,7 @@ async fn shape_guard_timeout_exit_does_not_append_padding_after_initial_probe()
) )
.await; .await;
timeout(Duration::from_secs(2), accept_task) timeout(Duration::from_secs(2), accept_task).await.unwrap().unwrap();
.await
.unwrap()
.unwrap();
} }
#[tokio::test] #[tokio::test]
@ -138,11 +126,7 @@ async fn shape_guard_clean_eof_with_nonempty_initial_still_applies_bucket_paddin
let expected_prefix_len = initial.len() + extra.len(); let expected_prefix_len = initial.len() + extra.len();
assert_eq!(&got[..initial.len()], initial.as_slice()); assert_eq!(&got[..initial.len()], initial.as_slice());
assert_eq!(&got[initial.len()..expected_prefix_len], extra.as_slice()); assert_eq!(&got[initial.len()..expected_prefix_len], extra.as_slice());
assert_eq!( assert_eq!(got.len(), 512, "clean EOF path should still shape to floor bucket");
got.len(),
512,
"clean EOF path should still shape to floor bucket"
);
} }
}); });
@ -178,12 +162,6 @@ async fn shape_guard_clean_eof_with_nonempty_initial_still_applies_bucket_paddin
client_writer.write_all(&extra).await.unwrap(); client_writer.write_all(&extra).await.unwrap();
client_writer.shutdown().await.unwrap(); client_writer.shutdown().await.unwrap();
timeout(Duration::from_secs(2), relay_task) timeout(Duration::from_secs(2), relay_task).await.unwrap().unwrap();
.await timeout(Duration::from_secs(2), accept_task).await.unwrap().unwrap();
.unwrap()
.unwrap();
timeout(Duration::from_secs(2), accept_task)
.await
.unwrap()
.unwrap();
} }

View File

@ -1,5 +1,5 @@
use super::*; use super::*;
use tokio::io::{AsyncReadExt, AsyncWrite, duplex, empty, sink}; use tokio::io::{duplex, empty, sink, AsyncReadExt, AsyncWrite};
struct CountingWriter { struct CountingWriter {
written: usize, written: usize,
@ -46,24 +46,21 @@ fn shape_bucket_clamps_to_cap_when_next_power_of_two_exceeds_cap() {
fn shape_bucket_never_drops_below_total_for_valid_ranges() { fn shape_bucket_never_drops_below_total_for_valid_ranges() {
for total in [1usize, 32, 127, 512, 999, 1000, 1001, 1499, 1500, 1501] { for total in [1usize, 32, 127, 512, 999, 1000, 1001, 1499, 1500, 1501] {
let bucket = next_mask_shape_bucket(total, 1000, 1500); let bucket = next_mask_shape_bucket(total, 1000, 1500);
assert!( assert!(bucket >= total || total >= 1500, "bucket={bucket} total={total}");
bucket >= total || total >= 1500,
"bucket={bucket} total={total}"
);
} }
} }
#[tokio::test] #[tokio::test]
async fn maybe_write_shape_padding_writes_exact_delta() { async fn maybe_write_shape_padding_writes_exact_delta() {
let mut writer = CountingWriter::new(); let mut writer = CountingWriter::new();
maybe_write_shape_padding(&mut writer, 1200, true, 1000, 1500, false, 0, false).await; maybe_write_shape_padding(&mut writer, 1200, true, 1000, 1500, false, 0).await;
assert_eq!(writer.written, 300); assert_eq!(writer.written, 300);
} }
#[tokio::test] #[tokio::test]
async fn maybe_write_shape_padding_skips_when_disabled() { async fn maybe_write_shape_padding_skips_when_disabled() {
let mut writer = CountingWriter::new(); let mut writer = CountingWriter::new();
maybe_write_shape_padding(&mut writer, 1200, false, 1000, 1500, false, 0, false).await; maybe_write_shape_padding(&mut writer, 1200, false, 1000, 1500, false, 0).await;
assert_eq!(writer.written, 0); assert_eq!(writer.written, 0);
} }
@ -87,7 +84,6 @@ async fn relay_to_mask_applies_cap_clamped_padding_for_non_power_of_two_cap() {
1500, 1500,
false, false,
0, 0,
false,
) )
.await; .await;
}); });

View File

@ -115,12 +115,6 @@ async fn timing_normalization_does_not_sleep_if_path_already_exceeds_ceiling() {
let slow = measure_bad_client_duration_ms(MaskPath::SlowBackend, floor, ceiling).await; let slow = measure_bad_client_duration_ms(MaskPath::SlowBackend, floor, ceiling).await;
assert!( assert!(slow >= 280, "slow backend path should remain slow (got {slow}ms)");
slow >= 280, assert!(slow <= 520, "slow backend path should remain bounded in tests (got {slow}ms)");
"slow backend path should remain slow (got {slow}ms)"
);
assert!(
slow <= 520,
"slow backend path should remain bounded in tests (got {slow}ms)"
);
} }

View File

@ -1,112 +0,0 @@
use super::*;
use crate::stats::Stats;
use dashmap::DashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use tokio::sync::Barrier;
use tokio::time::{Duration, timeout};
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn blackhat_campaign_saturation_quota_race_with_queue_pressure_stays_fail_closed() {
let _guard = super::quota_user_lock_test_scope();
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
for idx in 0..QUOTA_USER_LOCKS_MAX {
retained.push(quota_user_lock(&format!(
"middle-blackhat-held-{}-{idx}",
std::process::id()
)));
}
assert_eq!(
map.len(),
QUOTA_USER_LOCKS_MAX,
"precondition: bounded lock cache must be saturated"
);
let (tx, _rx) = mpsc::channel::<C2MeCommand>(1);
tx.send(C2MeCommand::Close)
.await
.expect("queue prefill should succeed");
let pressure_seq_before = relay_pressure_event_seq();
let pressure_errors = Arc::new(AtomicUsize::new(0));
let mut pressure_workers = Vec::new();
for _ in 0..16 {
let tx = tx.clone();
let pressure_errors = Arc::clone(&pressure_errors);
pressure_workers.push(tokio::spawn(async move {
if enqueue_c2me_command(&tx, C2MeCommand::Close).await.is_err() {
pressure_errors.fetch_add(1, Ordering::Relaxed);
}
}));
}
let stats = Arc::new(Stats::new());
let user = format!("middle-blackhat-quota-race-{}", std::process::id());
let gate = Arc::new(Barrier::new(16));
let mut quota_workers = Vec::new();
for _ in 0..16u8 {
let stats = Arc::clone(&stats);
let user = user.clone();
let gate = Arc::clone(&gate);
quota_workers.push(tokio::spawn(async move {
gate.wait().await;
let user_lock = quota_user_lock(&user);
let _quota_guard = user_lock.lock().await;
if quota_would_be_exceeded_for_user(&stats, &user, Some(1), 1) {
return false;
}
stats.add_user_octets_to(&user, 1);
true
}));
}
let mut ok_count = 0usize;
let mut denied_count = 0usize;
for worker in quota_workers {
let result = timeout(Duration::from_secs(2), worker)
.await
.expect("quota worker must finish")
.expect("quota worker must not panic");
if result {
ok_count += 1;
} else {
denied_count += 1;
}
}
for worker in pressure_workers {
timeout(Duration::from_secs(2), worker)
.await
.expect("pressure worker must finish")
.expect("pressure worker must not panic");
}
assert_eq!(
stats.get_user_total_octets(&user),
1,
"black-hat campaign must not overshoot same-user quota under saturation"
);
assert!(ok_count <= 1, "at most one quota contender may succeed");
assert!(
denied_count >= 15,
"all remaining contenders must be quota-denied"
);
let pressure_seq_after = relay_pressure_event_seq();
assert!(
pressure_seq_after > pressure_seq_before,
"queue pressure leg must trigger pressure accounting"
);
assert!(
pressure_errors.load(Ordering::Relaxed) >= 1,
"at least one pressure worker should fail from persistent backpressure"
);
drop(retained);
}

View File

@ -1,708 +0,0 @@
use super::*;
use crate::crypto::AesCtr;
use crate::crypto::SecureRandom;
use crate::stats::Stats;
use crate::stream::{BufferPool, PooledBuffer};
use std::sync::Arc;
use tokio::io::AsyncReadExt;
use tokio::io::duplex;
use tokio::sync::mpsc;
use tokio::time::{Duration as TokioDuration, timeout};
fn make_pooled_payload(data: &[u8]) -> PooledBuffer {
let pool = Arc::new(BufferPool::with_config(data.len().max(1), 4));
let mut payload = pool.get();
payload.resize(data.len(), 0);
payload[..data.len()].copy_from_slice(data);
payload
}
#[tokio::test]
async fn write_client_payload_abridged_short_quickack_sets_flag_and_preserves_payload() {
let (mut read_side, write_side) = duplex(4096);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
let mut decryptor = AesCtr::new(&key, iv);
let rng = SecureRandom::new();
let mut frame_buf = Vec::new();
let payload = vec![0xA1, 0xB2, 0xC3, 0xD4, 0x10, 0x20, 0x30, 0x40];
write_client_payload(
&mut writer,
ProtoTag::Abridged,
RPC_FLAG_QUICKACK,
&payload,
&rng,
&mut frame_buf,
)
.await
.expect("abridged quickack payload should serialize");
writer.flush().await.expect("flush must succeed");
let mut encrypted = vec![0u8; 1 + payload.len()];
read_side
.read_exact(&mut encrypted)
.await
.expect("must read serialized abridged frame");
let plaintext = decryptor.decrypt(&encrypted);
assert_eq!(plaintext[0], 0x80 | ((payload.len() / 4) as u8));
assert_eq!(&plaintext[1..], payload.as_slice());
}
#[tokio::test]
async fn write_client_payload_abridged_extended_header_is_encoded_correctly() {
let (mut read_side, write_side) = duplex(16 * 1024);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
let mut decryptor = AesCtr::new(&key, iv);
let rng = SecureRandom::new();
let mut frame_buf = Vec::new();
// Boundary where abridged switches to extended length encoding.
let payload = vec![0x5Au8; 0x7f * 4];
write_client_payload(
&mut writer,
ProtoTag::Abridged,
RPC_FLAG_QUICKACK,
&payload,
&rng,
&mut frame_buf,
)
.await
.expect("extended abridged payload should serialize");
writer.flush().await.expect("flush must succeed");
let mut encrypted = vec![0u8; 4 + payload.len()];
read_side
.read_exact(&mut encrypted)
.await
.expect("must read serialized extended abridged frame");
let plaintext = decryptor.decrypt(&encrypted);
assert_eq!(plaintext[0], 0xff, "0x7f with quickack bit must be set");
assert_eq!(&plaintext[1..4], &[0x7f, 0x00, 0x00]);
assert_eq!(&plaintext[4..], payload.as_slice());
}
#[tokio::test]
async fn write_client_payload_abridged_misaligned_is_rejected_fail_closed() {
let (_read_side, write_side) = duplex(1024);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
let rng = SecureRandom::new();
let mut frame_buf = Vec::new();
let err = write_client_payload(
&mut writer,
ProtoTag::Abridged,
0,
&[1, 2, 3],
&rng,
&mut frame_buf,
)
.await
.expect_err("misaligned abridged payload must be rejected");
let msg = format!("{err}");
assert!(
msg.contains("4-byte aligned"),
"error should explain alignment contract, got: {msg}"
);
}
#[tokio::test]
async fn write_client_payload_secure_misaligned_is_rejected_fail_closed() {
let (_read_side, write_side) = duplex(1024);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
let rng = SecureRandom::new();
let mut frame_buf = Vec::new();
let err = write_client_payload(
&mut writer,
ProtoTag::Secure,
0,
&[9, 8, 7, 6, 5],
&rng,
&mut frame_buf,
)
.await
.expect_err("misaligned secure payload must be rejected");
let msg = format!("{err}");
assert!(
msg.contains("Secure payload must be 4-byte aligned"),
"error should be explicit for fail-closed triage, got: {msg}"
);
}
#[tokio::test]
async fn write_client_payload_intermediate_quickack_sets_length_msb() {
let (mut read_side, write_side) = duplex(4096);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
let mut decryptor = AesCtr::new(&key, iv);
let rng = SecureRandom::new();
let mut frame_buf = Vec::new();
let payload = b"hello-middle-relay";
write_client_payload(
&mut writer,
ProtoTag::Intermediate,
RPC_FLAG_QUICKACK,
payload,
&rng,
&mut frame_buf,
)
.await
.expect("intermediate quickack payload should serialize");
writer.flush().await.expect("flush must succeed");
let mut encrypted = vec![0u8; 4 + payload.len()];
read_side
.read_exact(&mut encrypted)
.await
.expect("must read intermediate frame");
let plaintext = decryptor.decrypt(&encrypted);
let mut len_bytes = [0u8; 4];
len_bytes.copy_from_slice(&plaintext[..4]);
let len_with_flags = u32::from_le_bytes(len_bytes);
assert_ne!(len_with_flags & 0x8000_0000, 0, "quickack bit must be set");
assert_eq!((len_with_flags & 0x7fff_ffff) as usize, payload.len());
assert_eq!(&plaintext[4..], payload);
}
#[tokio::test]
async fn write_client_payload_secure_quickack_prefix_and_padding_bounds_hold() {
let (mut read_side, write_side) = duplex(4096);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
let mut decryptor = AesCtr::new(&key, iv);
let rng = SecureRandom::new();
let mut frame_buf = Vec::new();
let payload = vec![0x33u8; 100]; // 4-byte aligned as required by secure mode.
write_client_payload(
&mut writer,
ProtoTag::Secure,
RPC_FLAG_QUICKACK,
&payload,
&rng,
&mut frame_buf,
)
.await
.expect("secure quickack payload should serialize");
writer.flush().await.expect("flush must succeed");
// Secure mode adds 1..=3 bytes of randomized tail padding.
let mut encrypted_header = [0u8; 4];
read_side
.read_exact(&mut encrypted_header)
.await
.expect("must read secure header");
let decrypted_header = decryptor.decrypt(&encrypted_header);
let header: [u8; 4] = decrypted_header
.try_into()
.expect("decrypted secure header must be 4 bytes");
let wire_len_raw = u32::from_le_bytes(header);
assert_ne!(
wire_len_raw & 0x8000_0000,
0,
"secure quickack bit must be set"
);
let wire_len = (wire_len_raw & 0x7fff_ffff) as usize;
assert!(wire_len >= payload.len());
let padding_len = wire_len - payload.len();
assert!(
(1..=3).contains(&padding_len),
"secure writer must add bounded random tail padding, got {padding_len}"
);
let mut encrypted_body = vec![0u8; wire_len];
read_side
.read_exact(&mut encrypted_body)
.await
.expect("must read secure body");
let decrypted_body = decryptor.decrypt(&encrypted_body);
assert_eq!(&decrypted_body[..payload.len()], payload.as_slice());
}
#[tokio::test]
#[ignore = "heavy: allocates >64MiB to validate abridged too-large fail-closed branch"]
async fn write_client_payload_abridged_too_large_is_rejected_fail_closed() {
let (_read_side, write_side) = duplex(1024);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
let rng = SecureRandom::new();
let mut frame_buf = Vec::new();
// Exactly one 4-byte word above the encodable 24-bit abridged length range.
let payload = vec![0x00u8; (1 << 24) * 4];
let err = write_client_payload(
&mut writer,
ProtoTag::Abridged,
0,
&payload,
&rng,
&mut frame_buf,
)
.await
.expect_err("oversized abridged payload must be rejected");
let msg = format!("{err}");
assert!(
msg.contains("Abridged frame too large"),
"error must clearly indicate oversize fail-close path, got: {msg}"
);
}
#[tokio::test]
async fn write_client_ack_intermediate_is_little_endian() {
let (mut read_side, write_side) = duplex(1024);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
let mut decryptor = AesCtr::new(&key, iv);
write_client_ack(&mut writer, ProtoTag::Intermediate, 0x11_22_33_44)
.await
.expect("ack serialization should succeed");
writer.flush().await.expect("flush must succeed");
let mut encrypted = [0u8; 4];
read_side
.read_exact(&mut encrypted)
.await
.expect("must read ack bytes");
let plain = decryptor.decrypt(&encrypted);
assert_eq!(plain.as_slice(), &0x11_22_33_44u32.to_le_bytes());
}
#[tokio::test]
async fn write_client_ack_abridged_is_big_endian() {
let (mut read_side, write_side) = duplex(1024);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
let mut decryptor = AesCtr::new(&key, iv);
write_client_ack(&mut writer, ProtoTag::Abridged, 0xDE_AD_BE_EF)
.await
.expect("ack serialization should succeed");
writer.flush().await.expect("flush must succeed");
let mut encrypted = [0u8; 4];
read_side
.read_exact(&mut encrypted)
.await
.expect("must read ack bytes");
let plain = decryptor.decrypt(&encrypted);
assert_eq!(plain.as_slice(), &0xDE_AD_BE_EFu32.to_be_bytes());
}
#[tokio::test]
async fn write_client_payload_abridged_short_boundary_0x7e_is_single_byte_header() {
let (mut read_side, write_side) = duplex(1024 * 1024);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
let mut decryptor = AesCtr::new(&key, iv);
let rng = SecureRandom::new();
let mut frame_buf = Vec::new();
let payload = vec![0xABu8; 0x7e * 4];
write_client_payload(
&mut writer,
ProtoTag::Abridged,
0,
&payload,
&rng,
&mut frame_buf,
)
.await
.expect("boundary payload should serialize");
writer.flush().await.expect("flush must succeed");
let mut encrypted = vec![0u8; 1 + payload.len()];
read_side.read_exact(&mut encrypted).await.unwrap();
let plain = decryptor.decrypt(&encrypted);
assert_eq!(plain[0], 0x7e);
assert_eq!(&plain[1..], payload.as_slice());
}
#[tokio::test]
async fn write_client_payload_abridged_extended_without_quickack_has_clean_prefix() {
let (mut read_side, write_side) = duplex(16 * 1024);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
let mut decryptor = AesCtr::new(&key, iv);
let rng = SecureRandom::new();
let mut frame_buf = Vec::new();
let payload = vec![0x42u8; 0x80 * 4];
write_client_payload(
&mut writer,
ProtoTag::Abridged,
0,
&payload,
&rng,
&mut frame_buf,
)
.await
.expect("extended payload should serialize");
writer.flush().await.expect("flush must succeed");
let mut encrypted = vec![0u8; 4 + payload.len()];
read_side.read_exact(&mut encrypted).await.unwrap();
let plain = decryptor.decrypt(&encrypted);
assert_eq!(plain[0], 0x7f);
assert_eq!(&plain[1..4], &[0x80, 0x00, 0x00]);
assert_eq!(&plain[4..], payload.as_slice());
}
#[tokio::test]
async fn write_client_payload_intermediate_zero_length_emits_header_only() {
let (mut read_side, write_side) = duplex(1024);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
let mut decryptor = AesCtr::new(&key, iv);
let rng = SecureRandom::new();
let mut frame_buf = Vec::new();
write_client_payload(
&mut writer,
ProtoTag::Intermediate,
0,
&[],
&rng,
&mut frame_buf,
)
.await
.expect("zero-length intermediate payload should serialize");
writer.flush().await.expect("flush must succeed");
let mut encrypted = [0u8; 4];
read_side.read_exact(&mut encrypted).await.unwrap();
let plain = decryptor.decrypt(&encrypted);
assert_eq!(plain.as_slice(), &[0, 0, 0, 0]);
}
#[tokio::test]
async fn write_client_payload_intermediate_ignores_unrelated_flags() {
let (mut read_side, write_side) = duplex(1024);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
let mut decryptor = AesCtr::new(&key, iv);
let rng = SecureRandom::new();
let mut frame_buf = Vec::new();
let payload = [7u8; 12];
write_client_payload(
&mut writer,
ProtoTag::Intermediate,
0x4000_0000,
&payload,
&rng,
&mut frame_buf,
)
.await
.expect("payload should serialize");
writer.flush().await.expect("flush must succeed");
let mut encrypted = [0u8; 16];
read_side.read_exact(&mut encrypted).await.unwrap();
let plain = decryptor.decrypt(&encrypted);
let len = u32::from_le_bytes(plain[0..4].try_into().unwrap());
assert_eq!(len, payload.len() as u32, "only quickack bit may affect header");
assert_eq!(&plain[4..], payload.as_slice());
}
#[tokio::test]
async fn write_client_payload_secure_without_quickack_keeps_msb_clear() {
let (mut read_side, write_side) = duplex(4096);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
let mut decryptor = AesCtr::new(&key, iv);
let rng = SecureRandom::new();
let mut frame_buf = Vec::new();
let payload = [0x1Du8; 64];
write_client_payload(
&mut writer,
ProtoTag::Secure,
0,
&payload,
&rng,
&mut frame_buf,
)
.await
.expect("payload should serialize");
writer.flush().await.expect("flush must succeed");
let mut encrypted_header = [0u8; 4];
read_side.read_exact(&mut encrypted_header).await.unwrap();
let plain_header = decryptor.decrypt(&encrypted_header);
let h: [u8; 4] = plain_header.as_slice().try_into().unwrap();
let wire_len_raw = u32::from_le_bytes(h);
assert_eq!(wire_len_raw & 0x8000_0000, 0, "quickack bit must stay clear");
}
#[tokio::test]
async fn secure_padding_light_fuzz_distribution_has_multiple_outcomes() {
let (mut read_side, write_side) = duplex(256 * 1024);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
let mut decryptor = AesCtr::new(&key, iv);
let rng = SecureRandom::new();
let mut frame_buf = Vec::new();
let payload = [0x55u8; 100];
let mut seen = [false; 4];
for _ in 0..96 {
write_client_payload(
&mut writer,
ProtoTag::Secure,
0,
&payload,
&rng,
&mut frame_buf,
)
.await
.expect("secure payload should serialize");
writer.flush().await.expect("flush must succeed");
let mut encrypted_header = [0u8; 4];
read_side.read_exact(&mut encrypted_header).await.unwrap();
let plain_header = decryptor.decrypt(&encrypted_header);
let h: [u8; 4] = plain_header.as_slice().try_into().unwrap();
let wire_len = (u32::from_le_bytes(h) & 0x7fff_ffff) as usize;
let padding_len = wire_len - payload.len();
assert!((1..=3).contains(&padding_len));
seen[padding_len] = true;
let mut encrypted_body = vec![0u8; wire_len];
read_side.read_exact(&mut encrypted_body).await.unwrap();
let _ = decryptor.decrypt(&encrypted_body);
}
let distinct = (1..=3).filter(|idx| seen[*idx]).count();
assert!(
distinct >= 2,
"padding generator should not collapse to a single outcome under campaign"
);
}
#[tokio::test]
async fn write_client_payload_mixed_proto_sequence_preserves_stream_sync() {
let (mut read_side, write_side) = duplex(128 * 1024);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
let mut decryptor = AesCtr::new(&key, iv);
let rng = SecureRandom::new();
let mut frame_buf = Vec::new();
let p1 = vec![1u8; 8];
let p2 = vec![2u8; 16];
let p3 = vec![3u8; 20];
write_client_payload(&mut writer, ProtoTag::Abridged, 0, &p1, &rng, &mut frame_buf)
.await
.unwrap();
write_client_payload(
&mut writer,
ProtoTag::Intermediate,
RPC_FLAG_QUICKACK,
&p2,
&rng,
&mut frame_buf,
)
.await
.unwrap();
write_client_payload(&mut writer, ProtoTag::Secure, 0, &p3, &rng, &mut frame_buf)
.await
.unwrap();
writer.flush().await.unwrap();
// Frame 1: abridged short.
let mut e1 = vec![0u8; 1 + p1.len()];
read_side.read_exact(&mut e1).await.unwrap();
let d1 = decryptor.decrypt(&e1);
assert_eq!(d1[0], (p1.len() / 4) as u8);
assert_eq!(&d1[1..], p1.as_slice());
// Frame 2: intermediate with quickack.
let mut e2 = vec![0u8; 4 + p2.len()];
read_side.read_exact(&mut e2).await.unwrap();
let d2 = decryptor.decrypt(&e2);
let l2 = u32::from_le_bytes(d2[0..4].try_into().unwrap());
assert_ne!(l2 & 0x8000_0000, 0);
assert_eq!((l2 & 0x7fff_ffff) as usize, p2.len());
assert_eq!(&d2[4..], p2.as_slice());
// Frame 3: secure with bounded tail.
let mut e3h = [0u8; 4];
read_side.read_exact(&mut e3h).await.unwrap();
let d3h = decryptor.decrypt(&e3h);
let l3 = (u32::from_le_bytes(d3h.as_slice().try_into().unwrap()) & 0x7fff_ffff) as usize;
assert!(l3 >= p3.len());
assert!((1..=3).contains(&(l3 - p3.len())));
let mut e3b = vec![0u8; l3];
read_side.read_exact(&mut e3b).await.unwrap();
let d3b = decryptor.decrypt(&e3b);
assert_eq!(&d3b[..p3.len()], p3.as_slice());
}
#[test]
fn should_yield_sender_boundary_matrix_blackhat() {
assert!(!should_yield_c2me_sender(0, false));
assert!(!should_yield_c2me_sender(0, true));
assert!(!should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET - 1, true));
assert!(!should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET, false));
assert!(should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET, true));
assert!(should_yield_c2me_sender(
C2ME_SENDER_FAIRNESS_BUDGET.saturating_add(1024),
true
));
}
#[test]
fn should_yield_sender_light_fuzz_matches_oracle() {
let mut s: u64 = 0xD00D_BAAD_F00D_CAFE;
for _ in 0..5000 {
s ^= s << 7;
s ^= s >> 9;
s ^= s << 8;
let sent = (s as usize) & 0x1fff;
let backlog = (s & 1) != 0;
let expected = backlog && sent >= C2ME_SENDER_FAIRNESS_BUDGET;
assert_eq!(should_yield_c2me_sender(sent, backlog), expected);
}
}
#[test]
fn quota_would_be_exceeded_exact_remaining_one_byte() {
let stats = Stats::new();
let user = "quota-edge";
let quota = 100u64;
stats.add_user_octets_to(user, 99);
assert!(
!quota_would_be_exceeded_for_user(&stats, user, Some(quota), 1),
"exactly remaining budget should be allowed"
);
assert!(
quota_would_be_exceeded_for_user(&stats, user, Some(quota), 2),
"one byte beyond remaining budget must be rejected"
);
}
#[test]
fn quota_would_be_exceeded_saturating_edge_remains_fail_closed() {
let stats = Stats::new();
let user = "quota-saturating-edge";
let quota = u64::MAX - 3;
stats.add_user_octets_to(user, u64::MAX - 4);
assert!(
quota_would_be_exceeded_for_user(&stats, user, Some(quota), 2),
"saturating arithmetic edge must stay fail-closed"
);
}
#[test]
fn quota_exceeded_boundary_is_inclusive() {
let stats = Stats::new();
let user = "quota-inclusive-boundary";
stats.add_user_octets_to(user, 50);
assert!(quota_exceeded_for_user(&stats, user, Some(50)));
assert!(!quota_exceeded_for_user(&stats, user, Some(51)));
}
#[tokio::test]
async fn enqueue_c2me_close_fast_path_succeeds_without_backpressure() {
let (tx, mut rx) = mpsc::channel::<C2MeCommand>(4);
enqueue_c2me_command(&tx, C2MeCommand::Close)
.await
.expect("close should enqueue on fast path");
let recv = timeout(TokioDuration::from_millis(50), rx.recv())
.await
.expect("must receive close command")
.expect("close command should be present");
assert!(matches!(recv, C2MeCommand::Close));
}
#[tokio::test]
async fn enqueue_c2me_data_full_then_drain_preserves_order() {
let (tx, mut rx) = mpsc::channel::<C2MeCommand>(1);
tx.send(C2MeCommand::Data {
payload: make_pooled_payload(&[1]),
flags: 10,
})
.await
.unwrap();
let tx2 = tx.clone();
let producer = tokio::spawn(async move {
enqueue_c2me_command(
&tx2,
C2MeCommand::Data {
payload: make_pooled_payload(&[2, 2]),
flags: 20,
},
)
.await
});
tokio::time::sleep(TokioDuration::from_millis(10)).await;
let first = rx.recv().await.expect("first item should exist");
match first {
C2MeCommand::Data { payload, flags } => {
assert_eq!(payload.as_ref(), &[1]);
assert_eq!(flags, 10);
}
C2MeCommand::Close => panic!("unexpected close as first item"),
}
producer.await.unwrap().expect("producer should complete");
let second = timeout(TokioDuration::from_millis(100), rx.recv())
.await
.unwrap()
.expect("second item should exist");
match second {
C2MeCommand::Data { payload, flags } => {
assert_eq!(payload.as_ref(), &[2, 2]);
assert_eq!(flags, 20);
}
C2MeCommand::Close => panic!("unexpected close as second item"),
}
}

View File

@ -47,11 +47,7 @@ fn desync_all_full_bypass_keeps_existing_dedup_entries_unchanged() {
); );
} }
assert_eq!( assert_eq!(dedup.len(), 2, "bypass path must not mutate dedup cardinality");
dedup.len(),
2,
"bypass path must not mutate dedup cardinality"
);
assert_eq!( assert_eq!(
*dedup *dedup
.get(&0xAAAABBBBCCCCDDDD) .get(&0xAAAABBBBCCCCDDDD)
@ -77,11 +73,7 @@ fn edge_all_full_burst_does_not_poison_later_false_path_tracking() {
let now = Instant::now(); let now = Instant::now();
for i in 0..8192u64 { for i in 0..8192u64 {
assert!(should_emit_full_desync( assert!(should_emit_full_desync(0xABCD_0000_0000_0000 ^ i, true, now));
0xABCD_0000_0000_0000 ^ i,
true,
now
));
} }
let tracked_key = 0xDEAD_BEEF_0000_0001u64; let tracked_key = 0xDEAD_BEEF_0000_0001u64;
@ -183,9 +175,5 @@ fn stress_parallel_all_full_storm_does_not_grow_or_mutate_cache() {
} }
assert_eq!(emits.load(Ordering::Relaxed), 16 * 4096); assert_eq!(emits.load(Ordering::Relaxed), 16 * 4096);
assert_eq!( assert_eq!(dedup.len(), before_len, "parallel all_full storm must not mutate cache len");
dedup.len(),
before_len,
"parallel all_full storm must not mutate cache len"
);
} }

View File

@ -2,8 +2,8 @@ use super::*;
use crate::crypto::AesCtr; use crate::crypto::AesCtr;
use crate::stats::Stats; use crate::stats::Stats;
use crate::stream::{BufferPool, CryptoReader}; use crate::stream::{BufferPool, CryptoReader};
use std::sync::atomic::AtomicU64;
use std::sync::{Arc, Mutex, OnceLock}; use std::sync::{Arc, Mutex, OnceLock};
use std::sync::atomic::AtomicU64;
use tokio::io::AsyncWriteExt; use tokio::io::AsyncWriteExt;
use tokio::io::duplex; use tokio::io::duplex;
use tokio::time::{Duration as TokioDuration, Instant as TokioInstant, timeout}; use tokio::time::{Duration as TokioDuration, Instant as TokioInstant, timeout};
@ -93,9 +93,7 @@ async fn idle_policy_soft_mark_then_hard_close_increments_reason_counters() {
.await .await
.expect("idle test must complete"); .expect("idle test must complete");
assert!( assert!(matches!(result, Err(ProxyError::Io(ref e)) if e.kind() == std::io::ErrorKind::TimedOut));
matches!(result, Err(ProxyError::Io(ref e)) if e.kind() == std::io::ErrorKind::TimedOut)
);
let err_text = match result { let err_text = match result {
Err(ProxyError::Io(ref e)) => e.to_string(), Err(ProxyError::Io(ref e)) => e.to_string(),
_ => String::new(), _ => String::new(),
@ -145,9 +143,7 @@ async fn idle_policy_downstream_activity_grace_extends_hard_deadline() {
.await .await
.expect("grace test must complete"); .expect("grace test must complete");
assert!( assert!(matches!(result, Err(ProxyError::Io(ref e)) if e.kind() == std::io::ErrorKind::TimedOut));
matches!(result, Err(ProxyError::Io(ref e)) if e.kind() == std::io::ErrorKind::TimedOut)
);
assert!( assert!(
start.elapsed() >= TokioDuration::from_millis(100), start.elapsed() >= TokioDuration::from_millis(100),
"recent downstream activity must extend hard idle deadline" "recent downstream activity must extend hard idle deadline"
@ -175,9 +171,7 @@ async fn relay_idle_policy_disabled_keeps_legacy_timeout_behavior() {
) )
.await; .await;
assert!( assert!(matches!(result, Err(ProxyError::Io(ref e)) if e.kind() == std::io::ErrorKind::TimedOut));
matches!(result, Err(ProxyError::Io(ref e)) if e.kind() == std::io::ErrorKind::TimedOut)
);
let err_text = match result { let err_text = match result {
Err(ProxyError::Io(ref e)) => e.to_string(), Err(ProxyError::Io(ref e)) => e.to_string(),
_ => String::new(), _ => String::new(),
@ -231,13 +225,8 @@ async fn adversarial_partial_frame_trickle_cannot_bypass_hard_idle_close() {
.await .await
.expect("partial frame trickle test must complete"); .expect("partial frame trickle test must complete");
assert!( assert!(matches!(result, Err(ProxyError::Io(ref e)) if e.kind() == std::io::ErrorKind::TimedOut));
matches!(result, Err(ProxyError::Io(ref e)) if e.kind() == std::io::ErrorKind::TimedOut) assert_eq!(frame_counter, 0, "partial trickle must not count as a valid frame");
);
assert_eq!(
frame_counter, 0,
"partial trickle must not count as a valid frame"
);
} }
#[tokio::test] #[tokio::test]
@ -302,10 +291,7 @@ async fn protocol_desync_small_frame_updates_reason_counter() {
plaintext.extend_from_slice(&3u32.to_le_bytes()); plaintext.extend_from_slice(&3u32.to_le_bytes());
plaintext.extend_from_slice(&[1u8, 2, 3]); plaintext.extend_from_slice(&[1u8, 2, 3]);
let encrypted = encrypt_for_reader(&plaintext); let encrypted = encrypt_for_reader(&plaintext);
writer writer.write_all(&encrypted).await.expect("must write frame");
.write_all(&encrypted)
.await
.expect("must write frame");
let result = read_client_payload( let result = read_client_payload(
&mut crypto_reader, &mut crypto_reader,
@ -671,8 +657,7 @@ fn blackhat_pressure_seq_saturation_must_not_break_multiple_distinct_events() {
} }
#[tokio::test(flavor = "multi_thread", worker_threads = 4)] #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn integration_race_single_pressure_event_allows_at_most_one_eviction_under_parallel_claims() async fn integration_race_single_pressure_event_allows_at_most_one_eviction_under_parallel_claims() {
{
let _guard = acquire_idle_pressure_test_lock(); let _guard = acquire_idle_pressure_test_lock();
clear_relay_idle_pressure_state_for_testing(); clear_relay_idle_pressure_state_for_testing();
@ -695,8 +680,7 @@ async fn integration_race_single_pressure_event_allows_at_most_one_eviction_unde
let conn_id = *conn_id; let conn_id = *conn_id;
let stats = stats.clone(); let stats = stats.clone();
joins.push(tokio::spawn(async move { joins.push(tokio::spawn(async move {
let evicted = let evicted = maybe_evict_idle_candidate_on_pressure(conn_id, &mut seen, stats.as_ref());
maybe_evict_idle_candidate_on_pressure(conn_id, &mut seen, stats.as_ref());
(idx, conn_id, seen, evicted) (idx, conn_id, seen, evicted)
})); }));
} }
@ -769,8 +753,7 @@ async fn integration_race_burst_pressure_with_churn_preserves_empty_set_invalida
let conn_id = *conn_id; let conn_id = *conn_id;
let stats = stats.clone(); let stats = stats.clone();
joins.push(tokio::spawn(async move { joins.push(tokio::spawn(async move {
let evicted = let evicted = maybe_evict_idle_candidate_on_pressure(conn_id, &mut seen, stats.as_ref());
maybe_evict_idle_candidate_on_pressure(conn_id, &mut seen, stats.as_ref());
(idx, conn_id, seen, evicted) (idx, conn_id, seen, evicted)
})); }));
} }

View File

@ -1,75 +0,0 @@
use super::*;
use std::sync::Arc;
use std::sync::atomic::AtomicU64;
#[test]
fn intermediate_secure_wire_len_allows_max_31bit_payload() {
let (len_val, total) = compute_intermediate_secure_wire_len(0x7fff_fffe, 1, true)
.expect("31-bit wire length should be accepted");
assert_eq!(len_val, 0xffff_ffff, "quickack must use top bit only");
assert_eq!(total, 0x8000_0003);
}
#[test]
fn intermediate_secure_wire_len_rejects_length_above_31bit_limit() {
let err = compute_intermediate_secure_wire_len(0x7fff_ffff, 1, false)
.expect_err("wire length above 31-bit must fail closed");
assert!(
format!("{err}").contains("frame too large"),
"error should identify oversize frame path"
);
}
#[test]
fn intermediate_secure_wire_len_rejects_addition_overflow() {
let err = compute_intermediate_secure_wire_len(usize::MAX, 1, false)
.expect_err("overflowing addition must fail closed");
assert!(
format!("{err}").contains("overflow"),
"error should clearly report overflow"
);
}
#[test]
fn desync_forensics_len_bytes_marks_truncation_for_oversize_values() {
let (small_bytes, small_truncated) = desync_forensics_len_bytes(0x1020_3040);
assert_eq!(small_bytes, 0x1020_3040u32.to_le_bytes());
assert!(!small_truncated);
let (huge_bytes, huge_truncated) = desync_forensics_len_bytes(usize::MAX);
assert_eq!(huge_bytes, u32::MAX.to_le_bytes());
assert!(huge_truncated);
}
#[test]
fn report_desync_frame_too_large_preserves_full_length_in_error_message() {
let state = RelayForensicsState {
trace_id: 0x1234,
conn_id: 0x5678,
user: "middle-desync-oversize".to_string(),
peer: "198.51.100.55:443".parse().expect("valid test peer"),
peer_hash: 0xAABBCCDD,
started_at: Instant::now(),
bytes_c2me: 7,
bytes_me2c: Arc::new(AtomicU64::new(9)),
desync_all_full: false,
};
let huge_len = usize::MAX;
let err = report_desync_frame_too_large(
&state,
ProtoTag::Intermediate,
3,
1024,
huge_len,
None,
&Stats::new(),
);
let msg = format!("{err}");
assert!(
msg.contains(&huge_len.to_string()),
"error must preserve full usize length for forensics"
);
}

View File

@ -1,131 +0,0 @@
use super::*;
use dashmap::DashMap;
use std::sync::Arc;
#[test]
fn saturation_uses_stable_overflow_lock_without_cache_growth() {
let _guard = super::quota_user_lock_test_scope();
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let prefix = format!("middle-quota-held-{}", std::process::id());
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
for idx in 0..QUOTA_USER_LOCKS_MAX {
retained.push(quota_user_lock(&format!("{prefix}-{idx}")));
}
assert_eq!(map.len(), QUOTA_USER_LOCKS_MAX);
let user = format!("middle-quota-overflow-{}", std::process::id());
let first = quota_user_lock(&user);
let second = quota_user_lock(&user);
assert!(
Arc::ptr_eq(&first, &second),
"overflow user must get deterministic same lock while cache is saturated"
);
assert_eq!(
map.len(),
QUOTA_USER_LOCKS_MAX,
"overflow path must not grow bounded lock map"
);
assert!(
map.get(&user).is_none(),
"overflow user should stay outside bounded lock map under saturation"
);
drop(retained);
}
#[test]
fn overflow_striping_keeps_different_users_distributed() {
let _guard = super::quota_user_lock_test_scope();
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let prefix = format!("middle-quota-dist-held-{}", std::process::id());
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
for idx in 0..QUOTA_USER_LOCKS_MAX {
retained.push(quota_user_lock(&format!("{prefix}-{idx}")));
}
let a = quota_user_lock("middle-overflow-user-a");
let b = quota_user_lock("middle-overflow-user-b");
let c = quota_user_lock("middle-overflow-user-c");
let distinct = [
Arc::as_ptr(&a) as usize,
Arc::as_ptr(&b) as usize,
Arc::as_ptr(&c) as usize,
]
.iter()
.copied()
.collect::<std::collections::HashSet<_>>()
.len();
assert!(
distinct >= 2,
"striped overflow lock set should avoid collapsing all users to one lock"
);
drop(retained);
}
#[test]
fn reclaim_path_caches_new_user_after_stale_entries_drop() {
let _guard = super::quota_user_lock_test_scope();
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let prefix = format!("middle-quota-reclaim-held-{}", std::process::id());
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
for idx in 0..QUOTA_USER_LOCKS_MAX {
retained.push(quota_user_lock(&format!("{prefix}-{idx}")));
}
drop(retained);
let user = format!("middle-quota-reclaim-user-{}", std::process::id());
let got = quota_user_lock(&user);
assert!(map.get(&user).is_some());
assert!(
Arc::strong_count(&got) >= 2,
"after reclaim, lock should be held both by caller and map"
);
}
#[test]
fn overflow_path_same_user_is_stable_across_parallel_threads() {
let _guard = super::quota_user_lock_test_scope();
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
for idx in 0..QUOTA_USER_LOCKS_MAX {
retained.push(quota_user_lock(&format!(
"middle-quota-thread-held-{}-{idx}",
std::process::id()
)));
}
let user = format!("middle-quota-overflow-thread-user-{}", std::process::id());
let mut workers = Vec::new();
for _ in 0..32 {
let user = user.clone();
workers.push(std::thread::spawn(move || quota_user_lock(&user)));
}
let first = workers
.remove(0)
.join()
.expect("thread must return lock handle");
for worker in workers {
let got = worker.join().expect("thread must return lock handle");
assert!(
Arc::ptr_eq(&first, &got),
"same overflow user should resolve to one striped lock even under contention"
);
}
drop(retained);
}

Some files were not shown because too many files have changed in this diff Show More