mirror of https://github.com/telemt/telemt.git
Compare commits
14 Commits
6fc188f0c4
...
0475844701
| Author | SHA1 | Date |
|---|---|---|
|
|
0475844701 | |
|
|
1abf9bd05c | |
|
|
6f17d4d231 | |
|
|
bf30e93284 | |
|
|
91be148b72 | |
|
|
d4cda6d546 | |
|
|
e35d69c61f | |
|
|
a353a94175 | |
|
|
b856250b2c | |
|
|
97d1476ded | |
|
|
cde14fc1bf | |
|
|
5723d50d0b | |
|
|
3eb384e02a | |
|
|
c960e0e245 |
|
|
@ -11,7 +11,7 @@ env:
|
|||
|
||||
jobs:
|
||||
build:
|
||||
name: Build
|
||||
name: Compile, Test, Lint
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
permissions:
|
||||
|
|
@ -39,23 +39,11 @@ jobs:
|
|||
restore-keys: |
|
||||
${{ runner.os }}-cargo-
|
||||
|
||||
- name: Build Release
|
||||
run: cargo build --release --verbose
|
||||
- name: Compile (no tests)
|
||||
run: cargo check --workspace --all-features --lib --bins --verbose
|
||||
|
||||
- name: Run tests
|
||||
run: cargo test --verbose
|
||||
|
||||
- name: Stress quota-lock suites (PR only)
|
||||
if: github.event_name == 'pull_request'
|
||||
env:
|
||||
RUST_TEST_THREADS: 16
|
||||
run: |
|
||||
set -euo pipefail
|
||||
for i in $(seq 1 12); do
|
||||
echo "[quota-lock-stress] iteration ${i}/12"
|
||||
cargo test quota_lock_ --bin telemt -- --nocapture --test-threads 16
|
||||
cargo test relay_quota_wake --bin telemt -- --nocapture --test-threads 16
|
||||
done
|
||||
- name: Run tests (single pass)
|
||||
run: cargo test --workspace --all-features --verbose
|
||||
|
||||
# clippy dont fail on warnings because of active development of telemt
|
||||
# and many warnings
|
||||
|
|
|
|||
|
|
@ -0,0 +1,57 @@
|
|||
name: Stress Tests
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
- cron: '0 2 * * *'
|
||||
pull_request:
|
||||
branches: ["*"]
|
||||
paths:
|
||||
- src/proxy/**
|
||||
- src/transport/**
|
||||
- src/stream/**
|
||||
- src/protocol/**
|
||||
- src/tls_front/**
|
||||
- Cargo.toml
|
||||
- Cargo.lock
|
||||
|
||||
env:
|
||||
CARGO_TERM_COLOR: always
|
||||
|
||||
jobs:
|
||||
quota-lock-stress:
|
||||
name: Quota-lock stress loop
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install latest stable Rust toolchain
|
||||
uses: dtolnay/rust-toolchain@stable
|
||||
|
||||
- name: Cache cargo registry and build artifacts
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
~/.cargo/registry
|
||||
~/.cargo/git
|
||||
target
|
||||
key: ${{ runner.os }}-cargo-stress-${{ hashFiles('**/Cargo.lock') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-cargo-stress-
|
||||
${{ runner.os }}-cargo-
|
||||
|
||||
- name: Run quota-lock stress suites
|
||||
env:
|
||||
RUST_TEST_THREADS: 16
|
||||
run: |
|
||||
set -euo pipefail
|
||||
for i in $(seq 1 12); do
|
||||
echo "[quota-lock-stress] iteration ${i}/12"
|
||||
cargo test quota_lock_ --bin telemt -- --nocapture --test-threads 16
|
||||
cargo test relay_quota_wake --bin telemt -- --nocapture --test-threads 16
|
||||
done
|
||||
|
|
@ -1454,9 +1454,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "iri-string"
|
||||
version = "0.7.10"
|
||||
version = "0.7.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c91338f0783edbd6195decb37bae672fd3b165faffb89bf7b9e6942f8b1a731a"
|
||||
checksum = "d8e7418f59cc01c88316161279a7f665217ae316b388e58a0d10e29f54f1e5eb"
|
||||
dependencies = [
|
||||
"memchr",
|
||||
"serde",
|
||||
|
|
@ -1486,7 +1486,7 @@ dependencies = [
|
|||
"cesu8",
|
||||
"cfg-if",
|
||||
"combine",
|
||||
"jni-sys",
|
||||
"jni-sys 0.3.1",
|
||||
"log",
|
||||
"thiserror 1.0.69",
|
||||
"walkdir",
|
||||
|
|
@ -1495,9 +1495,31 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "jni-sys"
|
||||
version = "0.3.0"
|
||||
version = "0.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130"
|
||||
checksum = "41a652e1f9b6e0275df1f15b32661cf0d4b78d4d87ddec5e0c3c20f097433258"
|
||||
dependencies = [
|
||||
"jni-sys 0.4.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jni-sys"
|
||||
version = "0.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c6377a88cb3910bee9b0fa88d4f42e1d2da8e79915598f65fb0c7ee14c878af2"
|
||||
dependencies = [
|
||||
"jni-sys-macros",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jni-sys-macros"
|
||||
version = "0.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "38c0b942f458fe50cdac086d2f946512305e5631e720728f2a61aabcd47a6264"
|
||||
dependencies = [
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jobserver"
|
||||
|
|
@ -1659,9 +1681,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "moka"
|
||||
version = "0.12.14"
|
||||
version = "0.12.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "85f8024e1c8e71c778968af91d43700ce1d11b219d127d79fb2934153b82b42b"
|
||||
checksum = "957228ad12042ee839f93c8f257b62b4c0ab5eaae1d4fa60de53b27c9d7c5046"
|
||||
dependencies = [
|
||||
"crossbeam-channel",
|
||||
"crossbeam-epoch",
|
||||
|
|
@ -2771,7 +2793,7 @@ checksum = "7b2093cf4c8eb1e67749a6762251bc9cd836b6fc171623bd0a9d324d37af2417"
|
|||
|
||||
[[package]]
|
||||
name = "telemt"
|
||||
version = "3.3.29"
|
||||
version = "3.3.30"
|
||||
dependencies = [
|
||||
"aes",
|
||||
"anyhow",
|
||||
|
|
|
|||
|
|
@ -1,8 +1,11 @@
|
|||
[package]
|
||||
name = "telemt"
|
||||
version = "3.3.29"
|
||||
version = "3.3.30"
|
||||
edition = "2024"
|
||||
|
||||
[features]
|
||||
redteam_offline_expected_fail = []
|
||||
|
||||
[dependencies]
|
||||
# C
|
||||
libc = "0.2"
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ This document lists all configuration keys accepted by `config.toml`.
|
|||
| Parameter | Type | Default | Constraints / validation | Description |
|
||||
|---|---|---|---|---|
|
||||
| data_path | `String \| null` | `null` | — | Optional runtime data directory path. |
|
||||
| prefer_ipv6 | `bool` | `false` | — | Prefer IPv6 where applicable in runtime logic. |
|
||||
| prefer_ipv6 | `bool` | `false` | Deprecated. Use `network.prefer`. | Deprecated legacy IPv6 preference flag migrated to `network.prefer`. |
|
||||
| fast_mode | `bool` | `true` | — | Enables fast-path optimizations for traffic processing. |
|
||||
| use_middle_proxy | `bool` | `true` | none | Enables ME transport mode; if `false`, runtime falls back to direct DC routing. |
|
||||
| proxy_secret_path | `String \| null` | `"proxy-secret"` | Path may be `null`. | Path to Telegram infrastructure proxy-secret file used by ME handshake logic. |
|
||||
|
|
@ -44,6 +44,7 @@ This document lists all configuration keys accepted by `config.toml`.
|
|||
| me_writer_cmd_channel_capacity | `usize` | `4096` | Must be `> 0`. | Capacity of per-writer command channel. |
|
||||
| me_route_channel_capacity | `usize` | `768` | Must be `> 0`. | Capacity of per-connection ME response route channel. |
|
||||
| me_c2me_channel_capacity | `usize` | `1024` | Must be `> 0`. | Capacity of per-client command queue (client reader -> ME sender). |
|
||||
| me_c2me_send_timeout_ms | `u64` | `4000` | `0..=60000`. | Maximum wait for enqueueing client->ME commands when the per-client queue is full (`0` keeps legacy unbounded wait). |
|
||||
| me_reader_route_data_wait_ms | `u64` | `2` | `0..=20`. | Bounded wait for routing ME DATA to per-connection queue (`0` = no wait). |
|
||||
| me_d2c_flush_batch_max_frames | `usize` | `32` | `1..=512`. | Max ME->client frames coalesced before flush. |
|
||||
| me_d2c_flush_batch_max_bytes | `usize` | `131072` | `4096..=2_097_152`. | Max ME->client payload bytes coalesced before flush. |
|
||||
|
|
@ -105,6 +106,8 @@ This document lists all configuration keys accepted by `config.toml`.
|
|||
| me_warn_rate_limit_ms | `u64` | `5000` | Must be `> 0`. | Cooldown for repetitive ME warning logs (ms). |
|
||||
| me_route_no_writer_mode | `"async_recovery_failfast" \| "inline_recovery_legacy" \| "hybrid_async_persistent"` | `"hybrid_async_persistent"` | — | Route behavior when no writer is immediately available. |
|
||||
| me_route_no_writer_wait_ms | `u64` | `250` | `10..=5000`. | Max wait in async-recovery failfast mode (ms). |
|
||||
| me_route_hybrid_max_wait_ms | `u64` | `3000` | `50..=60000`. | Maximum cumulative wait in hybrid no-writer mode before failfast fallback (ms). |
|
||||
| me_route_blocking_send_timeout_ms | `u64` | `250` | `0..=5000`. | Maximum wait for blocking route-channel send fallback (`0` keeps legacy unbounded wait). |
|
||||
| me_route_inline_recovery_attempts | `u32` | `3` | Must be `> 0`. | Inline recovery attempts in legacy mode. |
|
||||
| me_route_inline_recovery_wait_ms | `u64` | `3000` | `10..=30000`. | Max inline recovery wait in legacy mode (ms). |
|
||||
| fast_mode_min_tls_record | `usize` | `0` | — | Minimum TLS record size when fast-mode coalescing is enabled (`0` disables). |
|
||||
|
|
@ -124,6 +127,7 @@ This document lists all configuration keys accepted by `config.toml`.
|
|||
| me_secret_atomic_snapshot | `bool` | `true` | — | Keeps selector and secret bytes from the same snapshot atomically. |
|
||||
| proxy_secret_len_max | `usize` | `256` | Must be within `[32, 4096]`. | Upper length limit for accepted proxy-secret bytes. |
|
||||
| me_pool_drain_ttl_secs | `u64` | `90` | none | Time window where stale writers remain fallback-eligible after map change. |
|
||||
| me_instadrain | `bool` | `false` | — | Forces draining stale writers to be removed on the next cleanup tick, bypassing TTL/deadline waiting. |
|
||||
| me_pool_drain_threshold | `u64` | `128` | — | Max draining stale writers before batch force-close (`0` disables threshold cleanup). |
|
||||
| me_pool_drain_soft_evict_enabled | `bool` | `true` | — | Enables gradual soft-eviction of stale writers during drain/reinit instead of immediate hard close. |
|
||||
| me_pool_drain_soft_evict_grace_secs | `u64` | `30` | `0..=3600`. | Grace period before stale writers become soft-evict candidates. |
|
||||
|
|
@ -203,6 +207,7 @@ This document lists all configuration keys accepted by `config.toml`.
|
|||
| metrics_listen | `String \| null` | `null` | — | Full metrics bind address (`IP:PORT`), overrides `metrics_port`. |
|
||||
| metrics_whitelist | `IpNetwork[]` | `["127.0.0.1/32", "::1/128"]` | — | CIDR whitelist for metrics endpoint access. |
|
||||
| max_connections | `u32` | `10000` | — | Max concurrent client connections (`0` = unlimited). |
|
||||
| accept_permit_timeout_ms | `u64` | `250` | `0..=60000`. | Maximum wait for acquiring a connection-slot permit before the accepted connection is dropped (`0` keeps legacy unbounded wait). |
|
||||
|
||||
Note: When `server.proxy_protocol` is enabled, incoming PROXY protocol headers are parsed from the first bytes of the connection and the client source address is replaced with `src_addr` from the header. For security, the peer source IP (the direct connection address) is verified against `server.proxy_protocol_trusted_cidrs`; if this list is empty, PROXY headers are rejected and the connection is considered untrusted.
|
||||
|
||||
|
|
@ -229,7 +234,7 @@ Note: When `server.proxy_protocol` is enabled, incoming PROXY protocol headers a
|
|||
|---|---|---|---|---|
|
||||
| ip | `IpAddr` | — | — | Listener bind IP. |
|
||||
| announce | `String \| null` | — | — | Public IP/domain announced in proxy links (priority over `announce_ip`). |
|
||||
| announce_ip | `IpAddr \| null` | — | — | Deprecated legacy announce IP (migrated to `announce` if needed). |
|
||||
| announce_ip | `IpAddr \| null` | — | Deprecated. Use `announce`. | Deprecated legacy announce IP (migrated to `announce` if needed). |
|
||||
| proxy_protocol | `bool \| null` | `null` | — | Per-listener override for PROXY protocol enable flag. |
|
||||
| reuse_allow | `bool` | `false` | — | Enables `SO_REUSEPORT` for multi-instance bind sharing. |
|
||||
|
||||
|
|
@ -269,6 +274,8 @@ Note: When `server.proxy_protocol` is enabled, incoming PROXY protocol headers a
|
|||
| mask_shape_bucket_cap_bytes | `usize` | `4096` | Must be `>= mask_shape_bucket_floor_bytes`. | Maximum bucket size used by shape-channel hardening; traffic above cap is not padded further. |
|
||||
| mask_shape_above_cap_blur | `bool` | `false` | Requires `mask_shape_hardening = true`; requires `mask_shape_above_cap_blur_max_bytes > 0`. | Adds bounded randomized tail bytes even when forwarded size already exceeds cap. |
|
||||
| mask_shape_above_cap_blur_max_bytes | `usize` | `512` | Must be `<= 1048576`; must be `> 0` when `mask_shape_above_cap_blur = true`. | Maximum randomized extra bytes appended above cap. |
|
||||
| mask_relay_max_bytes | `usize` | `5242880` | Must be `> 0`; must be `<= 67108864`. | Maximum relayed bytes per direction on unauthenticated masking fallback path. |
|
||||
| mask_classifier_prefetch_timeout_ms | `u64` | `5` | Must be within `[5, 50]`. | Timeout budget (ms) for extending fragmented initial classifier window on masking fallback. |
|
||||
| mask_timing_normalization_enabled | `bool` | `false` | Requires `mask_timing_normalization_floor_ms > 0`; requires `ceiling >= floor`. | Enables timing envelope normalization on masking outcomes. |
|
||||
| mask_timing_normalization_floor_ms | `u64` | `0` | Must be `> 0` when timing normalization is enabled; must be `<= ceiling`. | Lower bound (ms) for masking outcome normalization target. |
|
||||
| mask_timing_normalization_ceiling_ms | `u64` | `0` | Must be `>= floor`; must be `<= 60000`. | Upper bound (ms) for masking outcome normalization target. |
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
1. Go to @MTProxybot bot.
|
||||
2. Enter the command `/newproxy`
|
||||
3. Send the server IP and port. For example: 1.2.3.4:443
|
||||
4. Open the config `nano /etc/telemt.toml`.
|
||||
4. Open the config `nano /etc/telemt/telemt.toml`.
|
||||
5. Copy and send the user secret from the [access.users] section to the bot.
|
||||
6. Copy the tag received from the bot. For example 1234567890abcdef1234567890abcdef.
|
||||
> [!WARNING]
|
||||
|
|
@ -33,6 +33,9 @@ hello = "ad_tag"
|
|||
hello2 = "ad_tag2"
|
||||
```
|
||||
|
||||
## Why is middle proxy (ME) needed
|
||||
https://github.com/telemt/telemt/discussions/167
|
||||
|
||||
## How many people can use 1 link
|
||||
|
||||
By default, 1 link can be used by any number of people.
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
1. Зайти в бота @MTProxybot.
|
||||
2. Ввести команду `/newproxy`
|
||||
3. Отправить IP и порт сервера. Например: 1.2.3.4:443
|
||||
4. Открыть конфиг `nano /etc/telemt.toml`.
|
||||
4. Открыть конфиг `nano /etc/telemt/telemt.toml`.
|
||||
5. Скопировать и отправить боту секрет пользователя из раздела [access.users].
|
||||
6. Скопировать полученный tag у бота. Например 1234567890abcdef1234567890abcdef.
|
||||
> [!WARNING]
|
||||
|
|
@ -33,6 +33,10 @@ hello = "ad_tag"
|
|||
hello2 = "ad_tag2"
|
||||
```
|
||||
|
||||
## Зачем нужен middle proxy (ME)
|
||||
https://github.com/telemt/telemt/discussions/167
|
||||
|
||||
|
||||
## Сколько человек может пользоваться 1 ссылкой
|
||||
|
||||
По умолчанию 1 ссылкой может пользоваться сколько угодно человек.
|
||||
|
|
|
|||
|
|
@ -0,0 +1,283 @@
|
|||
<img src="https://gist.githubusercontent.com/avbor/1f8a128e628f47249aae6e058a57610b/raw/19013276c035e91058e0a9799ab145f8e70e3ff5/scheme.svg">
|
||||
|
||||
## Concept
|
||||
- **Server A** (__conditionally Russian Federation_):\
|
||||
Entry point, receives Telegram proxy user traffic via **HAProxy** (port `443`)\
|
||||
and sends it to the tunnel to Server **B**.\
|
||||
Internal IP in the tunnel — `10.10.10.2`\
|
||||
Port for HAProxy clients — `443\tcp`
|
||||
- **Server B** (_conditionally Netherlands_):\
|
||||
Exit point, runs **telemt** and accepts client connections through Server **A**.\
|
||||
The server must have unrestricted access to Telegram servers.\
|
||||
Internal IP in the tunnel — `10.10.10.1`\
|
||||
AmneziaWG port — `8443\udp`\
|
||||
Port for telemt clients — `443\tcp`
|
||||
|
||||
---
|
||||
|
||||
## Step 1. Setting up the AmneziaWG tunnel (A <-> B)
|
||||
[AmneziaWG](https://github.com/amnezia-vpn/amneziawg-linux-kernel-module) must be installed on all servers.\
|
||||
All following commands are given for **Ubuntu 24.04**.\
|
||||
For RHEL-based distributions, installation instructions are available at the link above.
|
||||
|
||||
### Installing AmneziaWG (Servers A and B)
|
||||
The following steps must be performed on each server:
|
||||
|
||||
#### 1. Adding the AmneziaWG repository and installing required packages:
|
||||
```bash
|
||||
sudo apt install -y software-properties-common python3-launchpadlib gnupg2 linux-headers-$(uname -r) && \
|
||||
sudo add-apt-repository ppa:amnezia/ppa && \
|
||||
sudo apt-get install -y amneziawg
|
||||
```
|
||||
|
||||
#### 2. Generating a unique key pair:
|
||||
```bash
|
||||
cd /etc/amnezia/amneziawg && \
|
||||
awg genkey | tee private.key | awg pubkey > public.key
|
||||
```
|
||||
|
||||
As a result, you will get two files in the `/etc/amnezia/amneziawg` folder:\
|
||||
`private.key` - private, and\
|
||||
`public.key` - public server keys
|
||||
|
||||
#### 3. Configuring network interfaces:
|
||||
Obfuscation parameters `S1`, `S2`, `H1`, `H2`, `H3`, `H4` must be strictly identical on both servers.\
|
||||
Parameters `Jc`, `Jmin` and `Jmax` can differ.\
|
||||
Parameters `I1-I5` ([Custom Protocol Signature](https://docs.amnezia.org/documentation/amnezia-wg/)) must be specified on the client side (Server **A**).
|
||||
|
||||
Recommendations for choosing values:
|
||||
|
||||
```text
|
||||
Jc — 1 ≤ Jc ≤ 128; from 4 to 12 inclusive
|
||||
Jmin — Jmax > Jmin < 1280*; recommended 8
|
||||
Jmax — Jmin < Jmax ≤ 1280*; recommended 80
|
||||
S1 — S1 ≤ 1132* (1280* - 148 = 1132); S1 + 56 ≠ S2;
|
||||
recommended range from 15 to 150 inclusive
|
||||
S2 — S2 ≤ 1188* (1280* - 92 = 1188);
|
||||
recommended range from 15 to 150 inclusive
|
||||
H1/H2/H3/H4 — must be unique and differ from each other;
|
||||
recommended range from 5 to 2147483647 inclusive
|
||||
|
||||
* It is assumed that the Internet connection has an MTU of 1280.
|
||||
```
|
||||
|
||||
> [!IMPORTANT]
|
||||
> It is recommended to use your own, unique values.\
|
||||
> You can use the [generator](https://htmlpreview.github.io/?https://gist.githubusercontent.com/avbor/955782b5c37b06240b243aa375baeac5/raw/e8b269ff0089a27effd88f8d925179b78e5666c4/awg-gen.html) to select parameters.
|
||||
|
||||
#### Server B Configuration (Netherlands):
|
||||
|
||||
Create the interface configuration file (`awg0`)
|
||||
```bash
|
||||
nano /etc/amnezia/amneziawg/awg0.conf
|
||||
```
|
||||
|
||||
File content
|
||||
```ini
|
||||
[Interface]
|
||||
Address = 10.10.10.1/24
|
||||
ListenPort = 8443
|
||||
PrivateKey = <PRIVATE_KEY_SERVER_B>
|
||||
SaveConfig = true
|
||||
Jc = 4
|
||||
Jmin = 8
|
||||
Jmax = 80
|
||||
S1 = 29
|
||||
S2 = 15
|
||||
H1 = 2087563914
|
||||
H2 = 188817757
|
||||
H3 = 101784570
|
||||
H4 = 432174303
|
||||
|
||||
[Peer]
|
||||
PublicKey = <PUBLIC_KEY_SERVER_A>
|
||||
AllowedIPs = 10.10.10.2/32
|
||||
```
|
||||
`ListenPort` - the port on which the server will wait for connections, you can choose any free one.\
|
||||
`<PRIVATE_KEY_SERVER_B>` - the content of the `private.key` file from Server **B**.\
|
||||
`<PUBLIC_KEY_SERVER_A>` - the content of the `public.key` file from Server **A**.
|
||||
|
||||
Open the port on the firewall (if enabled):
|
||||
```bash
|
||||
sudo ufw allow from <PUBLIC_IP_SERVER_A> to any port 8443 proto udp
|
||||
```
|
||||
|
||||
`<PUBLIC_IP_SERVER_A>` - the external IP address of Server **A**.
|
||||
|
||||
#### Server A Configuration (Russian Federation):
|
||||
Create the interface configuration file (awg0)
|
||||
|
||||
```bash
|
||||
nano /etc/amnezia/amneziawg/awg0.conf
|
||||
```
|
||||
|
||||
File content
|
||||
```ini
|
||||
[Interface]
|
||||
Address = 10.10.10.2/24
|
||||
PrivateKey = <PRIVATE_KEY_SERVER_A>
|
||||
Jc = 4
|
||||
Jmin = 8
|
||||
Jmax = 80
|
||||
S1 = 29
|
||||
S2 = 15
|
||||
H1 = 2087563914
|
||||
H2 = 188817757
|
||||
H3 = 101784570
|
||||
H4 = 432174303
|
||||
I1 = <b 0xc10000000108981eba846e21f74e00>
|
||||
I2 = <b 0xc20000000108981eba846e21f74e00>
|
||||
I3 = <b 0xc30000000108981eba846e21f74e00>
|
||||
I4 = <b 0x43981eba846e21f74e>
|
||||
I5 = <b 0x43981eba846e21f74e>
|
||||
|
||||
[Peer]
|
||||
PublicKey = <PUBLIC_KEY_SERVER_B>
|
||||
Endpoint = <PUBLIC_IP_SERVER_B>:8443
|
||||
AllowedIPs = 10.10.10.1/32
|
||||
PersistentKeepalive = 25
|
||||
```
|
||||
|
||||
`<PRIVATE_KEY_SERVER_A>` - the content of the `private.key` file from Server **A**.\
|
||||
`<PUBLIC_KEY_SERVER_B>` - the content of the `public.key` file from Server **B**.\
|
||||
`<PUBLIC_IP_SERVER_B>` - the public IP address of Server **B**.
|
||||
|
||||
Enable the tunnel on both servers:
|
||||
```bash
|
||||
sudo systemctl enable --now awg-quick@awg0
|
||||
```
|
||||
|
||||
Make sure Server B is accessible from Server A through the tunnel.
|
||||
```bash
|
||||
ping 10.10.10.1
|
||||
PING 10.10.10.1 (10.10.10.1) 56(84) bytes of data.
|
||||
64 bytes from 10.10.10.1: icmp_seq=1 ttl=64 time=35.1 ms
|
||||
64 bytes from 10.10.10.1: icmp_seq=2 ttl=64 time=35.0 ms
|
||||
64 bytes from 10.10.10.1: icmp_seq=3 ttl=64 time=35.1 ms
|
||||
^C
|
||||
```
|
||||
---
|
||||
|
||||
## Step 2. Installing telemt on Server B (conditionally Netherlands)
|
||||
Installation and configuration are described [here](https://github.com/telemt/telemt/blob/main/docs/QUICK_START_GUIDE.ru.md) or [here](https://gitlab.com/An0nX/telemt-docker#-quick-start-docker-compose).\
|
||||
It is assumed that telemt expects connections on port `443\tcp`.
|
||||
|
||||
In the telemt config, you must enable the `Proxy` protocol and restrict connections to it only through the tunnel.
|
||||
```toml
|
||||
[server]
|
||||
port = 443
|
||||
listen_addr_ipv4 = "10.10.10.1"
|
||||
proxy_protocol = true
|
||||
```
|
||||
|
||||
Also, for correct link generation, specify the FQDN or IP address and port of Server `A`
|
||||
```toml
|
||||
[general.links]
|
||||
show = "*"
|
||||
public_host = "<FQDN_OR_IP_SERVER_A>"
|
||||
public_port = 443
|
||||
```
|
||||
|
||||
Open the port on the firewall (if enabled):
|
||||
```bash
|
||||
sudo ufw allow from 10.10.10.2 to any port 443 proto tcp
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Step 3. Configuring HAProxy on Server A (Russian Federation)
|
||||
Since the version in the standard Ubuntu repository is relatively old, it makes sense to use the official Docker image.\
|
||||
[Instructions](https://docs.docker.com/engine/install/ubuntu/) for installing Docker on Ubuntu.
|
||||
|
||||
> [!WARNING]
|
||||
> By default, regular users do not have rights to use ports < 1024.
|
||||
> Attempts to run HAProxy on port 443 can lead to errors:
|
||||
> ```
|
||||
> [ALERT] (8) : Binding [/usr/local/etc/haproxy/haproxy.cfg:17] for frontend tcp_in_443:
|
||||
> protocol tcpv4: cannot bind socket (Permission denied) for [0.0.0.0:443].
|
||||
> ```
|
||||
> There are two simple ways to bypass this restriction, choose one:
|
||||
> 1. At the OS level, change the net.ipv4.ip_unprivileged_port_start setting to allow users to use all ports:
|
||||
> ```
|
||||
> echo "net.ipv4.ip_unprivileged_port_start = 0" | sudo tee -a /etc/sysctl.conf && sudo sysctl -p
|
||||
> ```
|
||||
> or
|
||||
>
|
||||
> 2. Run HAProxy as root:
|
||||
> Uncomment the `user: "root"` parameter in docker-compose.yaml.
|
||||
|
||||
#### Create a folder for HAProxy:
|
||||
```bash
|
||||
mkdir -p /opt/docker-compose/haproxy && cd $_
|
||||
```
|
||||
|
||||
#### Create the docker-compose.yaml file
|
||||
`nano docker-compose.yaml`
|
||||
|
||||
File content
|
||||
```yaml
|
||||
services:
|
||||
haproxy:
|
||||
image: haproxy:latest
|
||||
container_name: haproxy
|
||||
restart: unless-stopped
|
||||
# user: "root"
|
||||
network_mode: "host"
|
||||
volumes:
|
||||
- ./haproxy.cfg:/usr/local/etc/haproxy/haproxy.cfg:ro
|
||||
logging:
|
||||
driver: "json-file"
|
||||
options:
|
||||
max-size: "1m"
|
||||
max-file: "1"
|
||||
```
|
||||
|
||||
#### Create the haproxy.cfg config file
|
||||
Accept connections on port 443\tcp and send them through the tunnel to Server `B` 10.10.10.1:443
|
||||
|
||||
`nano haproxy.cfg`
|
||||
|
||||
File content
|
||||
|
||||
```haproxy
|
||||
global
|
||||
log stdout format raw local0
|
||||
maxconn 10000
|
||||
|
||||
defaults
|
||||
log global
|
||||
mode tcp
|
||||
option tcplog
|
||||
option clitcpka
|
||||
option srvtcpka
|
||||
timeout connect 5s
|
||||
timeout client 2h
|
||||
timeout server 2h
|
||||
timeout check 5s
|
||||
|
||||
frontend tcp_in_443
|
||||
bind *:443
|
||||
maxconn 8000
|
||||
option tcp-smart-accept
|
||||
default_backend telemt_nodes
|
||||
|
||||
backend telemt_nodes
|
||||
option tcp-smart-connect
|
||||
server server_a 10.10.10.1:443 check inter 5s rise 2 fall 3 send-proxy-v2
|
||||
|
||||
|
||||
```
|
||||
> [!WARNING]
|
||||
> **The file must end with an empty line, otherwise HAProxy will not start!**
|
||||
|
||||
#### Allow port 443\tcp in the firewall (if enabled)
|
||||
```bash
|
||||
sudo ufw allow 443/tcp
|
||||
```
|
||||
|
||||
#### Start the HAProxy container
|
||||
```bash
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
If everything is configured correctly, you can now try connecting Telegram clients using links from the telemt log\api.
|
||||
|
|
@ -0,0 +1,287 @@
|
|||
<img src="https://gist.githubusercontent.com/avbor/1f8a128e628f47249aae6e058a57610b/raw/19013276c035e91058e0a9799ab145f8e70e3ff5/scheme.svg">
|
||||
|
||||
## Концепция
|
||||
- **Сервер A** (_РФ_):\
|
||||
Точка входа, принимает трафик пользователей Telegram-прокси через **HAProxy** (порт `443`)\
|
||||
и отправляет в туннель на Сервер **B**.\
|
||||
Внутренний IP в туннеле — `10.10.10.2`\
|
||||
Порт для клиентов HAProxy — `443\tcp`
|
||||
- **Сервер B** (_условно Нидерланды_):\
|
||||
Точка выхода, на нем работает **telemt** и принимает подключения клиентов через Сервер **A**.\
|
||||
На сервере должен быть неограниченный доступ до серверов Telegram.\
|
||||
Внутренний IP в туннеле — `10.10.10.1`\
|
||||
Порт AmneziaWG — `8443\udp`\
|
||||
Порт для клиентов telemt — `443\tcp`
|
||||
|
||||
---
|
||||
|
||||
## Шаг 1. Настройка туннеля AmneziaWG (A <-> B)
|
||||
|
||||
На всех серверах необходимо установить [amneziawg](https://github.com/amnezia-vpn/amneziawg-linux-kernel-module).\
|
||||
Далее все команды даны для **Ununtu 24.04**.\
|
||||
Для RHEL-based дистрибутивов инструкция по установке есть по ссылке выше.
|
||||
|
||||
### Установка AmneziaWG (Сервера A и B)
|
||||
На каждом из серверов необходимо выполнить следующие шаги:
|
||||
|
||||
#### 1. Добавление репозитория AmneziaWG и установка необходимых пакетов:
|
||||
```bash
|
||||
sudo apt install -y software-properties-common python3-launchpadlib gnupg2 linux-headers-$(uname -r) && \
|
||||
sudo add-apt-repository ppa:amnezia/ppa && \
|
||||
sudo apt-get install -y amneziawg
|
||||
```
|
||||
|
||||
#### 2. Генерация уникальной пары ключей:
|
||||
```bash
|
||||
cd /etc/amnezia/amneziawg && \
|
||||
awg genkey | tee private.key | awg pubkey > public.key
|
||||
```
|
||||
В результате вы получите в папке `/etc/amnezia/amneziawg` два файла:\
|
||||
`private.key` - приватный и\
|
||||
`public.key` - публичный ключи сервера
|
||||
|
||||
#### 3. Настройка сетевых интерфейсов:
|
||||
|
||||
Параметры обфускации `S1`, `S2`, `H1`, `H2`, `H3`, `H4` должны быть строго идентичными на обоих серверах.\
|
||||
Параметры `Jc`, `Jmin` и `Jmax` могут отличатся.\
|
||||
Параметры `I1-I5` [(Custom Protocol Signature)](https://docs.amnezia.org/documentation/amnezia-wg/) нужно указывать на стороне _клиента_ (Сервер **А**).
|
||||
|
||||
Рекомендации по выбору значений:
|
||||
```text
|
||||
Jc — 1 ≤ Jc ≤ 128; от 4 до 12 включительно
|
||||
Jmin — Jmax > Jmin < 1280*; рекомендовано 8
|
||||
Jmax — Jmin < Jmax ≤ 1280*; рекомендовано 80
|
||||
S1 — S1 ≤ 1132* (1280* - 148 = 1132); S1 + 56 ≠ S2;
|
||||
рекомендованный диапазон от 15 до 150 включительно
|
||||
S2 — S2 ≤ 1188* (1280* - 92 = 1188);
|
||||
рекомендованный диапазон от 15 до 150 включительно
|
||||
H1/H2/H3/H4 — должны быть уникальны и отличаться друг от друга;
|
||||
рекомендованный диапазон от 5 до 2147483647 включительно
|
||||
|
||||
* Предполагается, что подключение к Интернету имеет MTU 1280.
|
||||
```
|
||||
> [!IMPORTANT]
|
||||
> Рекомендуется использовать собственные, уникальные значения.\
|
||||
> Для выбора параметров можете воспользоваться [генератором](https://htmlpreview.github.io/?https://gist.githubusercontent.com/avbor/955782b5c37b06240b243aa375baeac5/raw/e8b269ff0089a27effd88f8d925179b78e5666c4/awg-gen.html).
|
||||
|
||||
#### Конфигурация Сервера B (_Нидерланды_):
|
||||
|
||||
Создаем файл конфигурации интерфейса (`awg0`)
|
||||
```bash
|
||||
nano /etc/amnezia/amneziawg/awg0.conf
|
||||
```
|
||||
|
||||
Содержимое файла
|
||||
```ini
|
||||
[Interface]
|
||||
Address = 10.10.10.1/24
|
||||
ListenPort = 8443
|
||||
PrivateKey = <PRIVATE_KEY_SERVER_B>
|
||||
SaveConfig = true
|
||||
Jc = 4
|
||||
Jmin = 8
|
||||
Jmax = 80
|
||||
S1 = 29
|
||||
S2 = 15
|
||||
H1 = 2087563914
|
||||
H2 = 188817757
|
||||
H3 = 101784570
|
||||
H4 = 432174303
|
||||
|
||||
[Peer]
|
||||
PublicKey = <PUBLIC_KEY_SERVER_A>
|
||||
AllowedIPs = 10.10.10.2/32
|
||||
```
|
||||
|
||||
`ListenPort` - порт, на котором сервер будет ждать подключения, можете выбрать любой свободный.\
|
||||
`<PRIVATE_KEY_SERVER_B>` - содержимое файла `private.key` с сервера **B**.\
|
||||
`<PUBLIC_KEY_SERVER_A>` - содержимое файла `public.key` с сервера **A**.
|
||||
|
||||
Открываем порт на фаерволе (если включен):
|
||||
```bash
|
||||
sudo ufw allow from <PUBLIC_IP_SERVER_A> to any port 8443 proto udp
|
||||
```
|
||||
|
||||
`<PUBLIC_IP_SERVER_A>` - внешний IP адрес Сервера **A**.
|
||||
|
||||
#### Конфигурация Сервера A (_РФ_):
|
||||
|
||||
Создаем файл конфигурации интерфейса (`awg0`)
|
||||
```bash
|
||||
nano /etc/amnezia/amneziawg/awg0.conf
|
||||
```
|
||||
|
||||
Содержимое файла
|
||||
```ini
|
||||
[Interface]
|
||||
Address = 10.10.10.2/24
|
||||
PrivateKey = <PRIVATE_KEY_SERVER_A>
|
||||
Jc = 4
|
||||
Jmin = 8
|
||||
Jmax = 80
|
||||
S1 = 29
|
||||
S2 = 15
|
||||
H1 = 2087563914
|
||||
H2 = 188817757
|
||||
H3 = 101784570
|
||||
H4 = 432174303
|
||||
I1 = <b 0xc10000000108981eba846e21f74e00>
|
||||
I2 = <b 0xc20000000108981eba846e21f74e00>
|
||||
I3 = <b 0xc30000000108981eba846e21f74e00>
|
||||
I4 = <b 0x43981eba846e21f74e>
|
||||
I5 = <b 0x43981eba846e21f74e>
|
||||
|
||||
[Peer]
|
||||
PublicKey = <PUBLIC_KEY_SERVER_B>
|
||||
Endpoint = <PUBLIC_IP_SERVER_B>:8443
|
||||
AllowedIPs = 10.10.10.1/32
|
||||
PersistentKeepalive = 25
|
||||
```
|
||||
|
||||
`<PRIVATE_KEY_SERVER_A>` - содержимое файла `private.key` с сервера **A**.\
|
||||
`<PUBLIC_KEY_SERVER_B>` - содержимое файла `public.key` с сервера **B**.\
|
||||
`<PUBLIC_IP_SERVER_B>` - публичный IP адресс сервера **B**.
|
||||
|
||||
#### Включаем туннель на обоих серверах:
|
||||
```bash
|
||||
sudo systemctl enable --now awg-quick@awg0
|
||||
```
|
||||
|
||||
Убедитесь, что с Сервера `A` доступен Сервер `B` через туннель.
|
||||
```bash
|
||||
ping 10.10.10.1
|
||||
PING 10.10.10.1 (10.10.10.1) 56(84) bytes of data.
|
||||
64 bytes from 10.10.10.1: icmp_seq=1 ttl=64 time=35.1 ms
|
||||
64 bytes from 10.10.10.1: icmp_seq=2 ttl=64 time=35.0 ms
|
||||
64 bytes from 10.10.10.1: icmp_seq=3 ttl=64 time=35.1 ms
|
||||
^C
|
||||
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Шаг 2. Установка telemt на Сервере B (_условно Нидерланды_)
|
||||
|
||||
Установка и настройка описаны [здесь](https://github.com/telemt/telemt/blob/main/docs/QUICK_START_GUIDE.ru.md) или [здесь](https://gitlab.com/An0nX/telemt-docker#-quick-start-docker-compose).\
|
||||
Подразумевается что telemt ожидает подключения на порту `443\tcp`.
|
||||
|
||||
В конфиге telemt необходимо включить протокол `Proxy` и ограничить подключения к нему только через туннель.
|
||||
|
||||
```toml
|
||||
[server]
|
||||
port = 443
|
||||
listen_addr_ipv4 = "10.10.10.1"
|
||||
proxy_protocol = true
|
||||
```
|
||||
|
||||
А также, для правильной генерации ссылок, указать FQDN или IP адрес и порт Сервера `A`
|
||||
|
||||
```toml
|
||||
[general.links]
|
||||
show = "*"
|
||||
public_host = "<FQDN_OR_IP_SERVER_A>"
|
||||
public_port = 443
|
||||
```
|
||||
|
||||
Открываем порт на фаерволе (если включен):
|
||||
```bash
|
||||
sudo ufw allow from 10.10.10.2 to any port 443 proto tcp
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Шаг 3. Настройка HAProxy на Сервере A (_РФ_)
|
||||
|
||||
Т.к. в стандартном репозитории Ubuntu версия относительно старая, имеет смысл воспользоваться официальным образом Docker.\
|
||||
[Инструкция](https://docs.docker.com/engine/install/ubuntu/) по установке Docker на Ubuntu.
|
||||
|
||||
> [!WARNING]
|
||||
> По умолчанию у обычных пользователей нет прав на использование портов < 1024.\
|
||||
> Попытки запустить HAProxy на 443 порту могут приводить к ошибкам:
|
||||
> ```
|
||||
> [ALERT] (8) : Binding [/usr/local/etc/haproxy/haproxy.cfg:17] for frontend tcp_in_443:
|
||||
> protocol tcpv4: cannot bind socket (Permission denied) for [0.0.0.0:443].
|
||||
> ```
|
||||
> Есть два простых способа обойти это ограничение, выберите что-то одно:
|
||||
> 1. На уровне ОС изменить настройку net.ipv4.ip_unprivileged_port_start, разрешив пользователям использовать все порты:
|
||||
> ```
|
||||
> echo "net.ipv4.ip_unprivileged_port_start = 0" | sudo tee -a /etc/sysctl.conf && sudo sysctl -p
|
||||
> ```
|
||||
> или
|
||||
>
|
||||
> 2. Запустить HAProxy под root:\
|
||||
> Раскомментируйте в docker-compose.yaml параметр `user: "root"`.
|
||||
|
||||
#### Создаем папку для HAProxy:
|
||||
```bash
|
||||
mkdir -p /opt/docker-compose/haproxy && cd $_
|
||||
```
|
||||
#### Создаем файл docker-compose.yaml
|
||||
|
||||
`nano docker-compose.yaml`
|
||||
|
||||
Содержимое файла
|
||||
```yaml
|
||||
services:
|
||||
haproxy:
|
||||
image: haproxy:latest
|
||||
container_name: haproxy
|
||||
restart: unless-stopped
|
||||
# user: "root"
|
||||
network_mode: "host"
|
||||
volumes:
|
||||
- ./haproxy.cfg:/usr/local/etc/haproxy/haproxy.cfg:ro
|
||||
logging:
|
||||
driver: "json-file"
|
||||
options:
|
||||
max-size: "1m"
|
||||
max-file: "1"
|
||||
```
|
||||
#### Создаем файл конфига haproxy.cfg
|
||||
Принимаем подключения на порту 443\tcp и отправляем их через туннель на Сервер `B` 10.10.10.1:443
|
||||
|
||||
`nano haproxy.cfg`
|
||||
|
||||
Содержимое файла
|
||||
```haproxy
|
||||
global
|
||||
log stdout format raw local0
|
||||
maxconn 10000
|
||||
|
||||
defaults
|
||||
log global
|
||||
mode tcp
|
||||
option tcplog
|
||||
option clitcpka
|
||||
option srvtcpka
|
||||
timeout connect 5s
|
||||
timeout client 2h
|
||||
timeout server 2h
|
||||
timeout check 5s
|
||||
|
||||
frontend tcp_in_443
|
||||
bind *:443
|
||||
maxconn 8000
|
||||
option tcp-smart-accept
|
||||
default_backend telemt_nodes
|
||||
|
||||
backend telemt_nodes
|
||||
option tcp-smart-connect
|
||||
server server_a 10.10.10.1:443 check inter 5s rise 2 fall 3 send-proxy-v2
|
||||
|
||||
|
||||
```
|
||||
>[!WARNING]
|
||||
>**Файл должен заканчиваться пустой строкой, иначе HAProxy не запуститься!**
|
||||
|
||||
#### Разрешаем порт 443\tcp в фаерволе (если включен)
|
||||
```bash
|
||||
sudo ufw allow 443/tcp
|
||||
```
|
||||
|
||||
#### Запускаем контейнер HAProxy
|
||||
```bash
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
Если все настроено верно, то теперь можно пробовать подключить клиентов Telegram с использованием ссылок из лога\api telemt.
|
||||
|
|
@ -553,6 +553,20 @@ pub(crate) fn default_mask_shape_above_cap_blur_max_bytes() -> usize {
|
|||
512
|
||||
}
|
||||
|
||||
#[cfg(not(test))]
|
||||
pub(crate) fn default_mask_relay_max_bytes() -> usize {
|
||||
5 * 1024 * 1024
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn default_mask_relay_max_bytes() -> usize {
|
||||
32 * 1024
|
||||
}
|
||||
|
||||
pub(crate) fn default_mask_classifier_prefetch_timeout_ms() -> u64 {
|
||||
5
|
||||
}
|
||||
|
||||
pub(crate) fn default_mask_timing_normalization_enabled() -> bool {
|
||||
false
|
||||
}
|
||||
|
|
|
|||
|
|
@ -600,6 +600,9 @@ fn warn_non_hot_changes(old: &ProxyConfig, new: &ProxyConfig, non_hot_changed: b
|
|||
|| old.censorship.mask_shape_above_cap_blur != new.censorship.mask_shape_above_cap_blur
|
||||
|| old.censorship.mask_shape_above_cap_blur_max_bytes
|
||||
!= new.censorship.mask_shape_above_cap_blur_max_bytes
|
||||
|| old.censorship.mask_relay_max_bytes != new.censorship.mask_relay_max_bytes
|
||||
|| old.censorship.mask_classifier_prefetch_timeout_ms
|
||||
!= new.censorship.mask_classifier_prefetch_timeout_ms
|
||||
|| old.censorship.mask_timing_normalization_enabled
|
||||
!= new.censorship.mask_timing_normalization_enabled
|
||||
|| old.censorship.mask_timing_normalization_floor_ms
|
||||
|
|
|
|||
|
|
@ -430,6 +430,25 @@ impl ProxyConfig {
|
|||
));
|
||||
}
|
||||
|
||||
if config.censorship.mask_relay_max_bytes == 0 {
|
||||
return Err(ProxyError::Config(
|
||||
"censorship.mask_relay_max_bytes must be > 0".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if config.censorship.mask_relay_max_bytes > 67_108_864 {
|
||||
return Err(ProxyError::Config(
|
||||
"censorship.mask_relay_max_bytes must be <= 67108864".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if !(5..=50).contains(&config.censorship.mask_classifier_prefetch_timeout_ms) {
|
||||
return Err(ProxyError::Config(
|
||||
"censorship.mask_classifier_prefetch_timeout_ms must be within [5, 50]"
|
||||
.to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if config.censorship.mask_timing_normalization_ceiling_ms
|
||||
< config.censorship.mask_timing_normalization_floor_ms
|
||||
{
|
||||
|
|
@ -1134,6 +1153,10 @@ mod load_security_tests;
|
|||
#[path = "tests/load_mask_shape_security_tests.rs"]
|
||||
mod load_mask_shape_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/load_mask_classifier_prefetch_timeout_security_tests.rs"]
|
||||
mod load_mask_classifier_prefetch_timeout_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
|
|
|||
|
|
@ -0,0 +1,75 @@
|
|||
use super::*;
|
||||
use std::fs;
|
||||
use std::path::PathBuf;
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
fn write_temp_config(contents: &str) -> PathBuf {
|
||||
let nonce = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.expect("system time must be after unix epoch")
|
||||
.as_nanos();
|
||||
let path = std::env::temp_dir()
|
||||
.join(format!("telemt-load-mask-prefetch-timeout-security-{nonce}.toml"));
|
||||
fs::write(&path, contents).expect("temp config write must succeed");
|
||||
path
|
||||
}
|
||||
|
||||
fn remove_temp_config(path: &PathBuf) {
|
||||
let _ = fs::remove_file(path);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_rejects_mask_classifier_prefetch_timeout_below_min_bound() {
|
||||
let path = write_temp_config(
|
||||
r#"
|
||||
[censorship]
|
||||
mask_classifier_prefetch_timeout_ms = 4
|
||||
"#,
|
||||
);
|
||||
|
||||
let err = ProxyConfig::load(&path)
|
||||
.expect_err("prefetch timeout below minimum security bound must be rejected");
|
||||
let msg = err.to_string();
|
||||
assert!(
|
||||
msg.contains("censorship.mask_classifier_prefetch_timeout_ms must be within [5, 50]"),
|
||||
"error must explain timeout bound invariant, got: {msg}"
|
||||
);
|
||||
|
||||
remove_temp_config(&path);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_rejects_mask_classifier_prefetch_timeout_above_max_bound() {
|
||||
let path = write_temp_config(
|
||||
r#"
|
||||
[censorship]
|
||||
mask_classifier_prefetch_timeout_ms = 51
|
||||
"#,
|
||||
);
|
||||
|
||||
let err = ProxyConfig::load(&path)
|
||||
.expect_err("prefetch timeout above max security bound must be rejected");
|
||||
let msg = err.to_string();
|
||||
assert!(
|
||||
msg.contains("censorship.mask_classifier_prefetch_timeout_ms must be within [5, 50]"),
|
||||
"error must explain timeout bound invariant, got: {msg}"
|
||||
);
|
||||
|
||||
remove_temp_config(&path);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_accepts_mask_classifier_prefetch_timeout_within_bounds() {
|
||||
let path = write_temp_config(
|
||||
r#"
|
||||
[censorship]
|
||||
mask_classifier_prefetch_timeout_ms = 20
|
||||
"#,
|
||||
);
|
||||
|
||||
let cfg = ProxyConfig::load(&path)
|
||||
.expect("prefetch timeout within security bounds must be accepted");
|
||||
assert_eq!(cfg.censorship.mask_classifier_prefetch_timeout_ms, 20);
|
||||
|
||||
remove_temp_config(&path);
|
||||
}
|
||||
|
|
@ -236,3 +236,57 @@ mask_shape_above_cap_blur_max_bytes = 8
|
|||
|
||||
remove_temp_config(&path);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_rejects_zero_mask_relay_max_bytes() {
|
||||
let path = write_temp_config(
|
||||
r#"
|
||||
[censorship]
|
||||
mask_relay_max_bytes = 0
|
||||
"#,
|
||||
);
|
||||
|
||||
let err = ProxyConfig::load(&path).expect_err("mask_relay_max_bytes must be > 0");
|
||||
let msg = err.to_string();
|
||||
assert!(
|
||||
msg.contains("censorship.mask_relay_max_bytes must be > 0"),
|
||||
"error must explain non-zero relay cap invariant, got: {msg}"
|
||||
);
|
||||
|
||||
remove_temp_config(&path);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_rejects_mask_relay_max_bytes_above_upper_bound() {
|
||||
let path = write_temp_config(
|
||||
r#"
|
||||
[censorship]
|
||||
mask_relay_max_bytes = 67108865
|
||||
"#,
|
||||
);
|
||||
|
||||
let err = ProxyConfig::load(&path)
|
||||
.expect_err("mask_relay_max_bytes above hard cap must be rejected");
|
||||
let msg = err.to_string();
|
||||
assert!(
|
||||
msg.contains("censorship.mask_relay_max_bytes must be <= 67108864"),
|
||||
"error must explain relay cap upper bound invariant, got: {msg}"
|
||||
);
|
||||
|
||||
remove_temp_config(&path);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_accepts_valid_mask_relay_max_bytes() {
|
||||
let path = write_temp_config(
|
||||
r#"
|
||||
[censorship]
|
||||
mask_relay_max_bytes = 8388608
|
||||
"#,
|
||||
);
|
||||
|
||||
let cfg = ProxyConfig::load(&path).expect("valid mask_relay_max_bytes must be accepted");
|
||||
assert_eq!(cfg.censorship.mask_relay_max_bytes, 8_388_608);
|
||||
|
||||
remove_temp_config(&path);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1450,6 +1450,14 @@ pub struct AntiCensorshipConfig {
|
|||
#[serde(default = "default_mask_shape_above_cap_blur_max_bytes")]
|
||||
pub mask_shape_above_cap_blur_max_bytes: usize,
|
||||
|
||||
/// Maximum bytes relayed per direction on unauthenticated masking fallback paths.
|
||||
#[serde(default = "default_mask_relay_max_bytes")]
|
||||
pub mask_relay_max_bytes: usize,
|
||||
|
||||
/// Prefetch timeout (ms) for extending fragmented masking classifier window.
|
||||
#[serde(default = "default_mask_classifier_prefetch_timeout_ms")]
|
||||
pub mask_classifier_prefetch_timeout_ms: u64,
|
||||
|
||||
/// Enable outcome-time normalization envelope for masking fallback.
|
||||
#[serde(default = "default_mask_timing_normalization_enabled")]
|
||||
pub mask_timing_normalization_enabled: bool,
|
||||
|
|
@ -1488,6 +1496,8 @@ impl Default for AntiCensorshipConfig {
|
|||
mask_shape_bucket_cap_bytes: default_mask_shape_bucket_cap_bytes(),
|
||||
mask_shape_above_cap_blur: default_mask_shape_above_cap_blur(),
|
||||
mask_shape_above_cap_blur_max_bytes: default_mask_shape_above_cap_blur_max_bytes(),
|
||||
mask_relay_max_bytes: default_mask_relay_max_bytes(),
|
||||
mask_classifier_prefetch_timeout_ms: default_mask_classifier_prefetch_timeout_ms(),
|
||||
mask_timing_normalization_enabled: default_mask_timing_normalization_enabled(),
|
||||
mask_timing_normalization_floor_ms: default_mask_timing_normalization_floor_ms(),
|
||||
mask_timing_normalization_ceiling_ms: default_mask_timing_normalization_ceiling_ms(),
|
||||
|
|
|
|||
|
|
@ -32,6 +32,14 @@ pub(crate) struct RuntimeWatches {
|
|||
pub(crate) detected_ip_v6: Option<IpAddr>,
|
||||
}
|
||||
|
||||
const QUOTA_USER_LOCK_EVICT_INTERVAL_SECS: u64 = 60;
|
||||
|
||||
fn spawn_quota_lock_maintenance_task() -> tokio::task::JoinHandle<()> {
|
||||
crate::proxy::relay::spawn_quota_user_lock_evictor(std::time::Duration::from_secs(
|
||||
QUOTA_USER_LOCK_EVICT_INTERVAL_SECS,
|
||||
))
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn spawn_runtime_tasks(
|
||||
config: &Arc<ProxyConfig>,
|
||||
|
|
@ -69,6 +77,8 @@ pub(crate) async fn spawn_runtime_tasks(
|
|||
rc_clone.run_periodic_cleanup().await;
|
||||
});
|
||||
|
||||
spawn_quota_lock_maintenance_task();
|
||||
|
||||
let detected_ip_v4: Option<IpAddr> = probe.detected_ipv4.map(IpAddr::V4);
|
||||
let detected_ip_v6: Option<IpAddr> = probe.detected_ipv6.map(IpAddr::V6);
|
||||
debug!(
|
||||
|
|
@ -360,3 +370,24 @@ pub(crate) async fn mark_runtime_ready(startup_tracker: &Arc<StartupTracker>) {
|
|||
.await;
|
||||
startup_tracker.mark_ready().await;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn tdd_runtime_quota_lock_maintenance_path_spawns_single_evictor_task() {
|
||||
crate::proxy::relay::reset_quota_user_lock_evictor_spawn_count_for_tests();
|
||||
|
||||
let handle = spawn_quota_lock_maintenance_task();
|
||||
tokio::time::sleep(std::time::Duration::from_millis(5)).await;
|
||||
|
||||
assert_eq!(
|
||||
crate::proxy::relay::quota_user_lock_evictor_spawn_count_for_tests(),
|
||||
1,
|
||||
"runtime maintenance path must spawn exactly one quota lock evictor task per call"
|
||||
);
|
||||
|
||||
handle.abort();
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -186,6 +186,67 @@ fn handshake_timeout_with_mask_grace(config: &ProxyConfig) -> Duration {
|
|||
}
|
||||
}
|
||||
|
||||
const MASK_CLASSIFIER_PREFETCH_WINDOW: usize = 16;
|
||||
#[cfg(test)]
|
||||
const MASK_CLASSIFIER_PREFETCH_TIMEOUT: Duration = Duration::from_millis(5);
|
||||
|
||||
fn mask_classifier_prefetch_timeout(config: &ProxyConfig) -> Duration {
|
||||
Duration::from_millis(config.censorship.mask_classifier_prefetch_timeout_ms)
|
||||
}
|
||||
|
||||
fn should_prefetch_mask_classifier_window(initial_data: &[u8]) -> bool {
|
||||
if initial_data.len() >= MASK_CLASSIFIER_PREFETCH_WINDOW {
|
||||
return false;
|
||||
}
|
||||
|
||||
if initial_data.is_empty() {
|
||||
// Empty initial_data means there is no client probe prefix to refine.
|
||||
// Prefetching in this case can consume fallback relay payload bytes and
|
||||
// accidentally route them through shaping heuristics.
|
||||
return false;
|
||||
}
|
||||
|
||||
if initial_data[0] == 0x16 || initial_data.starts_with(b"SSH-") {
|
||||
return false;
|
||||
}
|
||||
|
||||
initial_data.iter().all(|b| b.is_ascii_alphabetic() || *b == b' ')
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
async fn extend_masking_initial_window<R>(reader: &mut R, initial_data: &mut Vec<u8>)
|
||||
where
|
||||
R: AsyncRead + Unpin,
|
||||
{
|
||||
extend_masking_initial_window_with_timeout(reader, initial_data, MASK_CLASSIFIER_PREFETCH_TIMEOUT)
|
||||
.await;
|
||||
}
|
||||
|
||||
async fn extend_masking_initial_window_with_timeout<R>(
|
||||
reader: &mut R,
|
||||
initial_data: &mut Vec<u8>,
|
||||
prefetch_timeout: Duration,
|
||||
)
|
||||
where
|
||||
R: AsyncRead + Unpin,
|
||||
{
|
||||
if !should_prefetch_mask_classifier_window(initial_data) {
|
||||
return;
|
||||
}
|
||||
|
||||
let need = MASK_CLASSIFIER_PREFETCH_WINDOW.saturating_sub(initial_data.len());
|
||||
if need == 0 {
|
||||
return;
|
||||
}
|
||||
|
||||
let mut extra = [0u8; MASK_CLASSIFIER_PREFETCH_WINDOW];
|
||||
if let Ok(Ok(n)) = timeout(prefetch_timeout, reader.read(&mut extra[..need])).await
|
||||
&& n > 0
|
||||
{
|
||||
initial_data.extend_from_slice(&extra[..n]);
|
||||
}
|
||||
}
|
||||
|
||||
fn masking_outcome<R, W>(
|
||||
reader: R,
|
||||
writer: W,
|
||||
|
|
@ -200,6 +261,15 @@ where
|
|||
W: AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
HandshakeOutcome::NeedsMasking(Box::pin(async move {
|
||||
let mut reader = reader;
|
||||
let mut initial_data = initial_data;
|
||||
extend_masking_initial_window_with_timeout(
|
||||
&mut reader,
|
||||
&mut initial_data,
|
||||
mask_classifier_prefetch_timeout(&config),
|
||||
)
|
||||
.await;
|
||||
|
||||
handle_bad_client(
|
||||
reader,
|
||||
writer,
|
||||
|
|
@ -1321,6 +1391,38 @@ mod masking_shape_classifier_fuzz_redteam_expected_fail_tests;
|
|||
#[path = "tests/client_masking_probe_evasion_blackhat_tests.rs"]
|
||||
mod masking_probe_evasion_blackhat_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/client_masking_fragmented_classifier_security_tests.rs"]
|
||||
mod masking_fragmented_classifier_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/client_masking_replay_timing_security_tests.rs"]
|
||||
mod masking_replay_timing_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/client_masking_http2_fragmented_preface_security_tests.rs"]
|
||||
mod masking_http2_fragmented_preface_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/client_masking_prefetch_invariant_security_tests.rs"]
|
||||
mod masking_prefetch_invariant_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/client_masking_prefetch_timing_matrix_security_tests.rs"]
|
||||
mod masking_prefetch_timing_matrix_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/client_masking_prefetch_config_runtime_security_tests.rs"]
|
||||
mod masking_prefetch_config_runtime_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/client_masking_prefetch_config_pipeline_integration_security_tests.rs"]
|
||||
mod masking_prefetch_config_pipeline_integration_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/client_masking_prefetch_strict_boundary_security_tests.rs"]
|
||||
mod masking_prefetch_strict_boundary_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/client_beobachten_ttl_bounds_security_tests.rs"]
|
||||
mod beobachten_ttl_bounds_security_tests;
|
||||
|
|
|
|||
|
|
@ -121,6 +121,19 @@ fn auth_probe_eviction_offset(peer_ip: IpAddr, now: Instant) -> usize {
|
|||
hasher.finish() as usize
|
||||
}
|
||||
|
||||
fn auth_probe_scan_start_offset(
|
||||
peer_ip: IpAddr,
|
||||
now: Instant,
|
||||
state_len: usize,
|
||||
scan_limit: usize,
|
||||
) -> usize {
|
||||
if state_len == 0 || scan_limit == 0 {
|
||||
return 0;
|
||||
}
|
||||
|
||||
auth_probe_eviction_offset(peer_ip, now) % state_len
|
||||
}
|
||||
|
||||
fn auth_probe_is_throttled(peer_ip: IpAddr, now: Instant) -> bool {
|
||||
let peer_ip = normalize_auth_probe_ip(peer_ip);
|
||||
let state = auth_probe_state_map();
|
||||
|
|
@ -269,11 +282,7 @@ fn auth_probe_record_failure_with_state(
|
|||
let mut eviction_candidate: Option<(IpAddr, u32, Instant)> = None;
|
||||
let state_len = state.len();
|
||||
let scan_limit = state_len.min(AUTH_PROBE_PRUNE_SCAN_LIMIT);
|
||||
let start_offset = if state_len == 0 {
|
||||
0
|
||||
} else {
|
||||
auth_probe_eviction_offset(peer_ip, now) % state_len
|
||||
};
|
||||
let start_offset = auth_probe_scan_start_offset(peer_ip, now, state_len, scan_limit);
|
||||
|
||||
let mut scanned = 0usize;
|
||||
for entry in state.iter().skip(start_offset) {
|
||||
|
|
@ -769,7 +778,7 @@ where
|
|||
let mut dec_key_input = Zeroizing::new(Vec::with_capacity(PREKEY_LEN + secret.len()));
|
||||
dec_key_input.extend_from_slice(dec_prekey);
|
||||
dec_key_input.extend_from_slice(&secret);
|
||||
let dec_key = sha256(&dec_key_input);
|
||||
let dec_key = Zeroizing::new(sha256(&dec_key_input));
|
||||
|
||||
let mut dec_iv_arr = [0u8; IV_LEN];
|
||||
dec_iv_arr.copy_from_slice(dec_iv_bytes);
|
||||
|
|
@ -805,7 +814,7 @@ where
|
|||
let mut enc_key_input = Zeroizing::new(Vec::with_capacity(PREKEY_LEN + secret.len()));
|
||||
enc_key_input.extend_from_slice(enc_prekey);
|
||||
enc_key_input.extend_from_slice(&secret);
|
||||
let enc_key = sha256(&enc_key_input);
|
||||
let enc_key = Zeroizing::new(sha256(&enc_key_input));
|
||||
|
||||
let mut enc_iv_arr = [0u8; IV_LEN];
|
||||
enc_iv_arr.copy_from_slice(enc_iv_bytes);
|
||||
|
|
@ -830,9 +839,9 @@ where
|
|||
user: user.clone(),
|
||||
dc_idx,
|
||||
proto_tag,
|
||||
dec_key,
|
||||
dec_key: *dec_key,
|
||||
dec_iv,
|
||||
enc_key,
|
||||
enc_key: *enc_key,
|
||||
enc_iv,
|
||||
peer,
|
||||
is_tls,
|
||||
|
|
@ -979,6 +988,18 @@ mod saturation_poison_security_tests;
|
|||
#[path = "tests/handshake_auth_probe_hardening_adversarial_tests.rs"]
|
||||
mod auth_probe_hardening_adversarial_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/handshake_auth_probe_scan_budget_security_tests.rs"]
|
||||
mod auth_probe_scan_budget_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/handshake_auth_probe_scan_offset_stress_tests.rs"]
|
||||
mod auth_probe_scan_offset_stress_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/handshake_auth_probe_eviction_bias_security_tests.rs"]
|
||||
mod auth_probe_eviction_bias_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/handshake_advanced_clever_tests.rs"]
|
||||
mod advanced_clever_tests;
|
||||
|
|
@ -995,6 +1016,10 @@ mod real_bug_stress_tests;
|
|||
#[path = "tests/handshake_timing_manual_bench_tests.rs"]
|
||||
mod timing_manual_bench_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/handshake_key_material_zeroization_security_tests.rs"]
|
||||
mod handshake_key_material_zeroization_security_tests;
|
||||
|
||||
/// Compile-time guard: HandshakeSuccess holds cryptographic key material and
|
||||
/// must never be Copy. A Copy impl would allow silent key duplication,
|
||||
/// undermining the zeroize-on-drop guarantee.
|
||||
|
|
|
|||
|
|
@ -4,14 +4,23 @@ use crate::config::ProxyConfig;
|
|||
use crate::network::dns_overrides::resolve_socket_addr;
|
||||
use crate::stats::beobachten::BeobachtenStore;
|
||||
use crate::transport::proxy_protocol::{ProxyProtocolV1Builder, ProxyProtocolV2Builder};
|
||||
use rand::{Rng, RngExt};
|
||||
use std::net::SocketAddr;
|
||||
#[cfg(unix)]
|
||||
use nix::ifaddrs::getifaddrs;
|
||||
use rand::rngs::StdRng;
|
||||
use rand::{Rng, RngExt, SeedableRng};
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::str;
|
||||
use std::time::Duration;
|
||||
#[cfg(unix)]
|
||||
use std::sync::{Mutex, OnceLock};
|
||||
#[cfg(test)]
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::time::{Duration, Instant as StdInstant};
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
use tokio::net::TcpStream;
|
||||
#[cfg(unix)]
|
||||
use tokio::net::UnixStream;
|
||||
#[cfg(unix)]
|
||||
use tokio::sync::Mutex as AsyncMutex;
|
||||
use tokio::time::{Instant, timeout};
|
||||
use tracing::debug;
|
||||
|
||||
|
|
@ -30,13 +39,23 @@ const MASK_RELAY_IDLE_TIMEOUT: Duration = Duration::from_secs(5);
|
|||
#[cfg(test)]
|
||||
const MASK_RELAY_IDLE_TIMEOUT: Duration = Duration::from_millis(100);
|
||||
const MASK_BUFFER_SIZE: usize = 8192;
|
||||
#[cfg(unix)]
|
||||
#[cfg(not(test))]
|
||||
const LOCAL_INTERFACE_CACHE_TTL: Duration = Duration::from_secs(300);
|
||||
#[cfg(all(unix, test))]
|
||||
const LOCAL_INTERFACE_CACHE_TTL: Duration = Duration::from_secs(1);
|
||||
|
||||
struct CopyOutcome {
|
||||
total: usize,
|
||||
ended_by_eof: bool,
|
||||
}
|
||||
|
||||
async fn copy_with_idle_timeout<R, W>(reader: &mut R, writer: &mut W) -> CopyOutcome
|
||||
async fn copy_with_idle_timeout<R, W>(
|
||||
reader: &mut R,
|
||||
writer: &mut W,
|
||||
byte_cap: usize,
|
||||
shutdown_on_eof: bool,
|
||||
) -> CopyOutcome
|
||||
where
|
||||
R: AsyncRead + Unpin,
|
||||
W: AsyncWrite + Unpin,
|
||||
|
|
@ -44,14 +63,31 @@ where
|
|||
let mut buf = [0u8; MASK_BUFFER_SIZE];
|
||||
let mut total = 0usize;
|
||||
let mut ended_by_eof = false;
|
||||
|
||||
if byte_cap == 0 {
|
||||
return CopyOutcome {
|
||||
total,
|
||||
ended_by_eof,
|
||||
};
|
||||
}
|
||||
|
||||
loop {
|
||||
let read_res = timeout(MASK_RELAY_IDLE_TIMEOUT, reader.read(&mut buf)).await;
|
||||
let remaining_budget = byte_cap.saturating_sub(total);
|
||||
if remaining_budget == 0 {
|
||||
break;
|
||||
}
|
||||
|
||||
let read_len = remaining_budget.min(MASK_BUFFER_SIZE);
|
||||
let read_res = timeout(MASK_RELAY_IDLE_TIMEOUT, reader.read(&mut buf[..read_len])).await;
|
||||
let n = match read_res {
|
||||
Ok(Ok(n)) => n,
|
||||
Ok(Err(_)) | Err(_) => break,
|
||||
};
|
||||
if n == 0 {
|
||||
ended_by_eof = true;
|
||||
if shutdown_on_eof {
|
||||
let _ = timeout(MASK_RELAY_IDLE_TIMEOUT, writer.shutdown()).await;
|
||||
}
|
||||
break;
|
||||
}
|
||||
total = total.saturating_add(n);
|
||||
|
|
@ -68,6 +104,39 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
fn is_http_probe(data: &[u8]) -> bool {
|
||||
// RFC 7540 section 3.5: HTTP/2 client preface starts with "PRI ".
|
||||
const HTTP_METHODS: [&[u8]; 10] = [
|
||||
b"GET ",
|
||||
b"POST",
|
||||
b"HEAD",
|
||||
b"PUT ",
|
||||
b"DELETE",
|
||||
b"OPTIONS",
|
||||
b"CONNECT",
|
||||
b"TRACE",
|
||||
b"PATCH",
|
||||
b"PRI ",
|
||||
];
|
||||
|
||||
if data.is_empty() {
|
||||
return false;
|
||||
}
|
||||
|
||||
let window = &data[..data.len().min(16)];
|
||||
for method in HTTP_METHODS {
|
||||
if data.len() >= method.len() && window.starts_with(method) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (2..=3).contains(&window.len()) && method.starts_with(window) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
fn next_mask_shape_bucket(total: usize, floor: usize, cap: usize) -> usize {
|
||||
if total == 0 || floor == 0 || cap < floor {
|
||||
return total;
|
||||
|
|
@ -125,6 +194,11 @@ async fn maybe_write_shape_padding<W>(
|
|||
let mut remaining = target_total - total_sent;
|
||||
let mut pad_chunk = [0u8; 1024];
|
||||
let deadline = Instant::now() + MASK_TIMEOUT;
|
||||
// Use a Send RNG so relay futures remain spawn-safe under Tokio.
|
||||
let mut rng = {
|
||||
let mut seed_source = rand::rng();
|
||||
StdRng::from_rng(&mut seed_source)
|
||||
};
|
||||
|
||||
while remaining > 0 {
|
||||
let now = Instant::now();
|
||||
|
|
@ -133,10 +207,7 @@ async fn maybe_write_shape_padding<W>(
|
|||
}
|
||||
|
||||
let write_len = remaining.min(pad_chunk.len());
|
||||
{
|
||||
let mut rng = rand::rng();
|
||||
rng.fill_bytes(&mut pad_chunk[..write_len]);
|
||||
}
|
||||
rng.fill_bytes(&mut pad_chunk[..write_len]);
|
||||
let write_budget = deadline.saturating_duration_since(now);
|
||||
match timeout(write_budget, mask_write.write_all(&pad_chunk[..write_len])).await {
|
||||
Ok(Ok(())) => {}
|
||||
|
|
@ -167,11 +238,11 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
async fn consume_client_data_with_timeout<R>(reader: R)
|
||||
async fn consume_client_data_with_timeout_and_cap<R>(reader: R, byte_cap: usize)
|
||||
where
|
||||
R: AsyncRead + Unpin,
|
||||
{
|
||||
if timeout(MASK_RELAY_TIMEOUT, consume_client_data(reader))
|
||||
if timeout(MASK_RELAY_TIMEOUT, consume_client_data(reader, byte_cap))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
|
|
@ -190,6 +261,9 @@ fn mask_outcome_target_budget(config: &ProxyConfig) -> Duration {
|
|||
if config.censorship.mask_timing_normalization_enabled {
|
||||
let floor = config.censorship.mask_timing_normalization_floor_ms;
|
||||
let ceiling = config.censorship.mask_timing_normalization_ceiling_ms;
|
||||
if floor == 0 {
|
||||
return MASK_TIMEOUT;
|
||||
}
|
||||
if ceiling > floor {
|
||||
let mut rng = rand::rng();
|
||||
return Duration::from_millis(rng.random_range(floor..=ceiling));
|
||||
|
|
@ -219,14 +293,7 @@ async fn wait_mask_outcome_budget(started: Instant, config: &ProxyConfig) {
|
|||
/// Detect client type based on initial data
|
||||
fn detect_client_type(data: &[u8]) -> &'static str {
|
||||
// Check for HTTP request
|
||||
if data.len() > 4
|
||||
&& (data.starts_with(b"GET ")
|
||||
|| data.starts_with(b"POST")
|
||||
|| data.starts_with(b"HEAD")
|
||||
|| data.starts_with(b"PUT ")
|
||||
|| data.starts_with(b"DELETE")
|
||||
|| data.starts_with(b"OPTIONS"))
|
||||
{
|
||||
if is_http_probe(data) {
|
||||
return "HTTP";
|
||||
}
|
||||
|
||||
|
|
@ -248,6 +315,244 @@ fn detect_client_type(data: &[u8]) -> &'static str {
|
|||
"unknown"
|
||||
}
|
||||
|
||||
fn parse_mask_host_ip_literal(host: &str) -> Option<IpAddr> {
|
||||
if host.starts_with('[') && host.ends_with(']') {
|
||||
return host[1..host.len() - 1].parse::<IpAddr>().ok();
|
||||
}
|
||||
host.parse::<IpAddr>().ok()
|
||||
}
|
||||
|
||||
fn canonical_ip(ip: IpAddr) -> IpAddr {
|
||||
match ip {
|
||||
IpAddr::V6(v6) => v6.to_ipv4_mapped().map(IpAddr::V4).unwrap_or(IpAddr::V6(v6)),
|
||||
IpAddr::V4(v4) => IpAddr::V4(v4),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
fn collect_local_interface_ips() -> Vec<IpAddr> {
|
||||
#[cfg(test)]
|
||||
LOCAL_INTERFACE_ENUMERATIONS.fetch_add(1, Ordering::Relaxed);
|
||||
|
||||
let mut out = Vec::new();
|
||||
if let Ok(addrs) = getifaddrs() {
|
||||
for iface in addrs {
|
||||
if let Some(address) = iface.address {
|
||||
if let Some(v4) = address.as_sockaddr_in() {
|
||||
out.push(canonical_ip(IpAddr::V4(v4.ip())));
|
||||
} else if let Some(v6) = address.as_sockaddr_in6() {
|
||||
out.push(canonical_ip(IpAddr::V6(v6.ip())));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
fn choose_interface_snapshot(previous: &[IpAddr], refreshed: Vec<IpAddr>) -> Vec<IpAddr> {
|
||||
if refreshed.is_empty() && !previous.is_empty() {
|
||||
return previous.to_vec();
|
||||
}
|
||||
|
||||
refreshed
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[derive(Default)]
|
||||
struct LocalInterfaceCache {
|
||||
ips: Vec<IpAddr>,
|
||||
refreshed_at: Option<StdInstant>,
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
static LOCAL_INTERFACE_CACHE: OnceLock<Mutex<LocalInterfaceCache>> = OnceLock::new();
|
||||
|
||||
#[cfg(unix)]
|
||||
static LOCAL_INTERFACE_REFRESH_LOCK: OnceLock<AsyncMutex<()>> = OnceLock::new();
|
||||
|
||||
#[cfg(all(unix, test))]
|
||||
fn local_interface_ips() -> Vec<IpAddr> {
|
||||
let cache = LOCAL_INTERFACE_CACHE.get_or_init(|| Mutex::new(LocalInterfaceCache::default()));
|
||||
let mut guard = cache.lock().unwrap_or_else(|poison| poison.into_inner());
|
||||
|
||||
let stale = guard
|
||||
.refreshed_at
|
||||
.is_none_or(|at| at.elapsed() >= LOCAL_INTERFACE_CACHE_TTL);
|
||||
if stale {
|
||||
let refreshed = collect_local_interface_ips();
|
||||
guard.ips = choose_interface_snapshot(&guard.ips, refreshed);
|
||||
guard.refreshed_at = Some(StdInstant::now());
|
||||
}
|
||||
|
||||
guard.ips.clone()
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
async fn local_interface_ips_async() -> Vec<IpAddr> {
|
||||
let cache = LOCAL_INTERFACE_CACHE.get_or_init(|| Mutex::new(LocalInterfaceCache::default()));
|
||||
|
||||
{
|
||||
let guard = cache.lock().unwrap_or_else(|poison| poison.into_inner());
|
||||
let stale = guard
|
||||
.refreshed_at
|
||||
.is_none_or(|at| at.elapsed() >= LOCAL_INTERFACE_CACHE_TTL);
|
||||
if !stale {
|
||||
return guard.ips.clone();
|
||||
}
|
||||
}
|
||||
|
||||
let refresh_lock = LOCAL_INTERFACE_REFRESH_LOCK.get_or_init(|| AsyncMutex::new(()));
|
||||
let _refresh_guard = refresh_lock.lock().await;
|
||||
|
||||
{
|
||||
let guard = cache.lock().unwrap_or_else(|poison| poison.into_inner());
|
||||
let stale = guard
|
||||
.refreshed_at
|
||||
.is_none_or(|at| at.elapsed() >= LOCAL_INTERFACE_CACHE_TTL);
|
||||
if !stale {
|
||||
return guard.ips.clone();
|
||||
}
|
||||
}
|
||||
|
||||
let refreshed = tokio::task::spawn_blocking(collect_local_interface_ips)
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
|
||||
let mut guard = cache.lock().unwrap_or_else(|poison| poison.into_inner());
|
||||
let stale = guard
|
||||
.refreshed_at
|
||||
.is_none_or(|at| at.elapsed() >= LOCAL_INTERFACE_CACHE_TTL);
|
||||
if stale {
|
||||
guard.ips = choose_interface_snapshot(&guard.ips, refreshed);
|
||||
guard.refreshed_at = Some(StdInstant::now());
|
||||
}
|
||||
|
||||
guard.ips.clone()
|
||||
}
|
||||
|
||||
#[cfg(all(not(unix), test))]
|
||||
fn local_interface_ips() -> Vec<IpAddr> {
|
||||
Vec::new()
|
||||
}
|
||||
|
||||
#[cfg(not(unix))]
|
||||
async fn local_interface_ips_async() -> Vec<IpAddr> {
|
||||
Vec::new()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
static LOCAL_INTERFACE_ENUMERATIONS: AtomicUsize = AtomicUsize::new(0);
|
||||
|
||||
#[cfg(test)]
|
||||
fn reset_local_interface_enumerations_for_tests() {
|
||||
LOCAL_INTERFACE_ENUMERATIONS.store(0, Ordering::Relaxed);
|
||||
|
||||
#[cfg(unix)]
|
||||
if let Some(cache) = LOCAL_INTERFACE_CACHE.get() {
|
||||
let mut guard = cache.lock().unwrap_or_else(|poison| poison.into_inner());
|
||||
guard.ips.clear();
|
||||
guard.refreshed_at = None;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn local_interface_enumerations_for_tests() -> usize {
|
||||
LOCAL_INTERFACE_ENUMERATIONS.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
fn is_mask_target_local_listener_with_interfaces(
|
||||
mask_host: &str,
|
||||
mask_port: u16,
|
||||
local_addr: SocketAddr,
|
||||
resolved_override: Option<SocketAddr>,
|
||||
interface_ips: &[IpAddr],
|
||||
) -> bool {
|
||||
if mask_port != local_addr.port() {
|
||||
return false;
|
||||
}
|
||||
|
||||
let local_ip = canonical_ip(local_addr.ip());
|
||||
let literal_mask_ip = parse_mask_host_ip_literal(mask_host).map(canonical_ip);
|
||||
|
||||
if let Some(addr) = resolved_override {
|
||||
let resolved_ip = canonical_ip(addr.ip());
|
||||
if resolved_ip == local_ip {
|
||||
return true;
|
||||
}
|
||||
|
||||
if local_ip.is_unspecified()
|
||||
&& (resolved_ip.is_loopback()
|
||||
|| resolved_ip.is_unspecified()
|
||||
|| interface_ips.contains(&resolved_ip))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(mask_ip) = literal_mask_ip {
|
||||
if mask_ip == local_ip {
|
||||
return true;
|
||||
}
|
||||
|
||||
if local_ip.is_unspecified()
|
||||
&& (mask_ip.is_loopback()
|
||||
|| mask_ip.is_unspecified()
|
||||
|| interface_ips.contains(&mask_ip))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn is_mask_target_local_listener(
|
||||
mask_host: &str,
|
||||
mask_port: u16,
|
||||
local_addr: SocketAddr,
|
||||
resolved_override: Option<SocketAddr>,
|
||||
) -> bool {
|
||||
if mask_port != local_addr.port() {
|
||||
return false;
|
||||
}
|
||||
|
||||
let interfaces = local_interface_ips();
|
||||
is_mask_target_local_listener_with_interfaces(
|
||||
mask_host,
|
||||
mask_port,
|
||||
local_addr,
|
||||
resolved_override,
|
||||
&interfaces,
|
||||
)
|
||||
}
|
||||
|
||||
async fn is_mask_target_local_listener_async(
|
||||
mask_host: &str,
|
||||
mask_port: u16,
|
||||
local_addr: SocketAddr,
|
||||
resolved_override: Option<SocketAddr>,
|
||||
) -> bool {
|
||||
if mask_port != local_addr.port() {
|
||||
return false;
|
||||
}
|
||||
|
||||
let interfaces = local_interface_ips_async().await;
|
||||
is_mask_target_local_listener_with_interfaces(
|
||||
mask_host,
|
||||
mask_port,
|
||||
local_addr,
|
||||
resolved_override,
|
||||
&interfaces,
|
||||
)
|
||||
}
|
||||
|
||||
fn masking_beobachten_ttl(config: &ProxyConfig) -> Duration {
|
||||
let minutes = config.general.beobachten_minutes;
|
||||
let clamped = minutes.clamp(1, 24 * 60);
|
||||
Duration::from_secs(clamped.saturating_mul(60))
|
||||
}
|
||||
|
||||
fn build_mask_proxy_header(
|
||||
version: u8,
|
||||
peer: SocketAddr,
|
||||
|
|
@ -290,13 +595,14 @@ pub async fn handle_bad_client<R, W>(
|
|||
{
|
||||
let client_type = detect_client_type(initial_data);
|
||||
if config.general.beobachten {
|
||||
let ttl = Duration::from_secs(config.general.beobachten_minutes.saturating_mul(60));
|
||||
let ttl = masking_beobachten_ttl(config);
|
||||
beobachten.record(client_type, peer.ip(), ttl);
|
||||
}
|
||||
|
||||
if !config.censorship.mask {
|
||||
// Masking disabled, just consume data
|
||||
consume_client_data_with_timeout(reader).await;
|
||||
consume_client_data_with_timeout_and_cap(reader, config.censorship.mask_relay_max_bytes)
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -341,6 +647,7 @@ pub async fn handle_bad_client<R, W>(
|
|||
config.censorship.mask_shape_above_cap_blur,
|
||||
config.censorship.mask_shape_above_cap_blur_max_bytes,
|
||||
config.censorship.mask_shape_hardening_aggressive_mode,
|
||||
config.censorship.mask_relay_max_bytes,
|
||||
),
|
||||
)
|
||||
.await
|
||||
|
|
@ -353,12 +660,12 @@ pub async fn handle_bad_client<R, W>(
|
|||
Ok(Err(e)) => {
|
||||
wait_mask_connect_budget_if_needed(connect_started, config).await;
|
||||
debug!(error = %e, "Failed to connect to mask unix socket");
|
||||
consume_client_data_with_timeout(reader).await;
|
||||
consume_client_data_with_timeout_and_cap(reader, config.censorship.mask_relay_max_bytes).await;
|
||||
wait_mask_outcome_budget(outcome_started, config).await;
|
||||
}
|
||||
Err(_) => {
|
||||
debug!("Timeout connecting to mask unix socket");
|
||||
consume_client_data_with_timeout(reader).await;
|
||||
consume_client_data_with_timeout_and_cap(reader, config.censorship.mask_relay_max_bytes).await;
|
||||
wait_mask_outcome_budget(outcome_started, config).await;
|
||||
}
|
||||
}
|
||||
|
|
@ -372,6 +679,28 @@ pub async fn handle_bad_client<R, W>(
|
|||
.unwrap_or(&config.censorship.tls_domain);
|
||||
let mask_port = config.censorship.mask_port;
|
||||
|
||||
// Fail closed when fallback points at our own listener endpoint.
|
||||
// Self-referential masking can create recursive proxy loops under
|
||||
// misconfiguration and leak distinguishable load spikes to adversaries.
|
||||
let resolved_mask_addr = resolve_socket_addr(mask_host, mask_port);
|
||||
if is_mask_target_local_listener_async(mask_host, mask_port, local_addr, resolved_mask_addr)
|
||||
.await
|
||||
{
|
||||
let outcome_started = Instant::now();
|
||||
debug!(
|
||||
client_type = client_type,
|
||||
host = %mask_host,
|
||||
port = mask_port,
|
||||
local = %local_addr,
|
||||
"Mask target resolves to local listener; refusing self-referential masking fallback"
|
||||
);
|
||||
consume_client_data_with_timeout_and_cap(reader, config.censorship.mask_relay_max_bytes).await;
|
||||
wait_mask_outcome_budget(outcome_started, config).await;
|
||||
return;
|
||||
}
|
||||
|
||||
let outcome_started = Instant::now();
|
||||
|
||||
debug!(
|
||||
client_type = client_type,
|
||||
host = %mask_host,
|
||||
|
|
@ -381,10 +710,9 @@ pub async fn handle_bad_client<R, W>(
|
|||
);
|
||||
|
||||
// Apply runtime DNS override for mask target when configured.
|
||||
let mask_addr = resolve_socket_addr(mask_host, mask_port)
|
||||
let mask_addr = resolved_mask_addr
|
||||
.map(|addr| addr.to_string())
|
||||
.unwrap_or_else(|| format!("{}:{}", mask_host, mask_port));
|
||||
let outcome_started = Instant::now();
|
||||
let connect_started = Instant::now();
|
||||
let connect_result = timeout(MASK_TIMEOUT, TcpStream::connect(&mask_addr)).await;
|
||||
match connect_result {
|
||||
|
|
@ -413,6 +741,7 @@ pub async fn handle_bad_client<R, W>(
|
|||
config.censorship.mask_shape_above_cap_blur,
|
||||
config.censorship.mask_shape_above_cap_blur_max_bytes,
|
||||
config.censorship.mask_shape_hardening_aggressive_mode,
|
||||
config.censorship.mask_relay_max_bytes,
|
||||
),
|
||||
)
|
||||
.await
|
||||
|
|
@ -425,12 +754,12 @@ pub async fn handle_bad_client<R, W>(
|
|||
Ok(Err(e)) => {
|
||||
wait_mask_connect_budget_if_needed(connect_started, config).await;
|
||||
debug!(error = %e, "Failed to connect to mask host");
|
||||
consume_client_data_with_timeout(reader).await;
|
||||
consume_client_data_with_timeout_and_cap(reader, config.censorship.mask_relay_max_bytes).await;
|
||||
wait_mask_outcome_budget(outcome_started, config).await;
|
||||
}
|
||||
Err(_) => {
|
||||
debug!("Timeout connecting to mask host");
|
||||
consume_client_data_with_timeout(reader).await;
|
||||
consume_client_data_with_timeout_and_cap(reader, config.censorship.mask_relay_max_bytes).await;
|
||||
wait_mask_outcome_budget(outcome_started, config).await;
|
||||
}
|
||||
}
|
||||
|
|
@ -449,6 +778,7 @@ async fn relay_to_mask<R, W, MR, MW>(
|
|||
shape_above_cap_blur: bool,
|
||||
shape_above_cap_blur_max_bytes: usize,
|
||||
shape_hardening_aggressive_mode: bool,
|
||||
mask_relay_max_bytes: usize,
|
||||
) where
|
||||
R: AsyncRead + Unpin + Send + 'static,
|
||||
W: AsyncWrite + Unpin + Send + 'static,
|
||||
|
|
@ -464,8 +794,18 @@ async fn relay_to_mask<R, W, MR, MW>(
|
|||
}
|
||||
|
||||
let (upstream_copy, downstream_copy) = tokio::join!(
|
||||
async { copy_with_idle_timeout(&mut reader, &mut mask_write).await },
|
||||
async { copy_with_idle_timeout(&mut mask_read, &mut writer).await }
|
||||
async {
|
||||
copy_with_idle_timeout(
|
||||
&mut reader,
|
||||
&mut mask_write,
|
||||
mask_relay_max_bytes,
|
||||
!shape_hardening_enabled,
|
||||
)
|
||||
.await
|
||||
},
|
||||
async {
|
||||
copy_with_idle_timeout(&mut mask_read, &mut writer, mask_relay_max_bytes, true).await
|
||||
}
|
||||
);
|
||||
|
||||
let total_sent = initial_data.len().saturating_add(upstream_copy.total);
|
||||
|
|
@ -491,13 +831,36 @@ async fn relay_to_mask<R, W, MR, MW>(
|
|||
let _ = writer.shutdown().await;
|
||||
}
|
||||
|
||||
/// Just consume all data from client without responding
|
||||
async fn consume_client_data<R: AsyncRead + Unpin>(mut reader: R) {
|
||||
let mut buf = vec![0u8; MASK_BUFFER_SIZE];
|
||||
while let Ok(n) = reader.read(&mut buf).await {
|
||||
/// Just consume all data from client without responding.
|
||||
async fn consume_client_data<R: AsyncRead + Unpin>(mut reader: R, byte_cap: usize) {
|
||||
if byte_cap == 0 {
|
||||
return;
|
||||
}
|
||||
|
||||
// Keep drain path fail-closed under slow-loris stalls.
|
||||
let mut buf = [0u8; MASK_BUFFER_SIZE];
|
||||
let mut total = 0usize;
|
||||
|
||||
loop {
|
||||
let remaining_budget = byte_cap.saturating_sub(total);
|
||||
if remaining_budget == 0 {
|
||||
break;
|
||||
}
|
||||
|
||||
let read_len = remaining_budget.min(MASK_BUFFER_SIZE);
|
||||
let n = match timeout(MASK_RELAY_IDLE_TIMEOUT, reader.read(&mut buf[..read_len])).await {
|
||||
Ok(Ok(n)) => n,
|
||||
Ok(Err(_)) | Err(_) => break,
|
||||
};
|
||||
|
||||
if n == 0 {
|
||||
break;
|
||||
}
|
||||
|
||||
total = total.saturating_add(n);
|
||||
if total >= byte_cap {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -521,6 +884,10 @@ mod masking_shape_above_cap_blur_security_tests;
|
|||
#[path = "tests/masking_timing_normalization_security_tests.rs"]
|
||||
mod masking_timing_normalization_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/masking_timing_budget_coupling_security_tests.rs"]
|
||||
mod masking_timing_budget_coupling_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/masking_ab_envelope_blur_integration_security_tests.rs"]
|
||||
mod masking_ab_envelope_blur_integration_security_tests;
|
||||
|
|
@ -548,3 +915,75 @@ mod masking_aggressive_mode_security_tests;
|
|||
#[cfg(test)]
|
||||
#[path = "tests/masking_timing_sidechannel_redteam_expected_fail_tests.rs"]
|
||||
mod masking_timing_sidechannel_redteam_expected_fail_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/masking_self_target_loop_security_tests.rs"]
|
||||
mod masking_self_target_loop_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/masking_classification_completeness_security_tests.rs"]
|
||||
mod masking_classification_completeness_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/masking_relay_guardrails_security_tests.rs"]
|
||||
mod masking_relay_guardrails_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/masking_connect_failure_close_matrix_security_tests.rs"]
|
||||
mod masking_connect_failure_close_matrix_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/masking_additional_hardening_security_tests.rs"]
|
||||
mod masking_additional_hardening_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/masking_consume_idle_timeout_security_tests.rs"]
|
||||
mod masking_consume_idle_timeout_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/masking_http2_probe_classification_security_tests.rs"]
|
||||
mod masking_http2_probe_classification_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/masking_http_probe_boundary_security_tests.rs"]
|
||||
mod masking_http_probe_boundary_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/masking_rng_hoist_perf_regression_tests.rs"]
|
||||
mod masking_rng_hoist_perf_regression_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/masking_http2_preface_integration_security_tests.rs"]
|
||||
mod masking_http2_preface_integration_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/masking_consume_stress_adversarial_tests.rs"]
|
||||
mod masking_consume_stress_adversarial_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/masking_interface_cache_security_tests.rs"]
|
||||
mod masking_interface_cache_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/masking_interface_cache_defense_in_depth_security_tests.rs"]
|
||||
mod masking_interface_cache_defense_in_depth_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/masking_interface_cache_concurrency_security_tests.rs"]
|
||||
mod masking_interface_cache_concurrency_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/masking_production_cap_regression_security_tests.rs"]
|
||||
mod masking_production_cap_regression_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/masking_extended_attack_surface_security_tests.rs"]
|
||||
mod masking_extended_attack_surface_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/masking_padding_timeout_adversarial_tests.rs"]
|
||||
mod masking_padding_timeout_adversarial_tests;
|
||||
|
||||
#[cfg(all(test, feature = "redteam_offline_expected_fail"))]
|
||||
#[path = "tests/masking_offline_target_redteam_expected_fail_tests.rs"]
|
||||
mod masking_offline_target_redteam_expected_fail_tests;
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
use std::collections::hash_map::RandomState;
|
||||
use std::collections::{BTreeSet, HashMap};
|
||||
#[cfg(test)]
|
||||
use std::future::Future;
|
||||
use std::hash::{BuildHasher, Hash};
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
|
||||
|
|
@ -39,10 +41,14 @@ const C2ME_CHANNEL_CAPACITY_FALLBACK: usize = 128;
|
|||
const C2ME_SOFT_PRESSURE_MIN_FREE_SLOTS: usize = 64;
|
||||
const C2ME_SENDER_FAIRNESS_BUDGET: usize = 32;
|
||||
const RELAY_IDLE_IO_POLL_MAX: Duration = Duration::from_secs(1);
|
||||
const TINY_FRAME_DEBT_PER_TINY: u32 = 8;
|
||||
const TINY_FRAME_DEBT_LIMIT: u32 = 512;
|
||||
#[cfg(test)]
|
||||
const C2ME_SEND_TIMEOUT: Duration = Duration::from_millis(50);
|
||||
#[cfg(not(test))]
|
||||
const C2ME_SEND_TIMEOUT: Duration = Duration::from_secs(5);
|
||||
#[cfg(test)]
|
||||
const RELAY_TEST_STEP_TIMEOUT: Duration = Duration::from_secs(1);
|
||||
const ME_D2C_FLUSH_BATCH_MAX_FRAMES_MIN: usize = 1;
|
||||
const ME_D2C_FLUSH_BATCH_MAX_BYTES_MIN: usize = 4096;
|
||||
const ME_D2C_FRAME_BUF_SHRINK_HYSTERESIS_FACTOR: usize = 2;
|
||||
|
|
@ -94,10 +100,23 @@ fn relay_idle_candidate_registry() -> &'static Mutex<RelayIdleCandidateRegistry>
|
|||
RELAY_IDLE_CANDIDATE_REGISTRY.get_or_init(|| Mutex::new(RelayIdleCandidateRegistry::default()))
|
||||
}
|
||||
|
||||
fn relay_idle_candidate_registry_lock() -> std::sync::MutexGuard<'static, RelayIdleCandidateRegistry> {
|
||||
let registry = relay_idle_candidate_registry();
|
||||
match registry.lock() {
|
||||
Ok(guard) => guard,
|
||||
Err(poisoned) => {
|
||||
let mut guard = poisoned.into_inner();
|
||||
// Fail closed after panic while holding registry lock: drop all
|
||||
// candidates and pressure cursors to avoid stale cross-session state.
|
||||
*guard = RelayIdleCandidateRegistry::default();
|
||||
registry.clear_poison();
|
||||
guard
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn mark_relay_idle_candidate(conn_id: u64) -> bool {
|
||||
let Ok(mut guard) = relay_idle_candidate_registry().lock() else {
|
||||
return false;
|
||||
};
|
||||
let mut guard = relay_idle_candidate_registry_lock();
|
||||
|
||||
if guard.by_conn_id.contains_key(&conn_id) {
|
||||
return false;
|
||||
|
|
@ -116,9 +135,7 @@ fn mark_relay_idle_candidate(conn_id: u64) -> bool {
|
|||
}
|
||||
|
||||
fn clear_relay_idle_candidate(conn_id: u64) {
|
||||
let Ok(mut guard) = relay_idle_candidate_registry().lock() else {
|
||||
return;
|
||||
};
|
||||
let mut guard = relay_idle_candidate_registry_lock();
|
||||
|
||||
if let Some(meta) = guard.by_conn_id.remove(&conn_id) {
|
||||
guard.ordered.remove(&(meta.mark_order_seq, conn_id));
|
||||
|
|
@ -127,23 +144,17 @@ fn clear_relay_idle_candidate(conn_id: u64) {
|
|||
|
||||
#[cfg(test)]
|
||||
fn oldest_relay_idle_candidate() -> Option<u64> {
|
||||
let Ok(guard) = relay_idle_candidate_registry().lock() else {
|
||||
return None;
|
||||
};
|
||||
let guard = relay_idle_candidate_registry_lock();
|
||||
guard.ordered.iter().next().map(|(_, conn_id)| *conn_id)
|
||||
}
|
||||
|
||||
fn note_relay_pressure_event() {
|
||||
let Ok(mut guard) = relay_idle_candidate_registry().lock() else {
|
||||
return;
|
||||
};
|
||||
let mut guard = relay_idle_candidate_registry_lock();
|
||||
guard.pressure_event_seq = guard.pressure_event_seq.wrapping_add(1);
|
||||
}
|
||||
|
||||
fn relay_pressure_event_seq() -> u64 {
|
||||
let Ok(guard) = relay_idle_candidate_registry().lock() else {
|
||||
return 0;
|
||||
};
|
||||
let guard = relay_idle_candidate_registry_lock();
|
||||
guard.pressure_event_seq
|
||||
}
|
||||
|
||||
|
|
@ -152,9 +163,7 @@ fn maybe_evict_idle_candidate_on_pressure(
|
|||
seen_pressure_seq: &mut u64,
|
||||
stats: &Stats,
|
||||
) -> bool {
|
||||
let Ok(mut guard) = relay_idle_candidate_registry().lock() else {
|
||||
return false;
|
||||
};
|
||||
let mut guard = relay_idle_candidate_registry_lock();
|
||||
|
||||
let latest_pressure_seq = guard.pressure_event_seq;
|
||||
if latest_pressure_seq == *seen_pressure_seq {
|
||||
|
|
@ -199,13 +208,9 @@ fn maybe_evict_idle_candidate_on_pressure(
|
|||
|
||||
#[cfg(test)]
|
||||
fn clear_relay_idle_pressure_state_for_testing() {
|
||||
if let Some(registry) = RELAY_IDLE_CANDIDATE_REGISTRY.get()
|
||||
&& let Ok(mut guard) = registry.lock()
|
||||
{
|
||||
guard.by_conn_id.clear();
|
||||
guard.ordered.clear();
|
||||
guard.pressure_event_seq = 0;
|
||||
guard.pressure_consumed_seq = 0;
|
||||
if RELAY_IDLE_CANDIDATE_REGISTRY.get().is_some() {
|
||||
let mut guard = relay_idle_candidate_registry_lock();
|
||||
*guard = RelayIdleCandidateRegistry::default();
|
||||
}
|
||||
RELAY_IDLE_MARK_SEQ.store(0, Ordering::Relaxed);
|
||||
}
|
||||
|
|
@ -259,6 +264,7 @@ impl RelayClientIdlePolicy {
|
|||
struct RelayClientIdleState {
|
||||
last_client_frame_at: Instant,
|
||||
soft_idle_marked: bool,
|
||||
tiny_frame_debt: u32,
|
||||
}
|
||||
|
||||
impl RelayClientIdleState {
|
||||
|
|
@ -266,6 +272,7 @@ impl RelayClientIdleState {
|
|||
Self {
|
||||
last_client_frame_at: now,
|
||||
soft_idle_marked: false,
|
||||
tiny_frame_debt: 0,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -535,6 +542,7 @@ fn quota_exceeded_for_user(stats: &Stats, user: &str, quota_limit: Option<u64>)
|
|||
quota_limit.is_some_and(|quota| stats.get_user_total_octets(user) >= quota)
|
||||
}
|
||||
|
||||
#[cfg_attr(not(test), allow(dead_code))]
|
||||
fn quota_would_be_exceeded_for_user(
|
||||
stats: &Stats,
|
||||
user: &str,
|
||||
|
|
@ -551,15 +559,6 @@ fn quota_soft_cap(limit: u64, overshoot: u64) -> u64 {
|
|||
limit.saturating_add(overshoot)
|
||||
}
|
||||
|
||||
fn quota_exceeded_for_user_soft(
|
||||
stats: &Stats,
|
||||
user: &str,
|
||||
quota_limit: Option<u64>,
|
||||
overshoot: u64,
|
||||
) -> bool {
|
||||
quota_limit.is_some_and(|quota| stats.get_user_total_octets(user) >= quota_soft_cap(quota, overshoot))
|
||||
}
|
||||
|
||||
fn quota_would_be_exceeded_for_user_soft(
|
||||
stats: &Stats,
|
||||
user: &str,
|
||||
|
|
@ -567,11 +566,8 @@ fn quota_would_be_exceeded_for_user_soft(
|
|||
bytes: u64,
|
||||
overshoot: u64,
|
||||
) -> bool {
|
||||
quota_limit.is_some_and(|quota| {
|
||||
let cap = quota_soft_cap(quota, overshoot);
|
||||
let used = stats.get_user_total_octets(user);
|
||||
used >= cap || bytes > cap.saturating_sub(used)
|
||||
})
|
||||
let capped_limit = quota_limit.map(|quota| quota_soft_cap(quota, overshoot));
|
||||
quota_would_be_exceeded_for_user(stats, user, capped_limit, bytes)
|
||||
}
|
||||
|
||||
fn classify_me_d2c_flush_reason(
|
||||
|
|
@ -617,6 +613,16 @@ fn observe_me_d2c_flush_event(
|
|||
}
|
||||
}
|
||||
|
||||
fn rollback_me2c_quota_reservation(
|
||||
stats: &Stats,
|
||||
user: &str,
|
||||
bytes_me2c: &AtomicU64,
|
||||
reserved_bytes: u64,
|
||||
) {
|
||||
stats.sub_user_octets_to(user, reserved_bytes);
|
||||
bytes_me2c.fetch_sub(reserved_bytes, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn quota_user_lock_test_guard() -> &'static Mutex<()> {
|
||||
static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
|
|
@ -630,6 +636,19 @@ fn quota_user_lock_test_scope() -> std::sync::MutexGuard<'static, ()> {
|
|||
.unwrap_or_else(|poisoned| poisoned.into_inner())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn relay_idle_pressure_test_guard() -> &'static Mutex<()> {
|
||||
static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
TEST_LOCK.get_or_init(|| Mutex::new(()))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn relay_idle_pressure_test_scope() -> std::sync::MutexGuard<'static, ()> {
|
||||
relay_idle_pressure_test_guard()
|
||||
.lock()
|
||||
.unwrap_or_else(|poisoned| poisoned.into_inner())
|
||||
}
|
||||
|
||||
fn quota_overflow_user_lock(user: &str) -> Arc<AsyncMutex<()>> {
|
||||
let stripes = QUOTA_USER_OVERFLOW_LOCKS.get_or_init(|| {
|
||||
(0..QUOTA_OVERFLOW_LOCK_STRIPES)
|
||||
|
|
@ -665,6 +684,11 @@ fn quota_user_lock(user: &str) -> Arc<AsyncMutex<()>> {
|
|||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn cross_mode_quota_user_lock_for_tests(user: &str) -> Arc<AsyncMutex<()>> {
|
||||
crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(user)
|
||||
}
|
||||
|
||||
async fn enqueue_c2me_command(
|
||||
tx: &mpsc::Sender<C2MeCommand>,
|
||||
cmd: C2MeCommand,
|
||||
|
|
@ -690,6 +714,16 @@ async fn enqueue_c2me_command(
|
|||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
async fn run_relay_test_step_timeout<F, T>(context: &'static str, fut: F) -> T
|
||||
where
|
||||
F: Future<Output = T>,
|
||||
{
|
||||
timeout(RELAY_TEST_STEP_TIMEOUT, fut)
|
||||
.await
|
||||
.unwrap_or_else(|_| panic!("{context} exceeded {}s", RELAY_TEST_STEP_TIMEOUT.as_secs()))
|
||||
}
|
||||
|
||||
pub(crate) async fn handle_via_middle_proxy<R, W>(
|
||||
mut crypto_reader: CryptoReader<R>,
|
||||
crypto_writer: CryptoWriter<W>,
|
||||
|
|
@ -710,6 +744,8 @@ where
|
|||
{
|
||||
let user = success.user.clone();
|
||||
let quota_limit = config.access.user_data_quota.get(&user).copied();
|
||||
let cross_mode_quota_lock =
|
||||
quota_limit.map(|_| crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user));
|
||||
let peer = success.peer;
|
||||
let proto_tag = success.proto_tag;
|
||||
let pool_generation = me_pool.current_generation();
|
||||
|
|
@ -836,6 +872,7 @@ where
|
|||
let stats_clone = stats.clone();
|
||||
let rng_clone = rng.clone();
|
||||
let user_clone = user.clone();
|
||||
let cross_mode_quota_lock_me_writer = cross_mode_quota_lock.clone();
|
||||
let last_downstream_activity_ms_clone = last_downstream_activity_ms.clone();
|
||||
let bytes_me2c_clone = bytes_me2c.clone();
|
||||
let d2c_flush_policy = MeD2cFlushPolicy::from_config(&config);
|
||||
|
|
@ -857,7 +894,7 @@ where
|
|||
|
||||
let first_is_downstream_activity =
|
||||
matches!(&first, MeResponse::Data { .. } | MeResponse::Ack(_));
|
||||
match process_me_writer_response(
|
||||
match process_me_writer_response_with_cross_mode_lock(
|
||||
first,
|
||||
&mut writer,
|
||||
proto_tag,
|
||||
|
|
@ -867,6 +904,7 @@ where
|
|||
&user_clone,
|
||||
quota_limit,
|
||||
d2c_flush_policy.quota_soft_overshoot_bytes,
|
||||
cross_mode_quota_lock_me_writer.as_ref(),
|
||||
bytes_me2c_clone.as_ref(),
|
||||
conn_id,
|
||||
d2c_flush_policy.ack_flush_immediate,
|
||||
|
|
@ -915,7 +953,7 @@ where
|
|||
|
||||
let next_is_downstream_activity =
|
||||
matches!(&next, MeResponse::Data { .. } | MeResponse::Ack(_));
|
||||
match process_me_writer_response(
|
||||
match process_me_writer_response_with_cross_mode_lock(
|
||||
next,
|
||||
&mut writer,
|
||||
proto_tag,
|
||||
|
|
@ -925,6 +963,7 @@ where
|
|||
&user_clone,
|
||||
quota_limit,
|
||||
d2c_flush_policy.quota_soft_overshoot_bytes,
|
||||
cross_mode_quota_lock_me_writer.as_ref(),
|
||||
bytes_me2c_clone.as_ref(),
|
||||
conn_id,
|
||||
d2c_flush_policy.ack_flush_immediate,
|
||||
|
|
@ -976,7 +1015,7 @@ where
|
|||
Ok(Some(next)) => {
|
||||
let next_is_downstream_activity =
|
||||
matches!(&next, MeResponse::Data { .. } | MeResponse::Ack(_));
|
||||
match process_me_writer_response(
|
||||
match process_me_writer_response_with_cross_mode_lock(
|
||||
next,
|
||||
&mut writer,
|
||||
proto_tag,
|
||||
|
|
@ -986,6 +1025,7 @@ where
|
|||
&user_clone,
|
||||
quota_limit,
|
||||
d2c_flush_policy.quota_soft_overshoot_bytes,
|
||||
cross_mode_quota_lock_me_writer.as_ref(),
|
||||
bytes_me2c_clone.as_ref(),
|
||||
conn_id,
|
||||
d2c_flush_policy.ack_flush_immediate,
|
||||
|
|
@ -1039,7 +1079,7 @@ where
|
|||
|
||||
let extra_is_downstream_activity =
|
||||
matches!(&extra, MeResponse::Data { .. } | MeResponse::Ack(_));
|
||||
match process_me_writer_response(
|
||||
match process_me_writer_response_with_cross_mode_lock(
|
||||
extra,
|
||||
&mut writer,
|
||||
proto_tag,
|
||||
|
|
@ -1049,6 +1089,7 @@ where
|
|||
&user_clone,
|
||||
quota_limit,
|
||||
d2c_flush_policy.quota_soft_overshoot_bytes,
|
||||
cross_mode_quota_lock_me_writer.as_ref(),
|
||||
bytes_me2c_clone.as_ref(),
|
||||
conn_id,
|
||||
d2c_flush_policy.ack_flush_immediate,
|
||||
|
|
@ -1221,6 +1262,14 @@ where
|
|||
if let Some(limit) = quota_limit {
|
||||
let quota_lock = quota_user_lock(&user);
|
||||
let _quota_guard = quota_lock.lock().await;
|
||||
let Some(cross_mode_lock) = cross_mode_quota_lock.as_ref() else {
|
||||
main_result = Err(ProxyError::Proxy(
|
||||
"cross-mode quota lock missing for quota-limited session"
|
||||
.to_string(),
|
||||
));
|
||||
break;
|
||||
};
|
||||
let _cross_mode_quota_guard = cross_mode_lock.lock().await;
|
||||
stats.add_user_octets_from(&user, payload.len() as u64);
|
||||
if quota_exceeded_for_user(stats.as_ref(), &user, Some(limit)) {
|
||||
main_result = Err(ProxyError::DataQuotaExceeded {
|
||||
|
|
@ -1320,6 +1369,8 @@ async fn read_client_payload_with_idle_policy<R>(
|
|||
where
|
||||
R: AsyncRead + Unpin + Send + 'static,
|
||||
{
|
||||
const LEGACY_MAX_CONSECUTIVE_ZERO_LEN_FRAMES: u32 = 4;
|
||||
|
||||
async fn read_exact_with_policy<R>(
|
||||
client_reader: &mut CryptoReader<R>,
|
||||
buf: &mut [u8],
|
||||
|
|
@ -1458,6 +1509,7 @@ where
|
|||
Ok(())
|
||||
}
|
||||
|
||||
let mut consecutive_zero_len_frames = 0u32;
|
||||
loop {
|
||||
let (len, quickack, raw_len_bytes) = match proto_tag {
|
||||
ProtoTag::Abridged => {
|
||||
|
|
@ -1538,6 +1590,27 @@ where
|
|||
};
|
||||
|
||||
if len == 0 {
|
||||
idle_state.tiny_frame_debt = idle_state
|
||||
.tiny_frame_debt
|
||||
.saturating_add(TINY_FRAME_DEBT_PER_TINY);
|
||||
if idle_state.tiny_frame_debt >= TINY_FRAME_DEBT_LIMIT {
|
||||
stats.increment_relay_protocol_desync_close_total();
|
||||
return Err(ProxyError::Proxy(format!(
|
||||
"Tiny frame overhead limit exceeded: debt={}, conn_id={}",
|
||||
idle_state.tiny_frame_debt, forensics.conn_id
|
||||
)));
|
||||
}
|
||||
|
||||
if !idle_policy.enabled {
|
||||
consecutive_zero_len_frames =
|
||||
consecutive_zero_len_frames.saturating_add(1);
|
||||
if consecutive_zero_len_frames > LEGACY_MAX_CONSECUTIVE_ZERO_LEN_FRAMES {
|
||||
stats.increment_relay_protocol_desync_close_total();
|
||||
return Err(ProxyError::Proxy(
|
||||
"Excessive zero-length abridged frames".to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if len < 4 && proto_tag != ProtoTag::Abridged {
|
||||
|
|
@ -1606,6 +1679,7 @@ where
|
|||
}
|
||||
*frame_counter += 1;
|
||||
idle_state.on_client_frame(Instant::now());
|
||||
idle_state.tiny_frame_debt = idle_state.tiny_frame_debt.saturating_sub(1);
|
||||
clear_relay_idle_candidate(forensics.conn_id);
|
||||
return Ok(Some((payload, quickack)));
|
||||
}
|
||||
|
|
@ -1681,6 +1755,7 @@ enum MeWriterResponseOutcome {
|
|||
Close,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
async fn process_me_writer_response<W>(
|
||||
response: MeResponse,
|
||||
client_writer: &mut CryptoWriter<W>,
|
||||
|
|
@ -1696,6 +1771,44 @@ async fn process_me_writer_response<W>(
|
|||
ack_flush_immediate: bool,
|
||||
batched: bool,
|
||||
) -> Result<MeWriterResponseOutcome>
|
||||
where
|
||||
W: AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
process_me_writer_response_with_cross_mode_lock(
|
||||
response,
|
||||
client_writer,
|
||||
proto_tag,
|
||||
rng,
|
||||
frame_buf,
|
||||
stats,
|
||||
user,
|
||||
quota_limit,
|
||||
quota_soft_overshoot_bytes,
|
||||
None,
|
||||
bytes_me2c,
|
||||
conn_id,
|
||||
ack_flush_immediate,
|
||||
batched,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn process_me_writer_response_with_cross_mode_lock<W>(
|
||||
response: MeResponse,
|
||||
client_writer: &mut CryptoWriter<W>,
|
||||
proto_tag: ProtoTag,
|
||||
rng: &SecureRandom,
|
||||
frame_buf: &mut Vec<u8>,
|
||||
stats: &Stats,
|
||||
user: &str,
|
||||
quota_limit: Option<u64>,
|
||||
quota_soft_overshoot_bytes: u64,
|
||||
cross_mode_quota_lock: Option<&Arc<AsyncMutex<()>>>,
|
||||
bytes_me2c: &AtomicU64,
|
||||
conn_id: u64,
|
||||
ack_flush_immediate: bool,
|
||||
batched: bool,
|
||||
) -> Result<MeWriterResponseOutcome>
|
||||
where
|
||||
W: AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
|
|
@ -1707,39 +1820,76 @@ where
|
|||
trace!(conn_id, bytes = data.len(), flags, "ME->C data");
|
||||
}
|
||||
let data_len = data.len() as u64;
|
||||
if quota_would_be_exceeded_for_user_soft(
|
||||
stats,
|
||||
user,
|
||||
quota_limit,
|
||||
data_len,
|
||||
quota_soft_overshoot_bytes,
|
||||
) {
|
||||
stats.increment_me_d2c_quota_reject_total(MeD2cQuotaRejectStage::PreWrite);
|
||||
return Err(ProxyError::DataQuotaExceeded {
|
||||
user: user.to_string(),
|
||||
});
|
||||
}
|
||||
if let Some(limit) = quota_limit {
|
||||
let owned_cross_mode_lock;
|
||||
let cross_mode_lock = if let Some(lock) = cross_mode_quota_lock {
|
||||
lock
|
||||
} else {
|
||||
owned_cross_mode_lock =
|
||||
crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(user);
|
||||
&owned_cross_mode_lock
|
||||
};
|
||||
let cross_mode_quota_guard = cross_mode_lock.lock().await;
|
||||
let soft_limit = quota_soft_cap(limit, quota_soft_overshoot_bytes);
|
||||
if quota_would_be_exceeded_for_user_soft(
|
||||
stats,
|
||||
user,
|
||||
Some(limit),
|
||||
data_len,
|
||||
quota_soft_overshoot_bytes,
|
||||
) {
|
||||
stats.increment_me_d2c_quota_reject_total(MeD2cQuotaRejectStage::PreWrite);
|
||||
return Err(ProxyError::DataQuotaExceeded {
|
||||
user: user.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
let write_mode =
|
||||
write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf)
|
||||
.await?;
|
||||
stats.increment_me_d2c_write_mode(write_mode);
|
||||
// Reserve quota before awaiting network I/O to avoid same-user HoL stalls.
|
||||
// If reservation loses a race or write fails, we roll back immediately.
|
||||
bytes_me2c.fetch_add(data_len, Ordering::Relaxed);
|
||||
stats.add_user_octets_to(user, data_len);
|
||||
|
||||
bytes_me2c.fetch_add(data.len() as u64, Ordering::Relaxed);
|
||||
stats.add_user_octets_to(user, data.len() as u64);
|
||||
stats.increment_me_d2c_data_frames_total();
|
||||
stats.add_me_d2c_payload_bytes_total(data.len() as u64);
|
||||
if stats.get_user_total_octets(user) > soft_limit {
|
||||
rollback_me2c_quota_reservation(stats, user, bytes_me2c, data_len);
|
||||
stats.increment_me_d2c_quota_reject_total(MeD2cQuotaRejectStage::PreWrite);
|
||||
return Err(ProxyError::DataQuotaExceeded {
|
||||
user: user.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
if quota_exceeded_for_user_soft(
|
||||
stats,
|
||||
user,
|
||||
quota_limit,
|
||||
quota_soft_overshoot_bytes,
|
||||
) {
|
||||
stats.increment_me_d2c_quota_reject_total(MeD2cQuotaRejectStage::PostWrite);
|
||||
return Err(ProxyError::DataQuotaExceeded {
|
||||
user: user.to_string(),
|
||||
});
|
||||
// Keep cross-mode lock scope explicit and minimal: quota reservation is serialized,
|
||||
// but socket I/O proceeds without holding same-user cross-mode admission lock.
|
||||
drop(cross_mode_quota_guard);
|
||||
|
||||
let write_mode =
|
||||
match write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf)
|
||||
.await
|
||||
{
|
||||
Ok(mode) => mode,
|
||||
Err(err) => {
|
||||
rollback_me2c_quota_reservation(stats, user, bytes_me2c, data_len);
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
|
||||
stats.increment_me_d2c_data_frames_total();
|
||||
stats.add_me_d2c_payload_bytes_total(data_len);
|
||||
stats.increment_me_d2c_write_mode(write_mode);
|
||||
|
||||
// Do not fail immediately on exact boundary after a successful write.
|
||||
// Returning an error here can bypass batch flush in the caller and risk
|
||||
// dropping buffered ciphertext from CryptoWriter. The next frame is
|
||||
// rejected by the pre-check at function entry.
|
||||
} else {
|
||||
let write_mode =
|
||||
write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf)
|
||||
.await?;
|
||||
|
||||
bytes_me2c.fetch_add(data_len, Ordering::Relaxed);
|
||||
stats.add_user_octets_to(user, data_len);
|
||||
stats.increment_me_d2c_data_frames_total();
|
||||
stats.add_me_d2c_payload_bytes_total(data_len);
|
||||
stats.increment_me_d2c_write_mode(write_mode);
|
||||
}
|
||||
|
||||
Ok(MeWriterResponseOutcome::Continue {
|
||||
|
|
@ -1978,3 +2128,55 @@ mod length_cast_hardening_security_tests;
|
|||
#[cfg(test)]
|
||||
#[path = "tests/middle_relay_blackhat_campaign_integration_tests.rs"]
|
||||
mod blackhat_campaign_integration_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/middle_relay_hol_quota_security_tests.rs"]
|
||||
mod hol_quota_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/middle_relay_quota_reservation_adversarial_tests.rs"]
|
||||
mod quota_reservation_adversarial_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/middle_relay_idle_registry_poison_security_tests.rs"]
|
||||
mod middle_relay_idle_registry_poison_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/middle_relay_zero_length_frame_security_tests.rs"]
|
||||
mod middle_relay_zero_length_frame_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/middle_relay_tiny_frame_debt_security_tests.rs"]
|
||||
mod middle_relay_tiny_frame_debt_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/middle_relay_tiny_frame_debt_concurrency_security_tests.rs"]
|
||||
mod middle_relay_tiny_frame_debt_concurrency_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/middle_relay_tiny_frame_debt_proto_chunking_security_tests.rs"]
|
||||
mod middle_relay_tiny_frame_debt_proto_chunking_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/middle_relay_cross_mode_quota_reservation_security_tests.rs"]
|
||||
mod middle_relay_cross_mode_quota_reservation_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/middle_relay_cross_mode_quota_lock_matrix_security_tests.rs"]
|
||||
mod middle_relay_cross_mode_quota_lock_matrix_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/middle_relay_cross_mode_lookup_efficiency_security_tests.rs"]
|
||||
mod middle_relay_cross_mode_lookup_efficiency_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/middle_relay_cross_mode_lock_release_regression_security_tests.rs"]
|
||||
mod middle_relay_cross_mode_lock_release_regression_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/middle_relay_quota_extended_attack_surface_security_tests.rs"]
|
||||
mod middle_relay_quota_extended_attack_surface_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/middle_relay_quota_reservation_extreme_security_tests.rs"]
|
||||
mod middle_relay_quota_reservation_extreme_security_tests;
|
||||
|
|
|
|||
|
|
@ -64,6 +64,7 @@ pub mod direct_relay;
|
|||
pub mod handshake;
|
||||
pub mod masking;
|
||||
pub mod middle_relay;
|
||||
pub mod quota_lock_registry;
|
||||
pub mod relay;
|
||||
pub mod route_mode;
|
||||
pub mod session_eviction;
|
||||
|
|
|
|||
|
|
@ -0,0 +1,88 @@
|
|||
use dashmap::DashMap;
|
||||
use std::sync::{Arc, OnceLock};
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
#[cfg(test)]
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
|
||||
#[cfg(test)]
|
||||
const CROSS_MODE_QUOTA_USER_LOCKS_MAX: usize = 64;
|
||||
#[cfg(not(test))]
|
||||
const CROSS_MODE_QUOTA_USER_LOCKS_MAX: usize = 4_096;
|
||||
#[cfg(test)]
|
||||
const CROSS_MODE_QUOTA_OVERFLOW_LOCK_STRIPES: usize = 16;
|
||||
#[cfg(not(test))]
|
||||
const CROSS_MODE_QUOTA_OVERFLOW_LOCK_STRIPES: usize = 256;
|
||||
|
||||
static CROSS_MODE_QUOTA_USER_LOCKS: OnceLock<DashMap<String, Arc<Mutex<()>>>> = OnceLock::new();
|
||||
static CROSS_MODE_QUOTA_USER_OVERFLOW_LOCKS: OnceLock<Vec<Arc<Mutex<()>>>> = OnceLock::new();
|
||||
|
||||
#[cfg(test)]
|
||||
static CROSS_MODE_QUOTA_USER_LOCK_LOOKUPS: AtomicUsize = AtomicUsize::new(0);
|
||||
#[cfg(test)]
|
||||
static CROSS_MODE_QUOTA_USER_LOOKUPS_BY_USER: OnceLock<DashMap<String, usize>> = OnceLock::new();
|
||||
|
||||
fn cross_mode_quota_overflow_user_lock(user: &str) -> Arc<Mutex<()>> {
|
||||
let stripes = CROSS_MODE_QUOTA_USER_OVERFLOW_LOCKS.get_or_init(|| {
|
||||
(0..CROSS_MODE_QUOTA_OVERFLOW_LOCK_STRIPES)
|
||||
.map(|_| Arc::new(Mutex::new(())))
|
||||
.collect()
|
||||
});
|
||||
|
||||
let hash = crc32fast::hash(user.as_bytes()) as usize;
|
||||
Arc::clone(&stripes[hash % stripes.len()])
|
||||
}
|
||||
|
||||
pub(crate) fn cross_mode_quota_user_lock(user: &str) -> Arc<Mutex<()>> {
|
||||
#[cfg(test)]
|
||||
{
|
||||
CROSS_MODE_QUOTA_USER_LOCK_LOOKUPS.fetch_add(1, Ordering::Relaxed);
|
||||
let lookups = CROSS_MODE_QUOTA_USER_LOOKUPS_BY_USER.get_or_init(DashMap::new);
|
||||
let mut entry = lookups.entry(user.to_string()).or_insert(0);
|
||||
*entry += 1;
|
||||
}
|
||||
|
||||
let locks = CROSS_MODE_QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
if let Some(existing) = locks.get(user) {
|
||||
return Arc::clone(existing.value());
|
||||
}
|
||||
|
||||
if locks.len() >= CROSS_MODE_QUOTA_USER_LOCKS_MAX {
|
||||
locks.retain(|_, value| Arc::strong_count(value) > 1);
|
||||
}
|
||||
|
||||
if locks.len() >= CROSS_MODE_QUOTA_USER_LOCKS_MAX {
|
||||
return cross_mode_quota_overflow_user_lock(user);
|
||||
}
|
||||
|
||||
let created = Arc::new(Mutex::new(()));
|
||||
match locks.entry(user.to_string()) {
|
||||
dashmap::mapref::entry::Entry::Occupied(entry) => Arc::clone(entry.get()),
|
||||
dashmap::mapref::entry::Entry::Vacant(entry) => {
|
||||
entry.insert(Arc::clone(&created));
|
||||
created
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn reset_cross_mode_quota_user_lock_lookup_count_for_tests() {
|
||||
CROSS_MODE_QUOTA_USER_LOCK_LOOKUPS.store(0, Ordering::Relaxed);
|
||||
let lookups = CROSS_MODE_QUOTA_USER_LOOKUPS_BY_USER.get_or_init(DashMap::new);
|
||||
lookups.clear();
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn cross_mode_quota_user_lock_lookup_count_for_tests() -> usize {
|
||||
CROSS_MODE_QUOTA_USER_LOCK_LOOKUPS.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn cross_mode_quota_user_lock_lookup_count_for_user_for_tests(user: &str) -> usize {
|
||||
let lookups = CROSS_MODE_QUOTA_USER_LOOKUPS_BY_USER.get_or_init(DashMap::new);
|
||||
lookups.get(user).map(|entry| *entry).unwrap_or(0)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/quota_lock_registry_cross_mode_adversarial_tests.rs"]
|
||||
mod quota_lock_registry_cross_mode_adversarial_tests;
|
||||
|
|
@ -62,7 +62,8 @@ use std::sync::{Arc, Mutex, OnceLock};
|
|||
use std::task::{Context, Poll};
|
||||
use std::time::Duration;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf, copy_bidirectional_with_sizes};
|
||||
use tokio::time::Instant;
|
||||
use tokio::sync::Mutex as AsyncMutex;
|
||||
use tokio::time::{Instant, Sleep};
|
||||
use tracing::{debug, trace, warn};
|
||||
|
||||
// ============= Constants =============
|
||||
|
|
@ -209,12 +210,16 @@ struct StatsIo<S> {
|
|||
counters: Arc<SharedCounters>,
|
||||
stats: Arc<Stats>,
|
||||
user: String,
|
||||
quota_lock: Option<Arc<Mutex<()>>>,
|
||||
cross_mode_quota_lock: Option<Arc<AsyncMutex<()>>>,
|
||||
quota_limit: Option<u64>,
|
||||
quota_exceeded: Arc<AtomicBool>,
|
||||
quota_read_wake_scheduled: bool,
|
||||
quota_write_wake_scheduled: bool,
|
||||
quota_read_retry_active: Arc<AtomicBool>,
|
||||
quota_write_retry_active: Arc<AtomicBool>,
|
||||
quota_read_retry_sleep: Option<Pin<Box<Sleep>>>,
|
||||
quota_write_retry_sleep: Option<Pin<Box<Sleep>>>,
|
||||
quota_read_retry_attempt: u8,
|
||||
quota_write_retry_attempt: u8,
|
||||
epoch: Instant,
|
||||
}
|
||||
|
||||
|
|
@ -230,30 +235,29 @@ impl<S> StatsIo<S> {
|
|||
) -> Self {
|
||||
// Mark initial activity so the watchdog doesn't fire before data flows
|
||||
counters.touch(Instant::now(), epoch);
|
||||
let quota_lock = quota_limit.map(|_| quota_user_lock(&user));
|
||||
let cross_mode_quota_lock = quota_limit
|
||||
.map(|_| crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user));
|
||||
Self {
|
||||
inner,
|
||||
counters,
|
||||
stats,
|
||||
user,
|
||||
quota_lock,
|
||||
cross_mode_quota_lock,
|
||||
quota_limit,
|
||||
quota_exceeded,
|
||||
quota_read_wake_scheduled: false,
|
||||
quota_write_wake_scheduled: false,
|
||||
quota_read_retry_active: Arc::new(AtomicBool::new(false)),
|
||||
quota_write_retry_active: Arc::new(AtomicBool::new(false)),
|
||||
quota_read_retry_sleep: None,
|
||||
quota_write_retry_sleep: None,
|
||||
quota_read_retry_attempt: 0,
|
||||
quota_write_retry_attempt: 0,
|
||||
epoch,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> Drop for StatsIo<S> {
|
||||
fn drop(&mut self) {
|
||||
self.quota_read_retry_active.store(false, Ordering::Relaxed);
|
||||
self.quota_write_retry_active
|
||||
.store(false, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct QuotaIoSentinel;
|
||||
|
||||
|
|
@ -281,20 +285,69 @@ fn is_quota_io_error(err: &io::Error) -> bool {
|
|||
const QUOTA_CONTENTION_RETRY_INTERVAL: Duration = Duration::from_millis(1);
|
||||
#[cfg(not(test))]
|
||||
const QUOTA_CONTENTION_RETRY_INTERVAL: Duration = Duration::from_millis(2);
|
||||
#[cfg(test)]
|
||||
const QUOTA_CONTENTION_RETRY_MAX_INTERVAL: Duration = Duration::from_millis(16);
|
||||
#[cfg(not(test))]
|
||||
const QUOTA_CONTENTION_RETRY_MAX_INTERVAL: Duration = Duration::from_millis(64);
|
||||
|
||||
fn spawn_quota_retry_waker(retry_active: Arc<AtomicBool>, waker: std::task::Waker) {
|
||||
tokio::task::spawn(async move {
|
||||
loop {
|
||||
if !retry_active.load(Ordering::Relaxed) {
|
||||
break;
|
||||
}
|
||||
tokio::time::sleep(QUOTA_CONTENTION_RETRY_INTERVAL).await;
|
||||
if !retry_active.load(Ordering::Relaxed) {
|
||||
break;
|
||||
}
|
||||
waker.wake_by_ref();
|
||||
}
|
||||
});
|
||||
#[cfg(test)]
|
||||
static QUOTA_RETRY_SLEEP_ALLOCS: AtomicU64 = AtomicU64::new(0);
|
||||
#[cfg(test)]
|
||||
static QUOTA_USER_LOCK_EVICTOR_SPAWN_COUNT: AtomicU64 = AtomicU64::new(0);
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn reset_quota_retry_sleep_allocs_for_tests() {
|
||||
QUOTA_RETRY_SLEEP_ALLOCS.store(0, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn quota_retry_sleep_allocs_for_tests() -> u64 {
|
||||
QUOTA_RETRY_SLEEP_ALLOCS.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn quota_contention_retry_delay(retry_attempt: u8) -> Duration {
|
||||
let shift = u32::from(retry_attempt.min(5));
|
||||
let multiplier = 1_u32 << shift;
|
||||
QUOTA_CONTENTION_RETRY_INTERVAL
|
||||
.saturating_mul(multiplier)
|
||||
.min(QUOTA_CONTENTION_RETRY_MAX_INTERVAL)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn reset_quota_retry_scheduler(
|
||||
sleep_slot: &mut Option<Pin<Box<Sleep>>>,
|
||||
wake_scheduled: &mut bool,
|
||||
retry_attempt: &mut u8,
|
||||
) {
|
||||
*wake_scheduled = false;
|
||||
*sleep_slot = None;
|
||||
*retry_attempt = 0;
|
||||
}
|
||||
|
||||
fn poll_quota_retry_sleep(
|
||||
sleep_slot: &mut Option<Pin<Box<Sleep>>>,
|
||||
wake_scheduled: &mut bool,
|
||||
retry_attempt: &mut u8,
|
||||
cx: &mut Context<'_>,
|
||||
) {
|
||||
if !*wake_scheduled {
|
||||
*wake_scheduled = true;
|
||||
#[cfg(test)]
|
||||
QUOTA_RETRY_SLEEP_ALLOCS.fetch_add(1, Ordering::Relaxed);
|
||||
*sleep_slot = Some(Box::pin(tokio::time::sleep(quota_contention_retry_delay(
|
||||
*retry_attempt,
|
||||
))));
|
||||
}
|
||||
|
||||
if let Some(sleep) = sleep_slot.as_mut()
|
||||
&& sleep.as_mut().poll(cx).is_ready()
|
||||
{
|
||||
*sleep_slot = None;
|
||||
*wake_scheduled = false;
|
||||
*retry_attempt = retry_attempt.saturating_add(1);
|
||||
cx.waker().wake_by_ref();
|
||||
}
|
||||
}
|
||||
|
||||
static QUOTA_USER_LOCKS: OnceLock<DashMap<String, Arc<Mutex<()>>>> = OnceLock::new();
|
||||
|
|
@ -333,16 +386,47 @@ fn quota_overflow_user_lock(user: &str) -> Arc<Mutex<()>> {
|
|||
Arc::clone(&stripes[hash % stripes.len()])
|
||||
}
|
||||
|
||||
pub(crate) fn quota_user_lock_evict() {
|
||||
if let Some(locks) = QUOTA_USER_LOCKS.get() {
|
||||
locks.retain(|_, value| Arc::strong_count(value) > 1);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn spawn_quota_user_lock_evictor(interval: Duration) -> tokio::task::JoinHandle<()> {
|
||||
let interval = interval.max(Duration::from_millis(1));
|
||||
#[cfg(test)]
|
||||
QUOTA_USER_LOCK_EVICTOR_SPAWN_COUNT.fetch_add(1, Ordering::Relaxed);
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
tokio::time::sleep(interval).await;
|
||||
quota_user_lock_evict();
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn spawn_quota_user_lock_evictor_for_tests(
|
||||
interval: Duration,
|
||||
) -> tokio::task::JoinHandle<()> {
|
||||
spawn_quota_user_lock_evictor(interval)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn reset_quota_user_lock_evictor_spawn_count_for_tests() {
|
||||
QUOTA_USER_LOCK_EVICTOR_SPAWN_COUNT.store(0, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn quota_user_lock_evictor_spawn_count_for_tests() -> u64 {
|
||||
QUOTA_USER_LOCK_EVICTOR_SPAWN_COUNT.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
fn quota_user_lock(user: &str) -> Arc<Mutex<()>> {
|
||||
let locks = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
if let Some(existing) = locks.get(user) {
|
||||
return Arc::clone(existing.value());
|
||||
}
|
||||
|
||||
if locks.len() >= QUOTA_USER_LOCKS_MAX {
|
||||
locks.retain(|_, value| Arc::strong_count(value) > 1);
|
||||
}
|
||||
|
||||
if locks.len() >= QUOTA_USER_LOCKS_MAX {
|
||||
return quota_overflow_user_lock(user);
|
||||
}
|
||||
|
|
@ -357,6 +441,11 @@ fn quota_user_lock(user: &str) -> Arc<Mutex<()>> {
|
|||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn cross_mode_quota_user_lock_for_tests(user: &str) -> Arc<AsyncMutex<()>> {
|
||||
crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(user)
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + Unpin> AsyncRead for StatsIo<S> {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
|
|
@ -368,26 +457,16 @@ impl<S: AsyncRead + Unpin> AsyncRead for StatsIo<S> {
|
|||
return Poll::Ready(Err(quota_io_error()));
|
||||
}
|
||||
|
||||
let quota_lock = this
|
||||
.quota_limit
|
||||
.is_some()
|
||||
.then(|| quota_user_lock(&this.user));
|
||||
let _quota_guard = if let Some(lock) = quota_lock.as_ref() {
|
||||
let _quota_guard = if let Some(lock) = this.quota_lock.as_ref() {
|
||||
match lock.try_lock() {
|
||||
Ok(guard) => {
|
||||
this.quota_read_wake_scheduled = false;
|
||||
this.quota_read_retry_active.store(false, Ordering::Relaxed);
|
||||
Some(guard)
|
||||
}
|
||||
Ok(guard) => Some(guard),
|
||||
Err(_) => {
|
||||
if !this.quota_read_wake_scheduled {
|
||||
this.quota_read_wake_scheduled = true;
|
||||
this.quota_read_retry_active.store(true, Ordering::Relaxed);
|
||||
spawn_quota_retry_waker(
|
||||
Arc::clone(&this.quota_read_retry_active),
|
||||
cx.waker().clone(),
|
||||
);
|
||||
}
|
||||
poll_quota_retry_sleep(
|
||||
&mut this.quota_read_retry_sleep,
|
||||
&mut this.quota_read_wake_scheduled,
|
||||
&mut this.quota_read_retry_attempt,
|
||||
cx,
|
||||
);
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
|
|
@ -395,6 +474,29 @@ impl<S: AsyncRead + Unpin> AsyncRead for StatsIo<S> {
|
|||
None
|
||||
};
|
||||
|
||||
let _cross_mode_quota_guard = if let Some(lock) = this.cross_mode_quota_lock.as_ref() {
|
||||
match lock.try_lock() {
|
||||
Ok(guard) => Some(guard),
|
||||
Err(_) => {
|
||||
poll_quota_retry_sleep(
|
||||
&mut this.quota_read_retry_sleep,
|
||||
&mut this.quota_read_wake_scheduled,
|
||||
&mut this.quota_read_retry_attempt,
|
||||
cx,
|
||||
);
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
reset_quota_retry_scheduler(
|
||||
&mut this.quota_read_retry_sleep,
|
||||
&mut this.quota_read_wake_scheduled,
|
||||
&mut this.quota_read_retry_attempt,
|
||||
);
|
||||
|
||||
if let Some(limit) = this.quota_limit
|
||||
&& this.stats.get_user_total_octets(&this.user) >= limit
|
||||
{
|
||||
|
|
@ -460,27 +562,16 @@ impl<S: AsyncWrite + Unpin> AsyncWrite for StatsIo<S> {
|
|||
return Poll::Ready(Err(quota_io_error()));
|
||||
}
|
||||
|
||||
let quota_lock = this
|
||||
.quota_limit
|
||||
.is_some()
|
||||
.then(|| quota_user_lock(&this.user));
|
||||
let _quota_guard = if let Some(lock) = quota_lock.as_ref() {
|
||||
let _quota_guard = if let Some(lock) = this.quota_lock.as_ref() {
|
||||
match lock.try_lock() {
|
||||
Ok(guard) => {
|
||||
this.quota_write_wake_scheduled = false;
|
||||
this.quota_write_retry_active
|
||||
.store(false, Ordering::Relaxed);
|
||||
Some(guard)
|
||||
}
|
||||
Ok(guard) => Some(guard),
|
||||
Err(_) => {
|
||||
if !this.quota_write_wake_scheduled {
|
||||
this.quota_write_wake_scheduled = true;
|
||||
this.quota_write_retry_active.store(true, Ordering::Relaxed);
|
||||
spawn_quota_retry_waker(
|
||||
Arc::clone(&this.quota_write_retry_active),
|
||||
cx.waker().clone(),
|
||||
);
|
||||
}
|
||||
poll_quota_retry_sleep(
|
||||
&mut this.quota_write_retry_sleep,
|
||||
&mut this.quota_write_wake_scheduled,
|
||||
&mut this.quota_write_retry_attempt,
|
||||
cx,
|
||||
);
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
|
|
@ -488,6 +579,29 @@ impl<S: AsyncWrite + Unpin> AsyncWrite for StatsIo<S> {
|
|||
None
|
||||
};
|
||||
|
||||
let _cross_mode_quota_guard = if let Some(lock) = this.cross_mode_quota_lock.as_ref() {
|
||||
match lock.try_lock() {
|
||||
Ok(guard) => Some(guard),
|
||||
Err(_) => {
|
||||
poll_quota_retry_sleep(
|
||||
&mut this.quota_write_retry_sleep,
|
||||
&mut this.quota_write_wake_scheduled,
|
||||
&mut this.quota_write_retry_attempt,
|
||||
cx,
|
||||
);
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
reset_quota_retry_scheduler(
|
||||
&mut this.quota_write_retry_sleep,
|
||||
&mut this.quota_write_wake_scheduled,
|
||||
&mut this.quota_write_retry_attempt,
|
||||
);
|
||||
|
||||
let write_buf = if let Some(limit) = this.quota_limit {
|
||||
let used = this.stats.get_user_total_octets(&this.user);
|
||||
if used >= limit {
|
||||
|
|
@ -780,6 +894,10 @@ mod relay_quota_model_adversarial_tests;
|
|||
#[path = "tests/relay_quota_overflow_regression_tests.rs"]
|
||||
mod relay_quota_overflow_regression_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_quota_extended_attack_surface_security_tests.rs"]
|
||||
mod relay_quota_extended_attack_surface_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_watchdog_delta_security_tests.rs"]
|
||||
mod relay_watchdog_delta_security_tests;
|
||||
|
|
@ -791,3 +909,63 @@ mod relay_quota_waker_storm_adversarial_tests;
|
|||
#[cfg(test)]
|
||||
#[path = "tests/relay_quota_wake_liveness_regression_tests.rs"]
|
||||
mod relay_quota_wake_liveness_regression_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_quota_lock_identity_security_tests.rs"]
|
||||
mod relay_quota_lock_identity_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_cross_mode_quota_lock_security_tests.rs"]
|
||||
mod relay_cross_mode_quota_lock_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_quota_retry_scheduler_tdd_tests.rs"]
|
||||
mod relay_quota_retry_scheduler_tdd_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_cross_mode_quota_fairness_tdd_tests.rs"]
|
||||
mod relay_cross_mode_quota_fairness_tdd_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_cross_mode_pipeline_hol_integration_security_tests.rs"]
|
||||
mod relay_cross_mode_pipeline_hol_integration_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_cross_mode_pipeline_latency_benchmark_security_tests.rs"]
|
||||
mod relay_cross_mode_pipeline_latency_benchmark_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_quota_retry_backoff_security_tests.rs"]
|
||||
mod relay_quota_retry_backoff_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_quota_retry_backoff_benchmark_security_tests.rs"]
|
||||
mod relay_quota_retry_backoff_benchmark_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_dual_lock_backoff_regression_security_tests.rs"]
|
||||
mod relay_dual_lock_backoff_regression_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_dual_lock_contention_matrix_security_tests.rs"]
|
||||
mod relay_dual_lock_contention_matrix_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_dual_lock_race_harness_security_tests.rs"]
|
||||
mod relay_dual_lock_race_harness_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_dual_lock_alternating_contention_security_tests.rs"]
|
||||
mod relay_dual_lock_alternating_contention_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_quota_retry_allocation_latency_security_tests.rs"]
|
||||
mod relay_quota_retry_allocation_latency_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_quota_lock_eviction_lifecycle_tdd_tests.rs"]
|
||||
mod relay_quota_lock_eviction_lifecycle_tdd_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_quota_lock_eviction_stress_security_tests.rs"]
|
||||
mod relay_quota_lock_eviction_stress_security_tests;
|
||||
|
|
|
|||
|
|
@ -0,0 +1,100 @@
|
|||
use super::*;
|
||||
use crate::config::{UpstreamConfig, UpstreamType};
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::time::Duration;
|
||||
|
||||
fn new_upstream_manager(stats: Arc<Stats>) -> Arc<UpstreamManager> {
|
||||
Arc::new(UpstreamManager::new(
|
||||
vec![UpstreamConfig {
|
||||
upstream_type: UpstreamType::Direct {
|
||||
interface: None,
|
||||
bind_addresses: None,
|
||||
},
|
||||
weight: 1,
|
||||
enabled: true,
|
||||
scopes: String::new(),
|
||||
selected_scope: String::new(),
|
||||
}],
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
false,
|
||||
stats,
|
||||
))
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn fragmented_connect_probe_is_classified_as_http_via_prefetch_window() {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let backend_addr = listener.local_addr().unwrap();
|
||||
|
||||
let accept_task = tokio::spawn(async move {
|
||||
let (mut stream, _) = listener.accept().await.unwrap();
|
||||
let mut got = Vec::new();
|
||||
stream.read_to_end(&mut got).await.unwrap();
|
||||
got
|
||||
});
|
||||
|
||||
let mut cfg = ProxyConfig::default();
|
||||
cfg.general.beobachten = true;
|
||||
cfg.general.beobachten_minutes = 1;
|
||||
cfg.censorship.mask = true;
|
||||
cfg.censorship.mask_unix_sock = None;
|
||||
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
|
||||
cfg.censorship.mask_port = backend_addr.port();
|
||||
cfg.general.modes.classic = false;
|
||||
cfg.general.modes.secure = false;
|
||||
|
||||
let config = Arc::new(cfg);
|
||||
let stats = Arc::new(Stats::new());
|
||||
let beobachten = Arc::new(BeobachtenStore::new());
|
||||
|
||||
let (server_side, mut client_side) = duplex(4096);
|
||||
let peer: SocketAddr = "198.51.100.251:57501".parse().unwrap();
|
||||
|
||||
let handler = tokio::spawn(handle_client_stream(
|
||||
server_side,
|
||||
peer,
|
||||
config,
|
||||
stats.clone(),
|
||||
new_upstream_manager(stats),
|
||||
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
|
||||
Arc::new(BufferPool::new()),
|
||||
Arc::new(SecureRandom::new()),
|
||||
None,
|
||||
Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
|
||||
None,
|
||||
Arc::new(UserIpTracker::new()),
|
||||
beobachten.clone(),
|
||||
false,
|
||||
));
|
||||
|
||||
client_side.write_all(b"CONNE").await.unwrap();
|
||||
client_side
|
||||
.write_all(b"CT example.org:443 HTTP/1.1\r\nHost: example.org\r\n\r\n")
|
||||
.await
|
||||
.unwrap();
|
||||
client_side.shutdown().await.unwrap();
|
||||
|
||||
let forwarded = tokio::time::timeout(Duration::from_secs(3), accept_task)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
assert!(
|
||||
forwarded.starts_with(b"CONNECT example.org:443 HTTP/1.1"),
|
||||
"mask backend must receive the full fragmented CONNECT probe"
|
||||
);
|
||||
|
||||
let result = tokio::time::timeout(Duration::from_secs(3), handler)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
assert!(result.is_ok());
|
||||
|
||||
let snapshot = beobachten.snapshot_text(Duration::from_secs(60));
|
||||
assert!(snapshot.contains("[HTTP]"));
|
||||
assert!(snapshot.contains("198.51.100.251-1"));
|
||||
}
|
||||
|
|
@ -0,0 +1,129 @@
|
|||
use super::*;
|
||||
use crate::config::{UpstreamConfig, UpstreamType};
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::time::{Duration, sleep};
|
||||
|
||||
fn new_upstream_manager(stats: Arc<Stats>) -> Arc<UpstreamManager> {
|
||||
Arc::new(UpstreamManager::new(
|
||||
vec![UpstreamConfig {
|
||||
upstream_type: UpstreamType::Direct {
|
||||
interface: None,
|
||||
bind_addresses: None,
|
||||
},
|
||||
weight: 1,
|
||||
enabled: true,
|
||||
scopes: String::new(),
|
||||
selected_scope: String::new(),
|
||||
}],
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
false,
|
||||
stats,
|
||||
))
|
||||
}
|
||||
|
||||
async fn run_http2_fragment_case(split_at: usize, delay_ms: u64, peer: SocketAddr) {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let backend_addr = listener.local_addr().unwrap();
|
||||
let preface = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n".to_vec();
|
||||
|
||||
let accept_task = tokio::spawn(async move {
|
||||
let (mut stream, _) = listener.accept().await.unwrap();
|
||||
let mut got = Vec::new();
|
||||
stream.read_to_end(&mut got).await.unwrap();
|
||||
got
|
||||
});
|
||||
|
||||
let mut cfg = ProxyConfig::default();
|
||||
cfg.general.beobachten = true;
|
||||
cfg.general.beobachten_minutes = 1;
|
||||
cfg.censorship.mask = true;
|
||||
cfg.censorship.mask_unix_sock = None;
|
||||
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
|
||||
cfg.censorship.mask_port = backend_addr.port();
|
||||
cfg.general.modes.classic = false;
|
||||
cfg.general.modes.secure = false;
|
||||
|
||||
let config = Arc::new(cfg);
|
||||
let stats = Arc::new(Stats::new());
|
||||
let beobachten = Arc::new(BeobachtenStore::new());
|
||||
|
||||
let (server_side, mut client_side) = duplex(4096);
|
||||
let handler = tokio::spawn(handle_client_stream(
|
||||
server_side,
|
||||
peer,
|
||||
config,
|
||||
stats.clone(),
|
||||
new_upstream_manager(stats),
|
||||
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
|
||||
Arc::new(BufferPool::new()),
|
||||
Arc::new(SecureRandom::new()),
|
||||
None,
|
||||
Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
|
||||
None,
|
||||
Arc::new(UserIpTracker::new()),
|
||||
beobachten.clone(),
|
||||
false,
|
||||
));
|
||||
|
||||
let first = split_at.min(preface.len());
|
||||
client_side.write_all(&preface[..first]).await.unwrap();
|
||||
if first < preface.len() {
|
||||
sleep(Duration::from_millis(delay_ms)).await;
|
||||
client_side.write_all(&preface[first..]).await.unwrap();
|
||||
}
|
||||
client_side.shutdown().await.unwrap();
|
||||
|
||||
let forwarded = tokio::time::timeout(Duration::from_secs(3), accept_task)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
assert!(
|
||||
forwarded.starts_with(&preface),
|
||||
"mask backend must receive an intact HTTP/2 preface prefix"
|
||||
);
|
||||
|
||||
let result = tokio::time::timeout(Duration::from_secs(3), handler)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
assert!(result.is_ok());
|
||||
|
||||
let snapshot = beobachten.snapshot_text(Duration::from_secs(60));
|
||||
assert!(snapshot.contains("[HTTP]"));
|
||||
assert!(snapshot.contains(&format!("{}-1", peer.ip())));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn http2_preface_fragmentation_matrix_is_classified_and_forwarded() {
|
||||
let cases = [
|
||||
(2usize, 0u64),
|
||||
(3, 0),
|
||||
(4, 0),
|
||||
(2, 7),
|
||||
(3, 7),
|
||||
(8, 1),
|
||||
];
|
||||
|
||||
for (i, (split_at, delay_ms)) in cases.into_iter().enumerate() {
|
||||
let peer: SocketAddr = format!("198.51.100.{}:58{}", 140 + i, 100 + i)
|
||||
.parse()
|
||||
.unwrap();
|
||||
run_http2_fragment_case(split_at, delay_ms, peer).await;
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn http2_preface_splitpoint_light_fuzz_classifies_http() {
|
||||
for split_at in 2usize..=12 {
|
||||
let delay_ms = if split_at % 3 == 0 { 7 } else { 1 };
|
||||
let peer: SocketAddr = format!("198.51.101.{}:59{}", split_at, 10 + split_at)
|
||||
.parse()
|
||||
.unwrap();
|
||||
run_http2_fragment_case(split_at, delay_ms, peer).await;
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,150 @@
|
|||
use super::*;
|
||||
use crate::config::{UpstreamConfig, UpstreamType};
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::time::{Duration, sleep};
|
||||
|
||||
fn new_upstream_manager(stats: Arc<Stats>) -> Arc<UpstreamManager> {
|
||||
Arc::new(UpstreamManager::new(
|
||||
vec![UpstreamConfig {
|
||||
upstream_type: UpstreamType::Direct {
|
||||
interface: None,
|
||||
bind_addresses: None,
|
||||
},
|
||||
weight: 1,
|
||||
enabled: true,
|
||||
scopes: String::new(),
|
||||
selected_scope: String::new(),
|
||||
}],
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
false,
|
||||
stats,
|
||||
))
|
||||
}
|
||||
|
||||
async fn run_pipeline_prefetch_case(
|
||||
prefetch_timeout_ms: u64,
|
||||
delayed_tail_ms: u64,
|
||||
peer: SocketAddr,
|
||||
) -> (Vec<u8>, String) {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let backend_addr = listener.local_addr().unwrap();
|
||||
|
||||
let accept_task = tokio::spawn(async move {
|
||||
let (mut stream, _) = listener.accept().await.unwrap();
|
||||
let mut got = Vec::new();
|
||||
stream.read_to_end(&mut got).await.unwrap();
|
||||
got
|
||||
});
|
||||
|
||||
let mut cfg = ProxyConfig::default();
|
||||
cfg.general.beobachten = true;
|
||||
cfg.general.beobachten_minutes = 1;
|
||||
cfg.censorship.mask = true;
|
||||
cfg.censorship.mask_unix_sock = None;
|
||||
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
|
||||
cfg.censorship.mask_port = backend_addr.port();
|
||||
cfg.censorship.mask_classifier_prefetch_timeout_ms = prefetch_timeout_ms;
|
||||
cfg.general.modes.classic = false;
|
||||
cfg.general.modes.secure = false;
|
||||
|
||||
let config = Arc::new(cfg);
|
||||
let stats = Arc::new(Stats::new());
|
||||
let beobachten = Arc::new(BeobachtenStore::new());
|
||||
|
||||
let (server_side, mut client_side) = duplex(4096);
|
||||
let handler = tokio::spawn(handle_client_stream(
|
||||
server_side,
|
||||
peer,
|
||||
config,
|
||||
stats.clone(),
|
||||
new_upstream_manager(stats),
|
||||
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
|
||||
Arc::new(BufferPool::new()),
|
||||
Arc::new(SecureRandom::new()),
|
||||
None,
|
||||
Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
|
||||
None,
|
||||
Arc::new(UserIpTracker::new()),
|
||||
beobachten.clone(),
|
||||
false,
|
||||
));
|
||||
|
||||
client_side.write_all(b"C").await.unwrap();
|
||||
sleep(Duration::from_millis(delayed_tail_ms)).await;
|
||||
|
||||
client_side
|
||||
.write_all(b"ONNECT example.org:443 HTTP/1.1\r\nHost: example.org\r\n\r\n")
|
||||
.await
|
||||
.unwrap();
|
||||
client_side.shutdown().await.unwrap();
|
||||
|
||||
let forwarded = tokio::time::timeout(Duration::from_secs(3), accept_task)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
let result = tokio::time::timeout(Duration::from_secs(3), handler)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
assert!(result.is_ok());
|
||||
|
||||
let snapshot = beobachten.snapshot_text(Duration::from_secs(60));
|
||||
(forwarded, snapshot)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn tdd_pipeline_prefetch_5ms_misses_15ms_tail_and_classifies_as_port_scanner() {
|
||||
let peer: SocketAddr = "198.51.100.171:58071".parse().unwrap();
|
||||
let (forwarded, snapshot) = run_pipeline_prefetch_case(5, 15, peer).await;
|
||||
|
||||
assert!(
|
||||
forwarded.starts_with(b"CONNECT"),
|
||||
"mask backend must still receive full payload bytes in-order"
|
||||
);
|
||||
assert!(
|
||||
snapshot.contains("[HTTP]") || snapshot.contains("[port-scanner]"),
|
||||
"unexpected classifier snapshot for 5ms delayed-tail case: {snapshot}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn tdd_pipeline_prefetch_20ms_recovers_15ms_tail_and_classifies_as_http() {
|
||||
let peer: SocketAddr = "198.51.100.172:58072".parse().unwrap();
|
||||
let (forwarded, snapshot) = run_pipeline_prefetch_case(20, 15, peer).await;
|
||||
|
||||
assert!(
|
||||
forwarded.starts_with(b"CONNECT"),
|
||||
"mask backend must receive full CONNECT payload"
|
||||
);
|
||||
assert!(
|
||||
snapshot.contains("[HTTP]"),
|
||||
"20ms budget should recover delayed fragmented prefix and classify as HTTP"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn matrix_pipeline_prefetch_budget_behavior_5_20_50ms() {
|
||||
let peer5: SocketAddr = "198.51.100.173:58073".parse().unwrap();
|
||||
let peer20: SocketAddr = "198.51.100.174:58074".parse().unwrap();
|
||||
let peer50: SocketAddr = "198.51.100.175:58075".parse().unwrap();
|
||||
|
||||
let (_, snap5) = run_pipeline_prefetch_case(5, 35, peer5).await;
|
||||
let (_, snap20) = run_pipeline_prefetch_case(20, 35, peer20).await;
|
||||
let (_, snap50) = run_pipeline_prefetch_case(50, 35, peer50).await;
|
||||
|
||||
assert!(
|
||||
snap5.contains("[HTTP]") || snap5.contains("[port-scanner]"),
|
||||
"unexpected 5ms snapshot: {snap5}"
|
||||
);
|
||||
assert!(
|
||||
snap20.contains("[HTTP]") || snap20.contains("[port-scanner]"),
|
||||
"unexpected 20ms snapshot: {snap20}"
|
||||
);
|
||||
assert!(snap50.contains("[HTTP]"));
|
||||
}
|
||||
|
|
@ -0,0 +1,82 @@
|
|||
use super::*;
|
||||
use tokio::io::{AsyncWriteExt, duplex};
|
||||
use tokio::time::{Duration, sleep};
|
||||
|
||||
#[test]
|
||||
fn prefetch_timeout_budget_reads_from_config() {
|
||||
let mut cfg = ProxyConfig::default();
|
||||
assert_eq!(
|
||||
mask_classifier_prefetch_timeout(&cfg),
|
||||
Duration::from_millis(5),
|
||||
"default prefetch timeout budget must remain 5ms"
|
||||
);
|
||||
|
||||
cfg.censorship.mask_classifier_prefetch_timeout_ms = 20;
|
||||
assert_eq!(
|
||||
mask_classifier_prefetch_timeout(&cfg),
|
||||
Duration::from_millis(20),
|
||||
"runtime prefetch timeout budget must follow configured value"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn configured_prefetch_budget_20ms_recovers_tail_delayed_15ms() {
|
||||
let (mut reader, mut writer) = duplex(1024);
|
||||
|
||||
let writer_task = tokio::spawn(async move {
|
||||
sleep(Duration::from_millis(15)).await;
|
||||
writer
|
||||
.write_all(b"ONNECT example.org:443 HTTP/1.1\r\n")
|
||||
.await
|
||||
.expect("tail bytes must be writable");
|
||||
writer.shutdown().await.expect("writer shutdown must succeed");
|
||||
});
|
||||
|
||||
let mut initial_data = b"C".to_vec();
|
||||
extend_masking_initial_window_with_timeout(
|
||||
&mut reader,
|
||||
&mut initial_data,
|
||||
Duration::from_millis(20),
|
||||
)
|
||||
.await;
|
||||
|
||||
writer_task
|
||||
.await
|
||||
.expect("writer task must not panic in runtime timeout test");
|
||||
|
||||
assert!(
|
||||
initial_data.starts_with(b"CONNECT"),
|
||||
"20ms configured prefetch budget should recover 15ms delayed CONNECT tail"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn configured_prefetch_budget_5ms_misses_tail_delayed_15ms() {
|
||||
let (mut reader, mut writer) = duplex(1024);
|
||||
|
||||
let writer_task = tokio::spawn(async move {
|
||||
sleep(Duration::from_millis(15)).await;
|
||||
writer
|
||||
.write_all(b"ONNECT example.org:443 HTTP/1.1\r\n")
|
||||
.await
|
||||
.expect("tail bytes must be writable");
|
||||
writer.shutdown().await.expect("writer shutdown must succeed");
|
||||
});
|
||||
|
||||
let mut initial_data = b"C".to_vec();
|
||||
extend_masking_initial_window_with_timeout(
|
||||
&mut reader,
|
||||
&mut initial_data,
|
||||
Duration::from_millis(5),
|
||||
)
|
||||
.await;
|
||||
|
||||
writer_task
|
||||
.await
|
||||
.expect("writer task must not panic in runtime timeout test");
|
||||
|
||||
assert!(
|
||||
!initial_data.starts_with(b"CONNECT"),
|
||||
"5ms configured prefetch budget should miss 15ms delayed CONNECT tail"
|
||||
);
|
||||
}
|
||||
|
|
@ -0,0 +1,261 @@
|
|||
use super::*;
|
||||
use crate::config::{UpstreamConfig, UpstreamType};
|
||||
use crate::crypto::sha256_hmac;
|
||||
use crate::protocol::constants::{HANDSHAKE_LEN, TLS_VERSION};
|
||||
use crate::protocol::tls;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
struct PipelineHarness {
|
||||
config: Arc<ProxyConfig>,
|
||||
stats: Arc<Stats>,
|
||||
upstream_manager: Arc<UpstreamManager>,
|
||||
replay_checker: Arc<ReplayChecker>,
|
||||
buffer_pool: Arc<BufferPool>,
|
||||
rng: Arc<SecureRandom>,
|
||||
route_runtime: Arc<RouteRuntimeController>,
|
||||
ip_tracker: Arc<UserIpTracker>,
|
||||
beobachten: Arc<BeobachtenStore>,
|
||||
}
|
||||
|
||||
fn build_harness(secret_hex: &str, mask_port: u16) -> PipelineHarness {
|
||||
let mut cfg = ProxyConfig::default();
|
||||
cfg.general.beobachten = false;
|
||||
cfg.censorship.mask = true;
|
||||
cfg.censorship.mask_unix_sock = None;
|
||||
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
|
||||
cfg.censorship.mask_port = mask_port;
|
||||
cfg.censorship.mask_proxy_protocol = 0;
|
||||
cfg.access.ignore_time_skew = true;
|
||||
cfg.access
|
||||
.users
|
||||
.insert("user".to_string(), secret_hex.to_string());
|
||||
|
||||
let config = Arc::new(cfg);
|
||||
let stats = Arc::new(Stats::new());
|
||||
let upstream_manager = Arc::new(UpstreamManager::new(
|
||||
vec![UpstreamConfig {
|
||||
upstream_type: UpstreamType::Direct {
|
||||
interface: None,
|
||||
bind_addresses: None,
|
||||
},
|
||||
weight: 1,
|
||||
enabled: true,
|
||||
scopes: String::new(),
|
||||
selected_scope: String::new(),
|
||||
}],
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
false,
|
||||
stats.clone(),
|
||||
));
|
||||
|
||||
PipelineHarness {
|
||||
config,
|
||||
stats,
|
||||
upstream_manager,
|
||||
replay_checker: Arc::new(ReplayChecker::new(256, Duration::from_secs(60))),
|
||||
buffer_pool: Arc::new(BufferPool::new()),
|
||||
rng: Arc::new(SecureRandom::new()),
|
||||
route_runtime: Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
|
||||
ip_tracker: Arc::new(UserIpTracker::new()),
|
||||
beobachten: Arc::new(BeobachtenStore::new()),
|
||||
}
|
||||
}
|
||||
|
||||
fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec<u8> {
|
||||
let total_len = 5 + tls_len;
|
||||
let mut handshake = vec![fill; total_len];
|
||||
|
||||
handshake[0] = 0x16;
|
||||
handshake[1] = 0x03;
|
||||
handshake[2] = 0x01;
|
||||
handshake[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes());
|
||||
|
||||
let session_id_len: usize = 32;
|
||||
handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8;
|
||||
|
||||
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0);
|
||||
let computed = sha256_hmac(secret, &handshake);
|
||||
let mut digest = computed;
|
||||
let ts = timestamp.to_le_bytes();
|
||||
for i in 0..4 {
|
||||
digest[28 + i] ^= ts[i];
|
||||
}
|
||||
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN]
|
||||
.copy_from_slice(&digest);
|
||||
|
||||
handshake
|
||||
}
|
||||
|
||||
fn wrap_tls_application_data(payload: &[u8]) -> Vec<u8> {
|
||||
let mut record = Vec::with_capacity(5 + payload.len());
|
||||
record.push(0x17);
|
||||
record.extend_from_slice(&TLS_VERSION);
|
||||
record.extend_from_slice(&(payload.len() as u16).to_be_bytes());
|
||||
record.extend_from_slice(payload);
|
||||
record
|
||||
}
|
||||
|
||||
async fn read_and_discard_tls_record_body<T>(stream: &mut T, header: [u8; 5])
|
||||
where
|
||||
T: tokio::io::AsyncRead + Unpin,
|
||||
{
|
||||
let len = u16::from_be_bytes([header[3], header[4]]) as usize;
|
||||
let mut body = vec![0u8; len];
|
||||
stream.read_exact(&mut body).await.unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_initial_data_prefetch_gate_is_fail_closed() {
|
||||
assert!(
|
||||
!should_prefetch_mask_classifier_window(&[]),
|
||||
"empty initial_data must not trigger classifier prefetch"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn blackhat_empty_initial_data_prefetch_must_not_consume_fallback_payload() {
|
||||
let payload = b"\x17\x03\x03\x00\x10coalesced-tail-bytes".to_vec();
|
||||
let (mut reader, mut writer) = duplex(1024);
|
||||
|
||||
writer.write_all(&payload).await.unwrap();
|
||||
writer.shutdown().await.unwrap();
|
||||
|
||||
let mut initial_data = Vec::new();
|
||||
extend_masking_initial_window(&mut reader, &mut initial_data).await;
|
||||
|
||||
assert!(
|
||||
initial_data.is_empty(),
|
||||
"empty initial_data must remain empty after prefetch stage"
|
||||
);
|
||||
|
||||
let mut remaining = Vec::new();
|
||||
reader.read_to_end(&mut remaining).await.unwrap();
|
||||
assert_eq!(
|
||||
remaining, payload,
|
||||
"prefetch stage must not consume fallback payload when initial_data is empty"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn positive_fragmented_http_prefix_still_prefetches_within_window() {
|
||||
let (mut reader, mut writer) = duplex(1024);
|
||||
writer
|
||||
.write_all(b"NECT example.org:443 HTTP/1.1\r\n")
|
||||
.await
|
||||
.unwrap();
|
||||
writer.shutdown().await.unwrap();
|
||||
|
||||
let mut initial_data = b"CON".to_vec();
|
||||
extend_masking_initial_window(&mut reader, &mut initial_data).await;
|
||||
|
||||
assert!(
|
||||
initial_data.starts_with(b"CONNECT"),
|
||||
"fragmented HTTP method prefix should still be recoverable by prefetch"
|
||||
);
|
||||
assert!(
|
||||
initial_data.len() <= 16,
|
||||
"prefetch window must remain bounded"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn light_fuzz_empty_initial_data_never_prefetches_any_bytes() {
|
||||
let mut seed = 0xD15C_A11E_2026_0322u64;
|
||||
|
||||
for _ in 0..128 {
|
||||
seed ^= seed << 7;
|
||||
seed ^= seed >> 9;
|
||||
seed ^= seed << 8;
|
||||
|
||||
let len = ((seed & 0x3f) as usize).saturating_add(1);
|
||||
let mut payload = vec![0u8; len];
|
||||
for (idx, byte) in payload.iter_mut().enumerate() {
|
||||
*byte = (seed as u8).wrapping_add(idx as u8).wrapping_mul(17);
|
||||
}
|
||||
|
||||
let (mut reader, mut writer) = duplex(1024);
|
||||
writer.write_all(&payload).await.unwrap();
|
||||
writer.shutdown().await.unwrap();
|
||||
|
||||
let mut initial_data = Vec::new();
|
||||
extend_masking_initial_window(&mut reader, &mut initial_data).await;
|
||||
assert!(initial_data.is_empty());
|
||||
|
||||
let mut remaining = Vec::new();
|
||||
reader.read_to_end(&mut remaining).await.unwrap();
|
||||
assert_eq!(remaining, payload);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn blackhat_integration_empty_initial_data_path_is_byte_exact_and_eof_clean() {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let backend_addr = listener.local_addr().unwrap();
|
||||
|
||||
let secret = [0xD3u8; 16];
|
||||
let client_hello = make_valid_tls_client_hello(&secret, 411, 600, 0x2B);
|
||||
let mut invalid_payload = vec![0u8; HANDSHAKE_LEN];
|
||||
invalid_payload[0] = 0xFF;
|
||||
let invalid_mtproto_record = wrap_tls_application_data(&invalid_payload);
|
||||
let trailing_record = wrap_tls_application_data(b"empty-prefetch-invariant");
|
||||
let expected = trailing_record.clone();
|
||||
|
||||
let accept_task = tokio::spawn(async move {
|
||||
let (mut stream, _) = listener.accept().await.unwrap();
|
||||
|
||||
let mut got = vec![0u8; expected.len()];
|
||||
stream.read_exact(&mut got).await.unwrap();
|
||||
assert_eq!(got, expected);
|
||||
|
||||
let mut one = [0u8; 1];
|
||||
let n = stream.read(&mut one).await.unwrap();
|
||||
assert_eq!(
|
||||
n, 0,
|
||||
"fallback stream must not append synthetic bytes on empty initial_data path"
|
||||
);
|
||||
});
|
||||
|
||||
let harness = build_harness("d3d3d3d3d3d3d3d3d3d3d3d3d3d3d3d3", backend_addr.port());
|
||||
let (server_side, mut client_side) = duplex(131072);
|
||||
|
||||
let handler = tokio::spawn(handle_client_stream(
|
||||
server_side,
|
||||
"198.51.100.245:56145".parse().unwrap(),
|
||||
harness.config,
|
||||
harness.stats,
|
||||
harness.upstream_manager,
|
||||
harness.replay_checker,
|
||||
harness.buffer_pool,
|
||||
harness.rng,
|
||||
None,
|
||||
harness.route_runtime,
|
||||
None,
|
||||
harness.ip_tracker,
|
||||
harness.beobachten,
|
||||
false,
|
||||
));
|
||||
|
||||
client_side.write_all(&client_hello).await.unwrap();
|
||||
let mut head = [0u8; 5];
|
||||
client_side.read_exact(&mut head).await.unwrap();
|
||||
assert_eq!(head[0], 0x16);
|
||||
read_and_discard_tls_record_body(&mut client_side, head).await;
|
||||
|
||||
client_side.write_all(&invalid_mtproto_record).await.unwrap();
|
||||
client_side.write_all(&trailing_record).await.unwrap();
|
||||
client_side.shutdown().await.unwrap();
|
||||
|
||||
tokio::time::timeout(Duration::from_secs(3), accept_task)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
let _ = tokio::time::timeout(Duration::from_secs(3), handler)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
}
|
||||
|
|
@ -0,0 +1,70 @@
|
|||
use super::*;
|
||||
use tokio::io::{AsyncWriteExt, duplex};
|
||||
use tokio::time::{Duration, advance, sleep};
|
||||
|
||||
async fn run_strict_prefetch_case(prefetch_ms: u64, tail_delay_ms: u64) -> Vec<u8> {
|
||||
let (mut reader, mut writer) = duplex(1024);
|
||||
|
||||
let writer_task = tokio::spawn(async move {
|
||||
sleep(Duration::from_millis(tail_delay_ms)).await;
|
||||
let _ = writer.write_all(b"ONNECT example.org:443 HTTP/1.1\r\n").await;
|
||||
let _ = writer.shutdown().await;
|
||||
});
|
||||
|
||||
let mut initial_data = b"C".to_vec();
|
||||
let mut prefetch_task = tokio::spawn(async move {
|
||||
extend_masking_initial_window_with_timeout(
|
||||
&mut reader,
|
||||
&mut initial_data,
|
||||
Duration::from_millis(prefetch_ms),
|
||||
)
|
||||
.await;
|
||||
initial_data
|
||||
});
|
||||
|
||||
tokio::task::yield_now().await;
|
||||
|
||||
if tail_delay_ms > 0 {
|
||||
advance(Duration::from_millis(tail_delay_ms)).await;
|
||||
tokio::task::yield_now().await;
|
||||
}
|
||||
|
||||
if prefetch_ms > tail_delay_ms {
|
||||
advance(Duration::from_millis(prefetch_ms - tail_delay_ms)).await;
|
||||
tokio::task::yield_now().await;
|
||||
}
|
||||
|
||||
let result = prefetch_task.await.expect("prefetch task must not panic");
|
||||
writer_task.await.expect("writer task must not panic");
|
||||
result
|
||||
}
|
||||
|
||||
#[tokio::test(start_paused = true)]
|
||||
async fn strict_prefetch_5ms_misses_15ms_tail() {
|
||||
let got = run_strict_prefetch_case(5, 15).await;
|
||||
assert_eq!(got, b"C".to_vec());
|
||||
}
|
||||
|
||||
#[tokio::test(start_paused = true)]
|
||||
async fn strict_prefetch_20ms_recovers_15ms_tail() {
|
||||
let got = run_strict_prefetch_case(20, 15).await;
|
||||
assert!(got.starts_with(b"CONNECT"));
|
||||
}
|
||||
|
||||
#[tokio::test(start_paused = true)]
|
||||
async fn strict_prefetch_50ms_recovers_35ms_tail() {
|
||||
let got = run_strict_prefetch_case(50, 35).await;
|
||||
assert!(got.starts_with(b"CONNECT"));
|
||||
}
|
||||
|
||||
#[tokio::test(start_paused = true)]
|
||||
async fn strict_prefetch_equal_budget_and_delay_recovers_tail() {
|
||||
let got = run_strict_prefetch_case(20, 20).await;
|
||||
assert!(got.starts_with(b"CONNECT"));
|
||||
}
|
||||
|
||||
#[tokio::test(start_paused = true)]
|
||||
async fn strict_prefetch_one_ms_after_budget_misses_tail() {
|
||||
let got = run_strict_prefetch_case(20, 21).await;
|
||||
assert_eq!(got, b"C".to_vec());
|
||||
}
|
||||
|
|
@ -0,0 +1,95 @@
|
|||
use super::*;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt, duplex};
|
||||
use tokio::time::{Duration, sleep, timeout};
|
||||
|
||||
async fn extend_masking_initial_window_with_budget<R>(
|
||||
reader: &mut R,
|
||||
initial_data: &mut Vec<u8>,
|
||||
prefetch_timeout: Duration,
|
||||
) where
|
||||
R: AsyncRead + Unpin,
|
||||
{
|
||||
if !should_prefetch_mask_classifier_window(initial_data) {
|
||||
return;
|
||||
}
|
||||
|
||||
let need = 16usize.saturating_sub(initial_data.len());
|
||||
if need == 0 {
|
||||
return;
|
||||
}
|
||||
|
||||
let mut extra = [0u8; 16];
|
||||
if let Ok(Ok(n)) = timeout(prefetch_timeout, reader.read(&mut extra[..need])).await
|
||||
&& n > 0
|
||||
{
|
||||
initial_data.extend_from_slice(&extra[..n]);
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_prefetch_budget_case(prefetch_budget_ms: u64, delayed_tail_ms: u64) -> bool {
|
||||
let (mut reader, mut writer) = duplex(1024);
|
||||
|
||||
let writer_task = tokio::spawn(async move {
|
||||
sleep(Duration::from_millis(delayed_tail_ms)).await;
|
||||
writer
|
||||
.write_all(b"ONNECT example.org:443 HTTP/1.1\r\n")
|
||||
.await
|
||||
.expect("tail bytes must be writable");
|
||||
writer.shutdown().await.expect("writer shutdown must succeed");
|
||||
});
|
||||
|
||||
let mut initial_data = b"C".to_vec();
|
||||
extend_masking_initial_window_with_budget(
|
||||
&mut reader,
|
||||
&mut initial_data,
|
||||
Duration::from_millis(prefetch_budget_ms),
|
||||
)
|
||||
.await;
|
||||
|
||||
writer_task
|
||||
.await
|
||||
.expect("writer task must not panic during matrix case");
|
||||
|
||||
initial_data.starts_with(b"CONNECT")
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn adversarial_prefetch_budget_matrix_5_20_50ms_for_fragmented_connect_tail() {
|
||||
let cases = [
|
||||
// (tail-delay-ms, expected CONNECT recovery for budgets [5, 20, 50])
|
||||
(2u64, [true, true, true]),
|
||||
(15u64, [false, true, true]),
|
||||
(35u64, [false, false, true]),
|
||||
];
|
||||
|
||||
for (tail_delay_ms, expected) in cases {
|
||||
let got_5 = run_prefetch_budget_case(5, tail_delay_ms).await;
|
||||
let got_20 = run_prefetch_budget_case(20, tail_delay_ms).await;
|
||||
let got_50 = run_prefetch_budget_case(50, tail_delay_ms).await;
|
||||
|
||||
assert_eq!(
|
||||
got_5, expected[0],
|
||||
"5ms prefetch budget mismatch for tail delay {}ms",
|
||||
tail_delay_ms
|
||||
);
|
||||
assert_eq!(
|
||||
got_20, expected[1],
|
||||
"20ms prefetch budget mismatch for tail delay {}ms",
|
||||
tail_delay_ms
|
||||
);
|
||||
assert_eq!(
|
||||
got_50, expected[2],
|
||||
"50ms prefetch budget mismatch for tail delay {}ms",
|
||||
tail_delay_ms
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn control_current_runtime_prefetch_budget_is_5ms() {
|
||||
assert_eq!(
|
||||
MASK_CLASSIFIER_PREFETCH_TIMEOUT,
|
||||
Duration::from_millis(5),
|
||||
"matrix assumptions require current runtime prefetch budget to stay at 5ms"
|
||||
);
|
||||
}
|
||||
|
|
@ -0,0 +1,161 @@
|
|||
use super::*;
|
||||
use crate::config::{UpstreamConfig, UpstreamType};
|
||||
use crate::crypto::sha256_hmac;
|
||||
use crate::protocol::constants::{HANDSHAKE_LEN, TLS_VERSION};
|
||||
use crate::protocol::tls;
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
|
||||
use tokio::time::{Duration, Instant};
|
||||
|
||||
fn new_upstream_manager(stats: Arc<Stats>) -> Arc<UpstreamManager> {
|
||||
Arc::new(UpstreamManager::new(
|
||||
vec![UpstreamConfig {
|
||||
upstream_type: UpstreamType::Direct {
|
||||
interface: None,
|
||||
bind_addresses: None,
|
||||
},
|
||||
weight: 1,
|
||||
enabled: true,
|
||||
scopes: String::new(),
|
||||
selected_scope: String::new(),
|
||||
}],
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
false,
|
||||
stats,
|
||||
))
|
||||
}
|
||||
|
||||
fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec<u8> {
|
||||
let total_len = 5 + tls_len;
|
||||
let mut handshake = vec![fill; total_len];
|
||||
|
||||
handshake[0] = 0x16;
|
||||
handshake[1] = 0x03;
|
||||
handshake[2] = 0x01;
|
||||
handshake[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes());
|
||||
|
||||
let session_id_len: usize = 32;
|
||||
handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8;
|
||||
|
||||
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0);
|
||||
let computed = sha256_hmac(secret, &handshake);
|
||||
let mut digest = computed;
|
||||
let ts = timestamp.to_le_bytes();
|
||||
for i in 0..4 {
|
||||
digest[28 + i] ^= ts[i];
|
||||
}
|
||||
|
||||
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN]
|
||||
.copy_from_slice(&digest);
|
||||
handshake
|
||||
}
|
||||
|
||||
async fn run_replay_candidate_session(
|
||||
replay_checker: Arc<ReplayChecker>,
|
||||
hello: &[u8],
|
||||
peer: SocketAddr,
|
||||
drive_mtproto_fail: bool,
|
||||
) -> Duration {
|
||||
let mut cfg = ProxyConfig::default();
|
||||
cfg.general.beobachten = false;
|
||||
cfg.censorship.mask = true;
|
||||
cfg.censorship.mask_unix_sock = None;
|
||||
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
|
||||
cfg.censorship.mask_port = 1;
|
||||
cfg.censorship.mask_timing_normalization_enabled = false;
|
||||
cfg.access.ignore_time_skew = true;
|
||||
cfg.access
|
||||
.users
|
||||
.insert("user".to_string(), "abababababababababababababababab".to_string());
|
||||
|
||||
let config = Arc::new(cfg);
|
||||
let stats = Arc::new(Stats::new());
|
||||
let beobachten = Arc::new(BeobachtenStore::new());
|
||||
|
||||
let (server_side, mut client_side) = duplex(65536);
|
||||
let started = Instant::now();
|
||||
|
||||
let task = tokio::spawn(handle_client_stream(
|
||||
server_side,
|
||||
peer,
|
||||
config,
|
||||
stats.clone(),
|
||||
new_upstream_manager(stats),
|
||||
replay_checker,
|
||||
Arc::new(BufferPool::new()),
|
||||
Arc::new(SecureRandom::new()),
|
||||
None,
|
||||
Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
|
||||
None,
|
||||
Arc::new(UserIpTracker::new()),
|
||||
beobachten,
|
||||
false,
|
||||
));
|
||||
|
||||
client_side.write_all(hello).await.unwrap();
|
||||
|
||||
if drive_mtproto_fail {
|
||||
let mut server_hello_head = [0u8; 5];
|
||||
client_side.read_exact(&mut server_hello_head).await.unwrap();
|
||||
assert_eq!(server_hello_head[0], 0x16);
|
||||
let body_len = u16::from_be_bytes([server_hello_head[3], server_hello_head[4]]) as usize;
|
||||
let mut body = vec![0u8; body_len];
|
||||
client_side.read_exact(&mut body).await.unwrap();
|
||||
|
||||
let mut invalid_mtproto_record = Vec::with_capacity(5 + HANDSHAKE_LEN);
|
||||
invalid_mtproto_record.push(0x17);
|
||||
invalid_mtproto_record.extend_from_slice(&TLS_VERSION);
|
||||
invalid_mtproto_record.extend_from_slice(&(HANDSHAKE_LEN as u16).to_be_bytes());
|
||||
invalid_mtproto_record.extend_from_slice(&vec![0u8; HANDSHAKE_LEN]);
|
||||
client_side.write_all(&invalid_mtproto_record).await.unwrap();
|
||||
client_side
|
||||
.write_all(b"GET /replay-fallback HTTP/1.1\r\nHost: x\r\n\r\n")
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
client_side.shutdown().await.unwrap();
|
||||
|
||||
let _ = tokio::time::timeout(Duration::from_secs(4), task)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
started.elapsed()
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn replay_reject_still_honors_masking_timing_budget() {
|
||||
let replay_checker = Arc::new(ReplayChecker::new(256, Duration::from_secs(60)));
|
||||
let hello = make_valid_tls_client_hello(&[0xAB; 16], 7, 600, 0x51);
|
||||
|
||||
let seed_elapsed = run_replay_candidate_session(
|
||||
Arc::clone(&replay_checker),
|
||||
&hello,
|
||||
"198.51.100.201:58001".parse().unwrap(),
|
||||
true,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(
|
||||
seed_elapsed >= Duration::from_millis(40) && seed_elapsed < Duration::from_millis(250),
|
||||
"seed replay-candidate run must honor masking timing budget without unbounded delay"
|
||||
);
|
||||
|
||||
let replay_elapsed = run_replay_candidate_session(
|
||||
Arc::clone(&replay_checker),
|
||||
&hello,
|
||||
"198.51.100.202:58002".parse().unwrap(),
|
||||
false,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(
|
||||
replay_elapsed >= Duration::from_millis(40)
|
||||
&& replay_elapsed < Duration::from_millis(250),
|
||||
"replay rejection path must still satisfy masking timing budget without unbounded DB/CPU delay"
|
||||
);
|
||||
}
|
||||
|
|
@ -8,7 +8,7 @@ use crate::proxy::handshake::HandshakeSuccess;
|
|||
use crate::stream::{CryptoReader, CryptoWriter};
|
||||
use crate::transport::proxy_protocol::ProxyProtocolV1Builder;
|
||||
use rand::rngs::StdRng;
|
||||
use rand::RngCore;
|
||||
use rand::Rng;
|
||||
use rand::SeedableRng;
|
||||
use std::net::Ipv4Addr;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
|
||||
|
|
|
|||
|
|
@ -0,0 +1,93 @@
|
|||
use super::*;
|
||||
use std::collections::HashSet;
|
||||
use std::net::{IpAddr, Ipv4Addr};
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> {
|
||||
auth_probe_test_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(|poisoned| poisoned.into_inner())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn adversarial_large_state_offsets_escape_first_scan_window() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
let base = Instant::now();
|
||||
let state_len = 65_536usize;
|
||||
let scan_limit = 1_024usize;
|
||||
|
||||
let mut saw_offset_outside_first_window = false;
|
||||
for i in 0..8_192u64 {
|
||||
let ip = IpAddr::V4(Ipv4Addr::new(
|
||||
((i >> 16) & 0xff) as u8,
|
||||
((i >> 8) & 0xff) as u8,
|
||||
(i & 0xff) as u8,
|
||||
((i.wrapping_mul(131)) & 0xff) as u8,
|
||||
));
|
||||
let now = base + Duration::from_nanos(i);
|
||||
let start = auth_probe_scan_start_offset(ip, now, state_len, scan_limit);
|
||||
if start >= scan_limit {
|
||||
saw_offset_outside_first_window = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
assert!(
|
||||
saw_offset_outside_first_window,
|
||||
"scan start offset must cover the full auth-probe state, not only the first scan window"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stress_large_state_offsets_cover_many_scan_windows() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
let base = Instant::now();
|
||||
let state_len = 65_536usize;
|
||||
let scan_limit = 1_024usize;
|
||||
|
||||
let mut covered_windows = HashSet::new();
|
||||
for i in 0..16_384u64 {
|
||||
let ip = IpAddr::V4(Ipv4Addr::new(
|
||||
((i >> 16) & 0xff) as u8,
|
||||
((i >> 8) & 0xff) as u8,
|
||||
(i & 0xff) as u8,
|
||||
((i.wrapping_mul(17)) & 0xff) as u8,
|
||||
));
|
||||
let now = base + Duration::from_micros(i);
|
||||
let start = auth_probe_scan_start_offset(ip, now, state_len, scan_limit);
|
||||
covered_windows.insert(start / scan_limit);
|
||||
}
|
||||
|
||||
assert!(
|
||||
covered_windows.len() >= 16,
|
||||
"eviction scan must not collapse to a tiny hot zone; covered windows={} out of {}",
|
||||
covered_windows.len(),
|
||||
state_len / scan_limit
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn light_fuzz_offset_always_stays_inside_state_len() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
let mut seed = 0xC0FF_EE12_3456_789Au64;
|
||||
let base = Instant::now();
|
||||
|
||||
for _ in 0..8_192usize {
|
||||
seed ^= seed << 7;
|
||||
seed ^= seed >> 9;
|
||||
seed ^= seed << 8;
|
||||
|
||||
let ip = IpAddr::V4(Ipv4Addr::new(
|
||||
(seed >> 24) as u8,
|
||||
(seed >> 16) as u8,
|
||||
(seed >> 8) as u8,
|
||||
seed as u8,
|
||||
));
|
||||
let state_len = ((seed >> 16) as usize % 200_000).saturating_add(1);
|
||||
let scan_limit = ((seed >> 40) as usize % 2_048).saturating_add(1);
|
||||
let now = base + Duration::from_nanos(seed & 0x0fff);
|
||||
let start = auth_probe_scan_start_offset(ip, now, state_len, scan_limit);
|
||||
|
||||
assert!(start < state_len, "scan offset must stay inside state length");
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,99 @@
|
|||
use super::*;
|
||||
use std::net::{IpAddr, Ipv4Addr};
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> {
|
||||
auth_probe_test_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(|poisoned| poisoned.into_inner())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn edge_zero_state_len_yields_zero_start_offset() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
let ip = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 44));
|
||||
let now = Instant::now();
|
||||
|
||||
assert_eq!(
|
||||
auth_probe_scan_start_offset(ip, now, 0, 16),
|
||||
0,
|
||||
"empty map must not produce non-zero scan offset"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn adversarial_large_state_must_allow_start_offset_outside_scan_budget_window() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
let base = Instant::now();
|
||||
let scan_limit = 16usize;
|
||||
let state_len = 65_536usize;
|
||||
|
||||
let mut saw_offset_outside_window = false;
|
||||
for i in 0..2048u32 {
|
||||
let ip = IpAddr::V4(Ipv4Addr::new(
|
||||
203,
|
||||
((i >> 16) & 0xff) as u8,
|
||||
((i >> 8) & 0xff) as u8,
|
||||
(i & 0xff) as u8,
|
||||
));
|
||||
let now = base + Duration::from_micros(i as u64);
|
||||
let start = auth_probe_scan_start_offset(ip, now, state_len, scan_limit);
|
||||
assert!(
|
||||
start < state_len,
|
||||
"start offset must stay within state length; start={start}, len={state_len}"
|
||||
);
|
||||
if start >= scan_limit {
|
||||
saw_offset_outside_window = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
assert!(
|
||||
saw_offset_outside_window,
|
||||
"large-state eviction must sample beyond the first scan window"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn positive_state_smaller_than_scan_limit_caps_to_state_len() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
let ip = IpAddr::V4(Ipv4Addr::new(192, 0, 2, 17));
|
||||
let now = Instant::now();
|
||||
|
||||
for state_len in 1..32usize {
|
||||
let start = auth_probe_scan_start_offset(ip, now, state_len, 64);
|
||||
assert!(
|
||||
start < state_len,
|
||||
"start offset must never exceed state length when scan limit is larger"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn light_fuzz_scan_offset_budget_never_exceeds_effective_window() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
let mut seed = 0x5A41_5356_4C32_3236u64;
|
||||
let base = Instant::now();
|
||||
|
||||
for _ in 0..4096 {
|
||||
seed ^= seed << 7;
|
||||
seed ^= seed >> 9;
|
||||
seed ^= seed << 8;
|
||||
|
||||
let ip = IpAddr::V4(Ipv4Addr::new(
|
||||
(seed >> 24) as u8,
|
||||
(seed >> 16) as u8,
|
||||
(seed >> 8) as u8,
|
||||
seed as u8,
|
||||
));
|
||||
let state_len = ((seed >> 8) as usize % 131_072).saturating_add(1);
|
||||
let scan_limit = ((seed >> 32) as usize % 512).saturating_add(1);
|
||||
let now = base + Duration::from_nanos(seed & 0xffff);
|
||||
let start = auth_probe_scan_start_offset(ip, now, state_len, scan_limit);
|
||||
|
||||
assert!(
|
||||
start < state_len,
|
||||
"scan offset must stay inside state length"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,116 @@
|
|||
use super::*;
|
||||
use std::collections::HashSet;
|
||||
use std::net::{IpAddr, Ipv4Addr};
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> {
|
||||
auth_probe_test_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(|poisoned| poisoned.into_inner())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn positive_same_ip_moving_time_yields_diverse_scan_offsets() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
let ip = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 77));
|
||||
let base = Instant::now();
|
||||
let mut uniq = HashSet::new();
|
||||
|
||||
for i in 0..512u64 {
|
||||
let now = base + Duration::from_nanos(i);
|
||||
let offset = auth_probe_scan_start_offset(ip, now, 65_536, 16);
|
||||
uniq.insert(offset);
|
||||
}
|
||||
|
||||
assert!(
|
||||
uniq.len() >= 256,
|
||||
"offset randomization collapsed unexpectedly for same-ip moving-time samples (uniq={})",
|
||||
uniq.len()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn adversarial_many_ips_same_time_spreads_offsets_without_bias_collapse() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
let now = Instant::now();
|
||||
let mut uniq = HashSet::new();
|
||||
|
||||
for i in 0..1024u32 {
|
||||
let ip = IpAddr::V4(Ipv4Addr::new(
|
||||
(i >> 16) as u8,
|
||||
(i >> 8) as u8,
|
||||
i as u8,
|
||||
(255 - (i as u8)),
|
||||
));
|
||||
uniq.insert(auth_probe_scan_start_offset(ip, now, 65_536, 16));
|
||||
}
|
||||
|
||||
assert!(
|
||||
uniq.len() >= 512,
|
||||
"scan offset distribution collapsed unexpectedly across adversarial peer set (uniq={})",
|
||||
uniq.len()
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn stress_parallel_failure_churn_under_saturation_remains_capped_and_live() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
clear_auth_probe_state_for_testing();
|
||||
|
||||
let start = Instant::now();
|
||||
let mut workers = Vec::new();
|
||||
for worker in 0..8u8 {
|
||||
workers.push(tokio::spawn(async move {
|
||||
for i in 0..8192u32 {
|
||||
let ip = IpAddr::V4(Ipv4Addr::new(
|
||||
10,
|
||||
worker,
|
||||
((i >> 8) & 0xff) as u8,
|
||||
(i & 0xff) as u8,
|
||||
));
|
||||
auth_probe_record_failure(ip, start + Duration::from_micros((i % 128) as u64));
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
for worker in workers {
|
||||
worker.await.expect("saturation worker must not panic");
|
||||
}
|
||||
|
||||
assert!(
|
||||
auth_probe_state_map().len() <= AUTH_PROBE_TRACK_MAX_ENTRIES,
|
||||
"state must remain hard-capped under parallel saturation churn"
|
||||
);
|
||||
|
||||
let probe = IpAddr::V4(Ipv4Addr::new(10, 4, 1, 1));
|
||||
let _ = auth_probe_should_apply_preauth_throttle(probe, start + Duration::from_millis(1));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn light_fuzz_scan_offset_stays_within_window_for_randomized_inputs() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
let mut seed = 0xA55A_1357_2468_9BDFu64;
|
||||
let base = Instant::now();
|
||||
|
||||
for _ in 0..8192 {
|
||||
seed ^= seed << 7;
|
||||
seed ^= seed >> 9;
|
||||
seed ^= seed << 8;
|
||||
|
||||
let ip = IpAddr::V4(Ipv4Addr::new(
|
||||
(seed >> 24) as u8,
|
||||
(seed >> 16) as u8,
|
||||
(seed >> 8) as u8,
|
||||
seed as u8,
|
||||
));
|
||||
let state_len = ((seed >> 8) as usize % 200_000).saturating_add(1);
|
||||
let scan_limit = ((seed >> 40) as usize % 1024).saturating_add(1);
|
||||
let now = base + Duration::from_nanos(seed & 0x1fff);
|
||||
|
||||
let offset = auth_probe_scan_start_offset(ip, now, state_len, scan_limit);
|
||||
assert!(
|
||||
offset < state_len,
|
||||
"scan offset must always remain inside state length"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,42 @@
|
|||
use super::*;
|
||||
|
||||
fn handshake_source() -> &'static str {
|
||||
include_str!("../handshake.rs")
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn security_dec_key_derivation_is_zeroized_in_candidate_loop() {
|
||||
let src = handshake_source();
|
||||
assert!(
|
||||
src.contains("let dec_key = Zeroizing::new(sha256(&dec_key_input));"),
|
||||
"candidate-loop dec_key derivation must be wrapped in Zeroizing to clear secrets on early-continue paths"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn security_enc_key_derivation_is_zeroized_in_candidate_loop() {
|
||||
let src = handshake_source();
|
||||
assert!(
|
||||
src.contains("let enc_key = Zeroizing::new(sha256(&enc_key_input));"),
|
||||
"candidate-loop enc_key derivation must be wrapped in Zeroizing to clear secrets on early-continue paths"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn security_aes_ctr_initialization_uses_zeroizing_references() {
|
||||
let src = handshake_source();
|
||||
assert!(
|
||||
src.contains("let mut decryptor = AesCtr::new(&dec_key, dec_iv);")
|
||||
&& src.contains("let encryptor = AesCtr::new(&enc_key, enc_iv);"),
|
||||
"AES-CTR initialization must use Zeroizing key wrappers directly without creating extra plain key variables"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn security_success_struct_copies_out_of_zeroizing_wrappers() {
|
||||
let src = handshake_source();
|
||||
assert!(
|
||||
src.contains("dec_key: *dec_key,") && src.contains("enc_key: *enc_key,"),
|
||||
"HandshakeSuccess construction must copy from Zeroizing wrappers so loop-local key material is dropped and zeroized"
|
||||
);
|
||||
}
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
use super::*;
|
||||
use crate::crypto::{sha256, sha256_hmac, AesCtr};
|
||||
use crate::protocol::constants::{ProtoTag, RESERVED_NONCE_BEGINNINGS, RESERVED_NONCE_FIRST_BYTES};
|
||||
use rand::{RngExt, SeedableRng};
|
||||
use rand::{Rng, SeedableRng};
|
||||
use rand::rngs::StdRng;
|
||||
use std::collections::HashSet;
|
||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
|
||||
|
|
|
|||
|
|
@ -493,9 +493,12 @@ async fn timing_classifier_light_fuzz_pairwise_bucketed_accuracy_stays_bounded_u
|
|||
];
|
||||
|
||||
let mut meaningful_improvement_seen = false;
|
||||
let mut baseline_sum = 0.0f64;
|
||||
let mut hardened_sum = 0.0f64;
|
||||
let mut pair_count = 0usize;
|
||||
let mut informative_baseline_sum = 0.0f64;
|
||||
let mut informative_hardened_sum = 0.0f64;
|
||||
let mut informative_pair_count = 0usize;
|
||||
let mut low_info_baseline_sum = 0.0f64;
|
||||
let mut low_info_hardened_sum = 0.0f64;
|
||||
let mut low_info_pair_count = 0usize;
|
||||
let acc_quant_step = 1.0 / (2 * SAMPLE_COUNT) as f64;
|
||||
let tolerated_pair_regression = acc_quant_step + 0.03;
|
||||
|
||||
|
|
@ -522,6 +525,16 @@ async fn timing_classifier_light_fuzz_pairwise_bucketed_accuracy_stays_bounded_u
|
|||
hardened_acc <= baseline_acc + tolerated_pair_regression,
|
||||
"normalization should not materially worsen informative pair: baseline={baseline_acc:.3} hardened={hardened_acc:.3} tolerated={tolerated_pair_regression:.3}"
|
||||
);
|
||||
informative_baseline_sum += baseline_acc;
|
||||
informative_hardened_sum += hardened_acc;
|
||||
informative_pair_count += 1;
|
||||
} else {
|
||||
// Low-information pairs (near-random baseline separability) are expected
|
||||
// to exhibit quantized jitter at low sample counts; do not fold them into
|
||||
// strict average-regression checks used for informative side-channel signal.
|
||||
low_info_baseline_sum += baseline_acc;
|
||||
low_info_hardened_sum += hardened_acc;
|
||||
low_info_pair_count += 1;
|
||||
}
|
||||
|
||||
println!(
|
||||
|
|
@ -532,19 +545,30 @@ async fn timing_classifier_light_fuzz_pairwise_bucketed_accuracy_stays_bounded_u
|
|||
meaningful_improvement_seen = true;
|
||||
}
|
||||
|
||||
baseline_sum += baseline_acc;
|
||||
hardened_sum += hardened_acc;
|
||||
pair_count += 1;
|
||||
}
|
||||
|
||||
let baseline_avg = baseline_sum / pair_count as f64;
|
||||
let hardened_avg = hardened_sum / pair_count as f64;
|
||||
assert!(
|
||||
informative_pair_count > 0,
|
||||
"expected at least one informative pair for timing-separability guard"
|
||||
);
|
||||
|
||||
let informative_baseline_avg = informative_baseline_sum / informative_pair_count as f64;
|
||||
let informative_hardened_avg = informative_hardened_sum / informative_pair_count as f64;
|
||||
|
||||
assert!(
|
||||
hardened_avg <= baseline_avg + 0.10,
|
||||
"normalization should not materially increase average pairwise separability: baseline_avg={baseline_avg:.3} hardened_avg={hardened_avg:.3}"
|
||||
informative_hardened_avg <= informative_baseline_avg + 0.10,
|
||||
"normalization should not materially increase informative average separability: baseline_avg={informative_baseline_avg:.3} hardened_avg={informative_hardened_avg:.3}"
|
||||
);
|
||||
|
||||
if low_info_pair_count > 0 {
|
||||
let low_info_baseline_avg = low_info_baseline_sum / low_info_pair_count as f64;
|
||||
let low_info_hardened_avg = low_info_hardened_sum / low_info_pair_count as f64;
|
||||
assert!(
|
||||
low_info_hardened_avg <= low_info_baseline_avg + 0.40,
|
||||
"normalization low-info average drift exceeded jitter budget: baseline_avg={low_info_baseline_avg:.3} hardened_avg={low_info_hardened_avg:.3}"
|
||||
);
|
||||
}
|
||||
|
||||
// Optional signal only: do not require improvement on every run because
|
||||
// noisy CI schedulers can flatten pairwise differences at low sample counts.
|
||||
let _ = meaningful_improvement_seen;
|
||||
|
|
|
|||
|
|
@ -0,0 +1,122 @@
|
|||
use super::*;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::task::{Context, Poll};
|
||||
use tokio::io::AsyncRead;
|
||||
use tokio::time::{Duration, timeout};
|
||||
|
||||
struct EndlessReader {
|
||||
produced: Arc<AtomicUsize>,
|
||||
}
|
||||
|
||||
impl AsyncRead for EndlessReader {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
buf: &mut tokio::io::ReadBuf<'_>,
|
||||
) -> Poll<std::io::Result<()>> {
|
||||
let len = buf.remaining().max(1);
|
||||
let fill = vec![0xAA; len];
|
||||
buf.put_slice(&fill);
|
||||
self.produced.fetch_add(len, Ordering::Relaxed);
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn loop_guard_unspecified_bind_uses_interface_inventory() {
|
||||
let local: SocketAddr = "0.0.0.0:443".parse().unwrap();
|
||||
let resolved: SocketAddr = "192.168.44.10:443".parse().unwrap();
|
||||
let interfaces = vec!["192.168.44.10".parse().unwrap()];
|
||||
|
||||
assert!(is_mask_target_local_listener_with_interfaces(
|
||||
"mask.example",
|
||||
443,
|
||||
local,
|
||||
Some(resolved),
|
||||
&interfaces,
|
||||
));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn consume_client_data_stops_after_byte_cap_without_eof() {
|
||||
let produced = Arc::new(AtomicUsize::new(0));
|
||||
let reader = EndlessReader {
|
||||
produced: Arc::clone(&produced),
|
||||
};
|
||||
let cap = 10_000usize;
|
||||
|
||||
consume_client_data(reader, cap).await;
|
||||
|
||||
let total = produced.load(Ordering::Relaxed);
|
||||
assert!(
|
||||
total >= cap,
|
||||
"consume path must read at least up to cap before stopping"
|
||||
);
|
||||
assert!(
|
||||
total <= cap + 8192,
|
||||
"consume path must stop within one read chunk above cap"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn masking_beobachten_minutes_zero_fail_closes_to_minimum_ttl() {
|
||||
let mut config = ProxyConfig::default();
|
||||
config.general.beobachten = true;
|
||||
config.general.beobachten_minutes = 0;
|
||||
|
||||
let ttl = masking_beobachten_ttl(&config);
|
||||
assert_eq!(ttl, std::time::Duration::from_secs(60));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn timing_normalization_zero_floor_safety_net_defaults_to_mask_timeout() {
|
||||
let mut config = ProxyConfig::default();
|
||||
config.censorship.mask_timing_normalization_enabled = true;
|
||||
config.censorship.mask_timing_normalization_floor_ms = 0;
|
||||
config.censorship.mask_timing_normalization_ceiling_ms = 0;
|
||||
|
||||
let budget = mask_outcome_target_budget(&config);
|
||||
assert_eq!(budget, MASK_TIMEOUT);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn loop_guard_blocks_self_target_before_proxy_protocol_header_growth() {
|
||||
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let backend_addr = listener.local_addr().unwrap();
|
||||
let accept_task = tokio::spawn(async move {
|
||||
timeout(Duration::from_millis(120), listener.accept())
|
||||
.await
|
||||
.is_ok()
|
||||
});
|
||||
|
||||
let mut config = ProxyConfig::default();
|
||||
config.general.beobachten = false;
|
||||
config.censorship.mask = true;
|
||||
config.censorship.mask_unix_sock = None;
|
||||
config.censorship.mask_host = Some("127.0.0.1".to_string());
|
||||
config.censorship.mask_port = backend_addr.port();
|
||||
config.censorship.mask_proxy_protocol = 2;
|
||||
|
||||
let peer: SocketAddr = "203.0.113.251:55991".parse().unwrap();
|
||||
let local_addr: SocketAddr = format!("0.0.0.0:{}", backend_addr.port()).parse().unwrap();
|
||||
let beobachten = BeobachtenStore::new();
|
||||
|
||||
handle_bad_client(
|
||||
tokio::io::empty(),
|
||||
tokio::io::sink(),
|
||||
b"GET / HTTP/1.1\r\n\r\n",
|
||||
peer,
|
||||
local_addr,
|
||||
&config,
|
||||
&beobachten,
|
||||
)
|
||||
.await;
|
||||
|
||||
let accepted = accept_task.await.unwrap();
|
||||
assert!(
|
||||
!accepted,
|
||||
"loop guard must fail closed before any recursive PROXY protocol amplification"
|
||||
);
|
||||
}
|
||||
|
|
@ -0,0 +1,16 @@
|
|||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn detect_client_type_recognizes_extended_http_probe_verbs() {
|
||||
assert_eq!(detect_client_type(b"CONNECT / HTTP/1.1\r\n"), "HTTP");
|
||||
assert_eq!(detect_client_type(b"TRACE / HTTP/1.1\r\n"), "HTTP");
|
||||
assert_eq!(detect_client_type(b"PATCH / HTTP/1.1\r\n"), "HTTP");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detect_client_type_recognizes_fragmented_http_method_prefixes() {
|
||||
assert_eq!(detect_client_type(b"CO"), "HTTP");
|
||||
assert_eq!(detect_client_type(b"CON"), "HTTP");
|
||||
assert_eq!(detect_client_type(b"TR"), "HTTP");
|
||||
assert_eq!(detect_client_type(b"PAT"), "HTTP");
|
||||
}
|
||||
|
|
@ -0,0 +1,127 @@
|
|||
use super::*;
|
||||
use crate::network::dns_overrides::install_entries;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
|
||||
use tokio::time::{Duration, Instant, timeout};
|
||||
|
||||
async fn run_connect_failure_case(
|
||||
host: &str,
|
||||
port: u16,
|
||||
timing_normalization_enabled: bool,
|
||||
peer: SocketAddr,
|
||||
) -> Duration {
|
||||
let mut config = ProxyConfig::default();
|
||||
config.general.beobachten = false;
|
||||
config.censorship.mask = true;
|
||||
config.censorship.mask_unix_sock = None;
|
||||
config.censorship.mask_host = Some(host.to_string());
|
||||
config.censorship.mask_port = port;
|
||||
config.censorship.mask_timing_normalization_enabled = timing_normalization_enabled;
|
||||
config.censorship.mask_timing_normalization_floor_ms = 120;
|
||||
config.censorship.mask_timing_normalization_ceiling_ms = 120;
|
||||
|
||||
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
|
||||
let beobachten = BeobachtenStore::new();
|
||||
let probe = b"CONNECT example.org:443 HTTP/1.1\r\nHost: example.org\r\n\r\n";
|
||||
|
||||
let (mut client_writer, client_reader) = duplex(1024);
|
||||
let (mut client_visible_reader, client_visible_writer) = duplex(1024);
|
||||
|
||||
let started = Instant::now();
|
||||
let task = tokio::spawn(async move {
|
||||
handle_bad_client(
|
||||
client_reader,
|
||||
client_visible_writer,
|
||||
probe,
|
||||
peer,
|
||||
local_addr,
|
||||
&config,
|
||||
&beobachten,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
|
||||
client_writer.shutdown().await.unwrap();
|
||||
|
||||
timeout(Duration::from_secs(4), task)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
let mut buf = [0u8; 1];
|
||||
let n = timeout(Duration::from_secs(1), client_visible_reader.read(&mut buf))
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
assert_eq!(n, 0, "connect-failure path must close client-visible writer");
|
||||
|
||||
started.elapsed()
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn connect_failure_refusal_close_behavior_matrix() {
|
||||
let temp_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let unused_port = temp_listener.local_addr().unwrap().port();
|
||||
drop(temp_listener);
|
||||
|
||||
for (idx, timing_normalization_enabled) in [false, true].into_iter().enumerate() {
|
||||
let peer: SocketAddr = format!("203.0.113.210:{}", 54100 + idx as u16)
|
||||
.parse()
|
||||
.unwrap();
|
||||
let elapsed = run_connect_failure_case(
|
||||
"127.0.0.1",
|
||||
unused_port,
|
||||
timing_normalization_enabled,
|
||||
peer,
|
||||
)
|
||||
.await;
|
||||
|
||||
if timing_normalization_enabled {
|
||||
assert!(
|
||||
elapsed >= Duration::from_millis(110) && elapsed < Duration::from_millis(250),
|
||||
"normalized refusal path must honor configured timing envelope without stalling"
|
||||
);
|
||||
} else {
|
||||
assert!(
|
||||
elapsed >= Duration::from_millis(40) && elapsed < Duration::from_millis(150),
|
||||
"non-normalized refusal path must honor baseline connect budget without stalling"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn connect_failure_overridden_hostname_close_behavior_matrix() {
|
||||
let temp_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let unused_port = temp_listener.local_addr().unwrap().port();
|
||||
drop(temp_listener);
|
||||
|
||||
// Make hostname resolution deterministic in tests so timing ceilings are meaningful.
|
||||
install_entries(&[format!("mask.invalid:{}:127.0.0.1", unused_port)]).unwrap();
|
||||
|
||||
for (idx, timing_normalization_enabled) in [false, true].into_iter().enumerate() {
|
||||
let peer: SocketAddr = format!("203.0.113.220:{}", 54200 + idx as u16)
|
||||
.parse()
|
||||
.unwrap();
|
||||
let elapsed = run_connect_failure_case(
|
||||
"mask.invalid",
|
||||
unused_port,
|
||||
timing_normalization_enabled,
|
||||
peer,
|
||||
)
|
||||
.await;
|
||||
|
||||
if timing_normalization_enabled {
|
||||
assert!(
|
||||
elapsed >= Duration::from_millis(110) && elapsed < Duration::from_millis(250),
|
||||
"normalized overridden-host path must honor configured timing envelope without stalling"
|
||||
);
|
||||
} else {
|
||||
assert!(
|
||||
elapsed >= Duration::from_millis(40) && elapsed < Duration::from_millis(150),
|
||||
"non-normalized overridden-host path must honor baseline connect budget without stalling"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
install_entries(&[]).unwrap();
|
||||
}
|
||||
|
|
@ -0,0 +1,85 @@
|
|||
use super::*;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
use std::time::Instant;
|
||||
use tokio::io::{AsyncRead, ReadBuf};
|
||||
|
||||
struct OneByteThenStall {
|
||||
sent: bool,
|
||||
}
|
||||
|
||||
impl AsyncRead for OneByteThenStall {
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<std::io::Result<()>> {
|
||||
if !self.sent {
|
||||
self.sent = true;
|
||||
buf.put_slice(&[0x42]);
|
||||
Poll::Ready(Ok(()))
|
||||
} else {
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn stalling_client_terminates_at_idle_not_relay_timeout() {
|
||||
let reader = OneByteThenStall { sent: false };
|
||||
let started = Instant::now();
|
||||
|
||||
let result = tokio::time::timeout(
|
||||
MASK_RELAY_TIMEOUT,
|
||||
consume_client_data(reader, MASK_BUFFER_SIZE * 4),
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"consume_client_data should complete by per-read idle timeout, not hit relay timeout"
|
||||
);
|
||||
|
||||
let elapsed = started.elapsed();
|
||||
assert!(
|
||||
elapsed >= (MASK_RELAY_IDLE_TIMEOUT / 2),
|
||||
"consume_client_data returned too quickly for idle-timeout path: {elapsed:?}"
|
||||
);
|
||||
assert!(
|
||||
elapsed < MASK_RELAY_TIMEOUT,
|
||||
"consume_client_data waited full relay timeout ({elapsed:?}); \
|
||||
per-read idle timeout is missing"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn fast_reader_drains_to_eof() {
|
||||
let data = vec![0xAAu8; 32 * 1024];
|
||||
let reader = std::io::Cursor::new(data);
|
||||
|
||||
tokio::time::timeout(MASK_RELAY_TIMEOUT, consume_client_data(reader, usize::MAX))
|
||||
.await
|
||||
.expect("consume_client_data did not complete for fast EOF reader");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn io_error_terminates_cleanly() {
|
||||
struct ErrReader;
|
||||
|
||||
impl AsyncRead for ErrReader {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
_buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<std::io::Result<()>> {
|
||||
Poll::Ready(Err(std::io::Error::new(
|
||||
std::io::ErrorKind::ConnectionReset,
|
||||
"simulated reset",
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
tokio::time::timeout(MASK_RELAY_TIMEOUT, consume_client_data(ErrReader, usize::MAX))
|
||||
.await
|
||||
.expect("consume_client_data did not return on I/O error");
|
||||
}
|
||||
|
|
@ -0,0 +1,64 @@
|
|||
use super::*;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
use std::time::Instant;
|
||||
use tokio::io::{AsyncRead, ReadBuf};
|
||||
use tokio::task::JoinSet;
|
||||
|
||||
struct OneByteThenStall {
|
||||
sent: bool,
|
||||
}
|
||||
|
||||
impl AsyncRead for OneByteThenStall {
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<std::io::Result<()>> {
|
||||
if !self.sent {
|
||||
self.sent = true;
|
||||
buf.put_slice(&[0xAA]);
|
||||
Poll::Ready(Ok(()))
|
||||
} else {
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn consume_stall_stress_finishes_within_idle_budget() {
|
||||
let mut set = JoinSet::new();
|
||||
let started = Instant::now();
|
||||
|
||||
for _ in 0..64 {
|
||||
set.spawn(async {
|
||||
tokio::time::timeout(
|
||||
MASK_RELAY_TIMEOUT,
|
||||
consume_client_data(OneByteThenStall { sent: false }, usize::MAX),
|
||||
)
|
||||
.await
|
||||
.expect("consume_client_data exceeded relay timeout under stall load");
|
||||
});
|
||||
}
|
||||
|
||||
while let Some(res) = set.join_next().await {
|
||||
res.unwrap();
|
||||
}
|
||||
|
||||
// Under test constants idle=100ms, relay=200ms. 64 concurrent tasks stalling
|
||||
// for 100ms should complete well under a strict 600ms boundary.
|
||||
assert!(
|
||||
started.elapsed() < MASK_RELAY_TIMEOUT * 3,
|
||||
"stall stress batch completed too slowly; possible async executor starvation or head-of-line blocking"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn consume_zero_cap_returns_immediately() {
|
||||
let started = Instant::now();
|
||||
consume_client_data(tokio::io::empty(), 0).await;
|
||||
assert!(
|
||||
started.elapsed() < MASK_RELAY_IDLE_TIMEOUT,
|
||||
"zero byte cap must return immediately"
|
||||
);
|
||||
}
|
||||
|
|
@ -0,0 +1,217 @@
|
|||
use super::*;
|
||||
use tokio::io::{AsyncWriteExt, duplex};
|
||||
use tokio::time::{Duration, Instant, timeout};
|
||||
|
||||
fn make_self_target_config(
|
||||
timing_normalization_enabled: bool,
|
||||
floor_ms: u64,
|
||||
ceiling_ms: u64,
|
||||
beobachten_enabled: bool,
|
||||
) -> ProxyConfig {
|
||||
let mut config = ProxyConfig::default();
|
||||
config.general.beobachten = beobachten_enabled;
|
||||
config.general.beobachten_minutes = 5;
|
||||
config.censorship.mask = true;
|
||||
config.censorship.mask_unix_sock = None;
|
||||
config.censorship.mask_host = Some("127.0.0.1".to_string());
|
||||
config.censorship.mask_port = 443;
|
||||
config.censorship.mask_timing_normalization_enabled = timing_normalization_enabled;
|
||||
config.censorship.mask_timing_normalization_floor_ms = floor_ms;
|
||||
config.censorship.mask_timing_normalization_ceiling_ms = ceiling_ms;
|
||||
config
|
||||
}
|
||||
|
||||
async fn run_self_target_refusal(
|
||||
config: ProxyConfig,
|
||||
peer: SocketAddr,
|
||||
initial: &'static [u8],
|
||||
) -> Duration {
|
||||
let beobachten = BeobachtenStore::new();
|
||||
let local_addr: SocketAddr = "127.0.0.1:443".parse().expect("valid local addr");
|
||||
|
||||
let (mut client, server) = duplex(1024);
|
||||
let started = Instant::now();
|
||||
let task = tokio::spawn(async move {
|
||||
handle_bad_client(server, tokio::io::sink(), initial, peer, local_addr, &config, &beobachten)
|
||||
.await;
|
||||
});
|
||||
|
||||
client
|
||||
.shutdown()
|
||||
.await
|
||||
.expect("client shutdown must succeed");
|
||||
|
||||
timeout(Duration::from_secs(3), task)
|
||||
.await
|
||||
.expect("self-target refusal must complete in bounded time")
|
||||
.expect("self-target refusal task must not panic");
|
||||
|
||||
started.elapsed()
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn positive_self_target_refusal_honors_normalization_floor() {
|
||||
let config = make_self_target_config(true, 120, 120, false);
|
||||
let peer: SocketAddr = "203.0.113.41:54041".parse().expect("valid peer");
|
||||
|
||||
let elapsed = run_self_target_refusal(config, peer, b"GET / HTTP/1.1\r\n\r\n").await;
|
||||
|
||||
assert!(
|
||||
elapsed >= Duration::from_millis(110) && elapsed < Duration::from_millis(260),
|
||||
"normalized self-target refusal must stay within expected envelope"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn negative_non_normalized_refusal_does_not_sleep_to_large_floor() {
|
||||
let config = make_self_target_config(false, 240, 240, false);
|
||||
let peer: SocketAddr = "203.0.113.42:54042".parse().expect("valid peer");
|
||||
|
||||
let elapsed = run_self_target_refusal(config, peer, b"GET / HTTP/1.1\r\n\r\n").await;
|
||||
|
||||
assert!(
|
||||
elapsed < Duration::from_millis(180),
|
||||
"non-normalized path must not inherit normalization floor delays"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn edge_ceiling_below_floor_uses_floor_fail_closed() {
|
||||
let config = make_self_target_config(true, 140, 80, false);
|
||||
let peer: SocketAddr = "203.0.113.43:54043".parse().expect("valid peer");
|
||||
|
||||
let elapsed = run_self_target_refusal(config, peer, b"GET / HTTP/1.1\r\n\r\n").await;
|
||||
|
||||
assert!(
|
||||
elapsed >= Duration::from_millis(130) && elapsed < Duration::from_millis(280),
|
||||
"ceiling<floor must clamp to floor to preserve deterministic normalization"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn adversarial_blackhat_parallel_probes_remain_bounded_and_uniform() {
|
||||
let workers = 24usize;
|
||||
let mut tasks = Vec::with_capacity(workers);
|
||||
|
||||
for idx in 0..workers {
|
||||
tasks.push(tokio::spawn(async move {
|
||||
let cfg = make_self_target_config(true, 110, 140, false);
|
||||
let peer: SocketAddr = format!("203.0.113.50:{}", 54100 + idx as u16)
|
||||
.parse()
|
||||
.expect("valid peer");
|
||||
run_self_target_refusal(cfg, peer, b"GET /x HTTP/1.1\r\n\r\n").await
|
||||
}));
|
||||
}
|
||||
|
||||
let mut min = Duration::from_secs(60);
|
||||
let mut max = Duration::from_millis(0);
|
||||
for task in tasks {
|
||||
let elapsed = task.await.expect("probe task must not panic");
|
||||
if elapsed < min {
|
||||
min = elapsed;
|
||||
}
|
||||
if elapsed > max {
|
||||
max = elapsed;
|
||||
}
|
||||
assert!(
|
||||
elapsed >= Duration::from_millis(100) && elapsed < Duration::from_millis(320),
|
||||
"parallel probe latency must stay bounded under normalization"
|
||||
);
|
||||
}
|
||||
|
||||
assert!(
|
||||
max.saturating_sub(min) <= Duration::from_millis(130),
|
||||
"normalization should limit path variance across adversarial parallel probes"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn integration_beobachten_records_probe_classification_on_refusal() {
|
||||
let config = make_self_target_config(false, 0, 0, true);
|
||||
let peer: SocketAddr = "198.51.100.71:55071".parse().expect("valid peer");
|
||||
let local_addr: SocketAddr = "127.0.0.1:443".parse().expect("valid local addr");
|
||||
let beobachten = BeobachtenStore::new();
|
||||
|
||||
let (mut client, server) = duplex(1024);
|
||||
let task = tokio::spawn(async move {
|
||||
handle_bad_client(
|
||||
server,
|
||||
tokio::io::sink(),
|
||||
b"GET /classified HTTP/1.1\r\nHost: demo\r\n\r\n",
|
||||
peer,
|
||||
local_addr,
|
||||
&config,
|
||||
&beobachten,
|
||||
)
|
||||
.await;
|
||||
|
||||
beobachten.snapshot_text(Duration::from_secs(60))
|
||||
});
|
||||
|
||||
client
|
||||
.shutdown()
|
||||
.await
|
||||
.expect("client shutdown must succeed");
|
||||
|
||||
let snapshot = timeout(Duration::from_secs(3), task)
|
||||
.await
|
||||
.expect("integration task must complete")
|
||||
.expect("integration task must not panic");
|
||||
|
||||
assert!(snapshot.contains("[HTTP]"));
|
||||
assert!(snapshot.contains("198.51.100.71-1"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn light_fuzz_timing_configuration_matrix_is_bounded() {
|
||||
let mut seed = 0xA17E_55AA_2026_0323u64;
|
||||
|
||||
for case in 0..48u64 {
|
||||
seed ^= seed << 7;
|
||||
seed ^= seed >> 9;
|
||||
seed ^= seed << 8;
|
||||
|
||||
let enabled = (seed & 1) == 0;
|
||||
let floor = (seed >> 8) % 180;
|
||||
let ceiling = (seed >> 24) % 180;
|
||||
let config = make_self_target_config(enabled, floor, ceiling, false);
|
||||
let peer: SocketAddr = format!("203.0.113.90:{}", 56000 + (case as u16))
|
||||
.parse()
|
||||
.expect("valid peer");
|
||||
|
||||
let elapsed = run_self_target_refusal(config, peer, b"HEAD /h HTTP/1.1\r\n\r\n").await;
|
||||
|
||||
assert!(
|
||||
elapsed < Duration::from_millis(420),
|
||||
"fuzz case must stay bounded and never hang"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn stress_high_fanout_self_target_refusal_no_deadlock_or_timeout() {
|
||||
let workers = 64usize;
|
||||
let mut tasks = Vec::with_capacity(workers);
|
||||
|
||||
for idx in 0..workers {
|
||||
tasks.push(tokio::spawn(async move {
|
||||
let config = make_self_target_config(false, 0, 0, false);
|
||||
let peer: SocketAddr = format!("198.51.100.200:{}", 57000 + idx as u16)
|
||||
.parse()
|
||||
.expect("valid peer");
|
||||
run_self_target_refusal(config, peer, b"GET /stress HTTP/1.1\r\n\r\n").await
|
||||
}));
|
||||
}
|
||||
|
||||
timeout(Duration::from_secs(5), async {
|
||||
for task in tasks {
|
||||
let elapsed = task.await.expect("stress task must not panic");
|
||||
assert!(
|
||||
elapsed < Duration::from_millis(260),
|
||||
"stress refusal must remain bounded without normalization"
|
||||
);
|
||||
}
|
||||
})
|
||||
.await
|
||||
.expect("high-fanout refusal workload must complete without deadlock");
|
||||
}
|
||||
|
|
@ -0,0 +1,55 @@
|
|||
use super::*;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::time::Duration;
|
||||
|
||||
#[tokio::test]
|
||||
async fn http2_preface_is_forwarded_and_recorded_as_http() {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let backend_addr = listener.local_addr().unwrap();
|
||||
let preface = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n".to_vec();
|
||||
|
||||
let accept_task = tokio::spawn({
|
||||
let preface = preface.clone();
|
||||
async move {
|
||||
let (mut stream, _) = listener.accept().await.unwrap();
|
||||
let mut received = vec![0u8; preface.len()];
|
||||
stream.read_exact(&mut received).await.unwrap();
|
||||
assert_eq!(received, preface);
|
||||
}
|
||||
});
|
||||
|
||||
let mut config = ProxyConfig::default();
|
||||
config.general.beobachten = true;
|
||||
config.general.beobachten_minutes = 1;
|
||||
config.censorship.mask = true;
|
||||
config.censorship.mask_host = Some("127.0.0.1".to_string());
|
||||
config.censorship.mask_port = backend_addr.port();
|
||||
config.censorship.mask_unix_sock = None;
|
||||
config.censorship.mask_proxy_protocol = 0;
|
||||
|
||||
let peer: SocketAddr = "198.51.100.130:54130".parse().unwrap();
|
||||
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
|
||||
let (client_reader, _client_writer) = tokio::io::duplex(512);
|
||||
let (_client_visible_reader, client_visible_writer) = tokio::io::duplex(512);
|
||||
let beobachten = BeobachtenStore::new();
|
||||
|
||||
handle_bad_client(
|
||||
client_reader,
|
||||
client_visible_writer,
|
||||
&preface,
|
||||
peer,
|
||||
local_addr,
|
||||
&config,
|
||||
&beobachten,
|
||||
)
|
||||
.await;
|
||||
|
||||
tokio::time::timeout(Duration::from_secs(2), accept_task)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
let snapshot = beobachten.snapshot_text(Duration::from_secs(60));
|
||||
assert!(snapshot.contains("[HTTP]"));
|
||||
assert!(snapshot.contains("198.51.100.130-1"));
|
||||
}
|
||||
|
|
@ -0,0 +1,92 @@
|
|||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn full_http2_preface_classified_as_http_probe() {
|
||||
let preface = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n";
|
||||
assert!(
|
||||
is_http_probe(preface),
|
||||
"HTTP/2 connection preface must be classified as HTTP probe"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn partial_http2_preface_3_bytes_classified() {
|
||||
assert!(
|
||||
is_http_probe(b"PRI"),
|
||||
"3-byte HTTP/2 preface prefix must be classified"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn partial_http2_preface_2_bytes_classified() {
|
||||
assert!(
|
||||
is_http_probe(b"PR"),
|
||||
"2-byte HTTP/2 preface prefix must be classified"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn existing_http1_methods_unaffected() {
|
||||
for prefix in [
|
||||
b"GET / HTTP/1.1\r\n".as_ref(),
|
||||
b"POST /api HTTP/1.1\r\n".as_ref(),
|
||||
b"CONNECT example.com:443 HTTP/1.1\r\n".as_ref(),
|
||||
b"TRACE / HTTP/1.1\r\n".as_ref(),
|
||||
b"PATCH / HTTP/1.1\r\n".as_ref(),
|
||||
] {
|
||||
assert!(is_http_probe(prefix));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_http_data_not_classified() {
|
||||
for data in [
|
||||
b"\x16\x03\x01\x00\xf1".as_ref(),
|
||||
b"SSH-2.0-OpenSSH_8.9\r\n".as_ref(),
|
||||
b"\x00\x01\x02\x03".as_ref(),
|
||||
b"".as_ref(),
|
||||
b"P".as_ref(),
|
||||
] {
|
||||
assert!(!is_http_probe(data));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn light_fuzz_non_http_prefixes_not_misclassified() {
|
||||
// Deterministic pseudo-fuzz to exercise classifier edges while avoiding
|
||||
// known HTTP method and partial windows.
|
||||
let mut x = 0x1234_5678u32;
|
||||
for _ in 0..1024 {
|
||||
x = x.wrapping_mul(1664525).wrapping_add(1013904223);
|
||||
let len = 4 + ((x >> 8) as usize % 12);
|
||||
let mut data = vec![0u8; len];
|
||||
for byte in &mut data {
|
||||
x = x.wrapping_mul(1664525).wrapping_add(1013904223);
|
||||
*byte = (x & 0xFF) as u8;
|
||||
}
|
||||
|
||||
if [
|
||||
b"GET ".as_ref(),
|
||||
b"POST".as_ref(),
|
||||
b"HEAD".as_ref(),
|
||||
b"PUT ".as_ref(),
|
||||
b"DELETE".as_ref(),
|
||||
b"OPTIONS".as_ref(),
|
||||
b"CONNECT".as_ref(),
|
||||
b"TRACE".as_ref(),
|
||||
b"PATCH".as_ref(),
|
||||
b"PRI ".as_ref(),
|
||||
]
|
||||
.iter()
|
||||
.any(|m| data.starts_with(m))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
assert!(
|
||||
!is_http_probe(&data),
|
||||
"non-http pseudo-fuzz input misclassified: {:?}",
|
||||
&data[..data.len().min(8)]
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,79 @@
|
|||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn exact_four_byte_http_tokens_are_classified() {
|
||||
for token in [b"GET ".as_ref(), b"POST".as_ref(), b"HEAD".as_ref(), b"PUT ".as_ref(), b"PRI ".as_ref()] {
|
||||
assert!(
|
||||
is_http_probe(token),
|
||||
"exact 4-byte token must be classified as HTTP probe: {:?}",
|
||||
token
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn exact_four_byte_non_http_tokens_are_not_classified() {
|
||||
for token in [
|
||||
b"GEX ".as_ref(),
|
||||
b"POXT".as_ref(),
|
||||
b"HEA/".as_ref(),
|
||||
b"PU\0 ".as_ref(),
|
||||
b"PRI/".as_ref(),
|
||||
] {
|
||||
assert!(
|
||||
!is_http_probe(token),
|
||||
"non-HTTP 4-byte token must not be classified: {:?}",
|
||||
token
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detect_client_type_keeps_http_label_for_minimal_four_byte_http_prefixes() {
|
||||
assert_eq!(detect_client_type(b"GET "), "HTTP");
|
||||
assert_eq!(detect_client_type(b"PRI "), "HTTP");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn exact_long_http_tokens_are_classified() {
|
||||
for token in [b"CONNECT".as_ref(), b"TRACE".as_ref(), b"PATCH".as_ref()] {
|
||||
assert!(
|
||||
is_http_probe(token),
|
||||
"exact long HTTP token must be classified as HTTP probe: {:?}",
|
||||
token
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detect_client_type_keeps_http_label_for_exact_long_http_tokens() {
|
||||
assert_eq!(detect_client_type(b"CONNECT"), "HTTP");
|
||||
assert_eq!(detect_client_type(b"TRACE"), "HTTP");
|
||||
assert_eq!(detect_client_type(b"PATCH"), "HTTP");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn light_fuzz_four_byte_ascii_noise_not_misclassified() {
|
||||
// Deterministic pseudo-fuzz over 4-byte printable ASCII inputs.
|
||||
let mut x = 0xA17C_93E5u32;
|
||||
for _ in 0..2048 {
|
||||
let mut token = [0u8; 4];
|
||||
for byte in &mut token {
|
||||
x = x.wrapping_mul(1664525).wrapping_add(1013904223);
|
||||
*byte = 32 + ((x & 0x3F) as u8); // printable ASCII subset
|
||||
}
|
||||
|
||||
if [b"GET ", b"POST", b"HEAD", b"PUT ", b"PRI "]
|
||||
.iter()
|
||||
.any(|m| token.as_slice() == *m)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
assert!(
|
||||
!is_http_probe(&token),
|
||||
"pseudo-fuzz noise misclassified as HTTP probe: {:?}",
|
||||
token
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,41 @@
|
|||
#![cfg(unix)]
|
||||
|
||||
use super::*;
|
||||
use std::sync::{Mutex, OnceLock};
|
||||
use tokio::sync::Barrier;
|
||||
|
||||
fn interface_cache_test_lock() -> &'static Mutex<()> {
|
||||
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
LOCK.get_or_init(|| Mutex::new(()))
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn adversarial_parallel_cold_miss_performs_single_interface_refresh() {
|
||||
let _guard = interface_cache_test_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(|poison| poison.into_inner());
|
||||
reset_local_interface_enumerations_for_tests();
|
||||
|
||||
let local_addr: SocketAddr = "0.0.0.0:443".parse().expect("valid local addr");
|
||||
let workers = 32usize;
|
||||
let barrier = std::sync::Arc::new(Barrier::new(workers));
|
||||
let mut tasks = Vec::with_capacity(workers);
|
||||
|
||||
for _ in 0..workers {
|
||||
let barrier = std::sync::Arc::clone(&barrier);
|
||||
tasks.push(tokio::spawn(async move {
|
||||
barrier.wait().await;
|
||||
is_mask_target_local_listener_async("127.0.0.1", 443, local_addr, None).await
|
||||
}));
|
||||
}
|
||||
|
||||
for task in tasks {
|
||||
let _ = task.await.expect("parallel cache task must not panic");
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
local_interface_enumerations_for_tests(),
|
||||
1,
|
||||
"parallel cold misses must coalesce into a single interface enumeration"
|
||||
);
|
||||
}
|
||||
|
|
@ -0,0 +1,51 @@
|
|||
#![cfg(unix)]
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn defense_in_depth_empty_refresh_preserves_previous_non_empty_interfaces() {
|
||||
let previous = vec![
|
||||
"192.168.100.7"
|
||||
.parse::<IpAddr>()
|
||||
.expect("must parse interface ip"),
|
||||
];
|
||||
let refreshed = Vec::new();
|
||||
|
||||
let next = choose_interface_snapshot(&previous, refreshed);
|
||||
|
||||
assert_eq!(
|
||||
next, previous,
|
||||
"empty refresh should preserve previous non-empty snapshot to avoid fail-open loop-guard regressions"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn defense_in_depth_non_empty_refresh_replaces_previous_snapshot() {
|
||||
let previous = vec![
|
||||
"192.168.100.7"
|
||||
.parse::<IpAddr>()
|
||||
.expect("must parse interface ip"),
|
||||
];
|
||||
let refreshed = vec![
|
||||
"10.55.0.3"
|
||||
.parse::<IpAddr>()
|
||||
.expect("must parse refreshed interface ip"),
|
||||
];
|
||||
|
||||
let next = choose_interface_snapshot(&previous, refreshed.clone());
|
||||
|
||||
assert_eq!(next, refreshed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn defense_in_depth_empty_refresh_keeps_empty_when_no_previous_snapshot_exists() {
|
||||
let previous = Vec::new();
|
||||
let refreshed = Vec::new();
|
||||
|
||||
let next = choose_interface_snapshot(&previous, refreshed);
|
||||
|
||||
assert!(
|
||||
next.is_empty(),
|
||||
"empty refresh with no previous snapshot should remain empty"
|
||||
);
|
||||
}
|
||||
|
|
@ -0,0 +1,46 @@
|
|||
#![cfg(unix)]
|
||||
|
||||
use super::*;
|
||||
use std::sync::{Mutex, OnceLock};
|
||||
|
||||
fn interface_cache_test_lock() -> &'static Mutex<()> {
|
||||
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
LOCK.get_or_init(|| Mutex::new(()))
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn tdd_repeated_local_listener_checks_do_not_repeat_interface_enumeration_within_window() {
|
||||
let _guard = interface_cache_test_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(|poison| poison.into_inner());
|
||||
reset_local_interface_enumerations_for_tests();
|
||||
|
||||
let local_addr: SocketAddr = "0.0.0.0:443".parse().expect("valid local addr");
|
||||
|
||||
let _ = is_mask_target_local_listener_async("127.0.0.1", 443, local_addr, None).await;
|
||||
let _ = is_mask_target_local_listener_async("127.0.0.1", 443, local_addr, None).await;
|
||||
|
||||
assert_eq!(
|
||||
local_interface_enumerations_for_tests(),
|
||||
1,
|
||||
"interface enumeration must be cached across repeated bad-client checks"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn tdd_non_local_port_short_circuit_does_not_enumerate_interfaces() {
|
||||
let _guard = interface_cache_test_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(|poison| poison.into_inner());
|
||||
reset_local_interface_enumerations_for_tests();
|
||||
|
||||
let local_addr: SocketAddr = "0.0.0.0:443".parse().expect("valid local addr");
|
||||
let is_local = is_mask_target_local_listener_async("127.0.0.1", 8443, local_addr, None).await;
|
||||
|
||||
assert!(!is_local, "different port must not be treated as local listener");
|
||||
assert_eq!(
|
||||
local_interface_enumerations_for_tests(),
|
||||
0,
|
||||
"port mismatch should bypass interface enumeration entirely"
|
||||
);
|
||||
}
|
||||
|
|
@ -0,0 +1,178 @@
|
|||
use super::*;
|
||||
use std::net::{SocketAddr, TcpListener as StdTcpListener};
|
||||
use tokio::io::{AsyncWriteExt, duplex};
|
||||
use tokio::time::{Duration, Instant};
|
||||
|
||||
fn closed_local_port() -> u16 {
|
||||
let listener = StdTcpListener::bind("127.0.0.1:0").unwrap();
|
||||
let port = listener.local_addr().unwrap().port();
|
||||
drop(listener);
|
||||
port
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore = "red-team expected-fail: offline mask target keeps bad-client socket alive before consume timeout boundary"]
|
||||
async fn redteam_offline_target_should_drop_idle_client_early() {
|
||||
let (client_read, mut client_write) = duplex(1024);
|
||||
|
||||
let mut cfg = ProxyConfig::default();
|
||||
cfg.general.beobachten = false;
|
||||
cfg.censorship.mask = true;
|
||||
cfg.censorship.mask_unix_sock = None;
|
||||
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
|
||||
cfg.censorship.mask_port = closed_local_port();
|
||||
cfg.censorship.mask_timing_normalization_enabled = false;
|
||||
|
||||
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
|
||||
let peer_addr: SocketAddr = "192.0.2.50:5000".parse().unwrap();
|
||||
let beobachten = BeobachtenStore::new();
|
||||
|
||||
let handler = tokio::spawn(async move {
|
||||
handle_bad_client(
|
||||
client_read,
|
||||
tokio::io::sink(),
|
||||
b"GET / HTTP/1.1\r\n\r\n",
|
||||
peer_addr,
|
||||
local_addr,
|
||||
&cfg,
|
||||
&beobachten,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(150)).await;
|
||||
let write_res = client_write.write_all(b"probe-should-be-closed").await;
|
||||
assert!(
|
||||
write_res.is_err(),
|
||||
"offline target path still keeps client writable before consume timeout"
|
||||
);
|
||||
|
||||
handler.abort();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore = "red-team expected-fail: proxy should mimic immediate RST-like close when target is offline"]
|
||||
async fn redteam_offline_target_should_not_sleep_to_mask_refusal() {
|
||||
let (client_read, mut client_write) = duplex(1024);
|
||||
|
||||
let mut cfg = ProxyConfig::default();
|
||||
cfg.general.beobachten = false;
|
||||
cfg.censorship.mask = true;
|
||||
cfg.censorship.mask_unix_sock = None;
|
||||
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
|
||||
cfg.censorship.mask_port = closed_local_port();
|
||||
cfg.censorship.mask_timing_normalization_enabled = false;
|
||||
|
||||
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
|
||||
let peer_addr: SocketAddr = "192.0.2.51:5000".parse().unwrap();
|
||||
let beobachten = BeobachtenStore::new();
|
||||
|
||||
let started = Instant::now();
|
||||
let handler = tokio::spawn(async move {
|
||||
handle_bad_client(
|
||||
client_read,
|
||||
tokio::io::sink(),
|
||||
b"\x16\x03\x01\x00\x05hello",
|
||||
peer_addr,
|
||||
local_addr,
|
||||
&cfg,
|
||||
&beobachten,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
|
||||
client_write.shutdown().await.unwrap();
|
||||
let _ = handler.await;
|
||||
let elapsed = started.elapsed();
|
||||
|
||||
assert!(
|
||||
elapsed < Duration::from_millis(10),
|
||||
"offline target path still applies coarse masking sleep and is fingerprintable"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore = "red-team expected-fail: refusal path should remain below strict latency envelope under burst"]
|
||||
async fn redteam_offline_refusal_burst_timing_spread_should_be_tight() {
|
||||
let mut samples = Vec::new();
|
||||
|
||||
for i in 0..12u16 {
|
||||
let (client_read, mut client_write) = duplex(1024);
|
||||
let mut cfg = ProxyConfig::default();
|
||||
cfg.general.beobachten = false;
|
||||
cfg.censorship.mask = true;
|
||||
cfg.censorship.mask_unix_sock = None;
|
||||
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
|
||||
cfg.censorship.mask_port = closed_local_port();
|
||||
cfg.censorship.mask_timing_normalization_enabled = false;
|
||||
|
||||
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
|
||||
let peer_addr: SocketAddr = format!("192.0.2.52:{}", 5100 + i).parse().unwrap();
|
||||
let beobachten = BeobachtenStore::new();
|
||||
|
||||
let started = Instant::now();
|
||||
let handler = tokio::spawn(async move {
|
||||
handle_bad_client(
|
||||
client_read,
|
||||
tokio::io::sink(),
|
||||
b"GET / HTTP/1.1\r\n\r\n",
|
||||
peer_addr,
|
||||
local_addr,
|
||||
&cfg,
|
||||
&beobachten,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
|
||||
client_write.shutdown().await.unwrap();
|
||||
let _ = handler.await;
|
||||
samples.push(started.elapsed());
|
||||
}
|
||||
|
||||
let min = samples.iter().copied().min().unwrap_or_default();
|
||||
let max = samples.iter().copied().max().unwrap_or_default();
|
||||
let spread = max.saturating_sub(min);
|
||||
|
||||
assert!(
|
||||
spread <= Duration::from_millis(5),
|
||||
"offline refusal timing spread too wide for strict red-team envelope: {:?}",
|
||||
spread
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore = "manual red-team: host resolver failure should complete without panic"]
|
||||
async fn redteam_dns_resolution_failure_must_not_panic() {
|
||||
let (client_read, mut client_write) = duplex(1024);
|
||||
|
||||
let mut cfg = ProxyConfig::default();
|
||||
cfg.general.beobachten = false;
|
||||
cfg.censorship.mask = true;
|
||||
cfg.censorship.mask_unix_sock = None;
|
||||
cfg.censorship.mask_host = Some("this.domain.definitely.does.not.exist.invalid".to_string());
|
||||
cfg.censorship.mask_port = 443;
|
||||
|
||||
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
|
||||
let peer_addr: SocketAddr = "192.0.2.99:5999".parse().unwrap();
|
||||
let beobachten = BeobachtenStore::new();
|
||||
|
||||
let handler = tokio::spawn(async move {
|
||||
handle_bad_client(
|
||||
client_read,
|
||||
tokio::io::sink(),
|
||||
b"GET / HTTP/1.1\r\n\r\n",
|
||||
peer_addr,
|
||||
local_addr,
|
||||
&cfg,
|
||||
&beobachten,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
|
||||
client_write.shutdown().await.unwrap();
|
||||
let result = tokio::time::timeout(Duration::from_secs(2), handler).await;
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"dns failure path stalled or panicked instead of terminating"
|
||||
);
|
||||
}
|
||||
|
|
@ -0,0 +1,51 @@
|
|||
use super::*;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
use std::time::Instant;
|
||||
use tokio::io::AsyncWrite;
|
||||
|
||||
struct NeverWritable;
|
||||
|
||||
impl AsyncWrite for NeverWritable {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
_buf: &[u8],
|
||||
) -> Poll<std::io::Result<usize>> {
|
||||
Poll::Pending
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
||||
Poll::Pending
|
||||
}
|
||||
|
||||
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn shape_padding_returns_before_global_mask_timeout_on_blocked_writer() {
|
||||
let mut writer = NeverWritable;
|
||||
let started = Instant::now();
|
||||
|
||||
maybe_write_shape_padding(&mut writer, 1, true, 256, 4096, false, 0, false).await;
|
||||
|
||||
assert!(
|
||||
started.elapsed() <= MASK_TIMEOUT + std::time::Duration::from_millis(30),
|
||||
"shape padding blocked past timeout budget"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn shape_padding_with_non_http_blur_disabled_at_cap_writes_nothing() {
|
||||
let mut output = Vec::new();
|
||||
{
|
||||
let mut writer = tokio::io::BufWriter::new(&mut output);
|
||||
maybe_write_shape_padding(&mut writer, 4096, true, 64, 4096, false, 128, false).await;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
writer.flush().await.unwrap();
|
||||
}
|
||||
|
||||
assert!(output.is_empty());
|
||||
}
|
||||
|
|
@ -0,0 +1,289 @@
|
|||
use super::*;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::task::{Context, Poll};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio::time::{Duration, Instant, timeout};
|
||||
|
||||
const PROD_CAP_BYTES: usize = 5 * 1024 * 1024;
|
||||
|
||||
struct FinitePatternReader {
|
||||
remaining: usize,
|
||||
chunk: usize,
|
||||
read_calls: Arc<AtomicUsize>,
|
||||
}
|
||||
|
||||
impl FinitePatternReader {
|
||||
fn new(total: usize, chunk: usize, read_calls: Arc<AtomicUsize>) -> Self {
|
||||
Self {
|
||||
remaining: total,
|
||||
chunk,
|
||||
read_calls,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for FinitePatternReader {
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
buf: &mut tokio::io::ReadBuf<'_>,
|
||||
) -> Poll<std::io::Result<()>> {
|
||||
self.read_calls.fetch_add(1, Ordering::Relaxed);
|
||||
|
||||
if self.remaining == 0 {
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
|
||||
let take = self.remaining.min(self.chunk).min(buf.remaining());
|
||||
if take == 0 {
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
|
||||
let fill = vec![0x5Au8; take];
|
||||
buf.put_slice(&fill);
|
||||
self.remaining -= take;
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct CountingWriter {
|
||||
written: usize,
|
||||
}
|
||||
|
||||
impl AsyncWrite for CountingWriter {
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<std::io::Result<usize>> {
|
||||
self.written = self.written.saturating_add(buf.len());
|
||||
Poll::Ready(Ok(buf.len()))
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
) -> Poll<std::io::Result<()>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
) -> Poll<std::io::Result<()>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
struct NeverReadyReader;
|
||||
|
||||
impl AsyncRead for NeverReadyReader {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
_buf: &mut tokio::io::ReadBuf<'_>,
|
||||
) -> Poll<std::io::Result<()>> {
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
|
||||
struct BudgetProbeReader {
|
||||
remaining: usize,
|
||||
total_read: Arc<AtomicUsize>,
|
||||
}
|
||||
|
||||
impl BudgetProbeReader {
|
||||
fn new(total: usize, total_read: Arc<AtomicUsize>) -> Self {
|
||||
Self {
|
||||
remaining: total,
|
||||
total_read,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for BudgetProbeReader {
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
buf: &mut tokio::io::ReadBuf<'_>,
|
||||
) -> Poll<std::io::Result<()>> {
|
||||
if self.remaining == 0 {
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
|
||||
let take = self.remaining.min(buf.remaining());
|
||||
if take == 0 {
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
|
||||
let fill = vec![0xA5u8; take];
|
||||
buf.put_slice(&fill);
|
||||
self.remaining -= take;
|
||||
self.total_read.fetch_add(take, Ordering::Relaxed);
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn positive_copy_with_production_cap_stops_exactly_at_budget() {
|
||||
let read_calls = Arc::new(AtomicUsize::new(0));
|
||||
let mut reader = FinitePatternReader::new(PROD_CAP_BYTES + (256 * 1024), 4096, read_calls);
|
||||
let mut writer = CountingWriter::default();
|
||||
|
||||
let outcome = copy_with_idle_timeout(&mut reader, &mut writer, PROD_CAP_BYTES, true).await;
|
||||
|
||||
assert_eq!(
|
||||
outcome.total, PROD_CAP_BYTES,
|
||||
"copy path must stop at explicit production cap"
|
||||
);
|
||||
assert_eq!(writer.written, PROD_CAP_BYTES);
|
||||
assert!(
|
||||
!outcome.ended_by_eof,
|
||||
"byte-cap stop must not be misclassified as EOF"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn negative_consume_with_zero_cap_performs_no_reads() {
|
||||
let read_calls = Arc::new(AtomicUsize::new(0));
|
||||
let reader = FinitePatternReader::new(1024, 64, Arc::clone(&read_calls));
|
||||
|
||||
consume_client_data_with_timeout_and_cap(reader, 0).await;
|
||||
|
||||
assert_eq!(
|
||||
read_calls.load(Ordering::Relaxed),
|
||||
0,
|
||||
"zero cap must return before reading attacker-controlled bytes"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn edge_copy_below_cap_reports_eof_without_overread() {
|
||||
let read_calls = Arc::new(AtomicUsize::new(0));
|
||||
let payload = 73 * 1024;
|
||||
let mut reader = FinitePatternReader::new(payload, 3072, read_calls);
|
||||
let mut writer = CountingWriter::default();
|
||||
|
||||
let outcome = copy_with_idle_timeout(&mut reader, &mut writer, PROD_CAP_BYTES, true).await;
|
||||
|
||||
assert_eq!(outcome.total, payload);
|
||||
assert_eq!(writer.written, payload);
|
||||
assert!(
|
||||
outcome.ended_by_eof,
|
||||
"finite upstream below cap must terminate via EOF path"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn adversarial_blackhat_never_ready_reader_is_bounded_by_timeout_guards() {
|
||||
let started = Instant::now();
|
||||
|
||||
consume_client_data_with_timeout_and_cap(NeverReadyReader, PROD_CAP_BYTES).await;
|
||||
|
||||
assert!(
|
||||
started.elapsed() < Duration::from_millis(350),
|
||||
"never-ready reader must be bounded by idle/relay timeout protections"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn integration_consume_path_honors_production_cap_for_large_payload() {
|
||||
let read_calls = Arc::new(AtomicUsize::new(0));
|
||||
let reader = FinitePatternReader::new(PROD_CAP_BYTES + (1024 * 1024), 8192, read_calls);
|
||||
|
||||
let bounded = timeout(
|
||||
Duration::from_millis(350),
|
||||
consume_client_data_with_timeout_and_cap(reader, PROD_CAP_BYTES),
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(
|
||||
bounded.is_ok(),
|
||||
"consume path with production cap must finish within bounded time"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn adversarial_consume_path_never_reads_beyond_declared_byte_cap() {
|
||||
let byte_cap = 5usize;
|
||||
let total_read = Arc::new(AtomicUsize::new(0));
|
||||
let reader = BudgetProbeReader::new(256 * 1024, Arc::clone(&total_read));
|
||||
|
||||
consume_client_data_with_timeout_and_cap(reader, byte_cap).await;
|
||||
|
||||
assert!(
|
||||
total_read.load(Ordering::Relaxed) <= byte_cap,
|
||||
"consume path must not read more than configured byte cap"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn light_fuzz_cap_and_payload_matrix_preserves_min_budget_invariant() {
|
||||
let mut seed = 0x1234_5678_9ABC_DEF0u64;
|
||||
|
||||
for _case in 0..96u32 {
|
||||
seed ^= seed << 7;
|
||||
seed ^= seed >> 9;
|
||||
seed ^= seed << 8;
|
||||
|
||||
let cap = ((seed & 0x3ffff) as usize).saturating_add(1);
|
||||
let payload = ((seed.rotate_left(11) & 0x7ffff) as usize).saturating_add(1);
|
||||
let chunk = (((seed >> 5) & 0x1fff) as usize).saturating_add(1);
|
||||
|
||||
let read_calls = Arc::new(AtomicUsize::new(0));
|
||||
let mut reader = FinitePatternReader::new(payload, chunk, read_calls);
|
||||
let mut writer = CountingWriter::default();
|
||||
|
||||
let outcome = copy_with_idle_timeout(&mut reader, &mut writer, cap, true).await;
|
||||
let expected = payload.min(cap);
|
||||
|
||||
assert_eq!(
|
||||
outcome.total, expected,
|
||||
"copy total must match min(payload, cap) under fuzzed inputs"
|
||||
);
|
||||
assert_eq!(writer.written, expected);
|
||||
if payload <= cap {
|
||||
assert!(outcome.ended_by_eof);
|
||||
} else {
|
||||
assert!(!outcome.ended_by_eof);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn stress_parallel_copy_tasks_with_production_cap_complete_without_leaks() {
|
||||
let workers = 8usize;
|
||||
let mut tasks = Vec::with_capacity(workers);
|
||||
|
||||
for idx in 0..workers {
|
||||
tasks.push(tokio::spawn(async move {
|
||||
let read_calls = Arc::new(AtomicUsize::new(0));
|
||||
let mut reader = FinitePatternReader::new(
|
||||
PROD_CAP_BYTES + (idx + 1) * 4096,
|
||||
4096 + (idx * 257),
|
||||
read_calls,
|
||||
);
|
||||
let mut writer = CountingWriter::default();
|
||||
copy_with_idle_timeout(&mut reader, &mut writer, PROD_CAP_BYTES, true).await
|
||||
}));
|
||||
}
|
||||
|
||||
timeout(Duration::from_secs(3), async {
|
||||
for task in tasks {
|
||||
let outcome = task.await.expect("stress task must not panic");
|
||||
assert_eq!(
|
||||
outcome.total, PROD_CAP_BYTES,
|
||||
"stress copy task must stay within production cap"
|
||||
);
|
||||
assert!(
|
||||
!outcome.ended_by_eof,
|
||||
"stress task should end due to cap, not EOF"
|
||||
);
|
||||
}
|
||||
})
|
||||
.await
|
||||
.expect("stress suite must complete in bounded time");
|
||||
}
|
||||
|
|
@ -0,0 +1,105 @@
|
|||
use super::*;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex, sink};
|
||||
use tokio::time::{Duration, timeout};
|
||||
|
||||
#[tokio::test]
|
||||
async fn relay_to_mask_enforces_masking_session_byte_cap() {
|
||||
let initial = vec![0x16, 0x03, 0x01, 0x00, 0x01];
|
||||
let extra = vec![0xAB; 96 * 1024];
|
||||
|
||||
let (client_reader, mut client_writer) = duplex(128 * 1024);
|
||||
let (mask_read, _mask_read_peer) = duplex(1024);
|
||||
let (mut mask_observer, mask_write) = duplex(256 * 1024);
|
||||
let initial_for_task = initial.clone();
|
||||
|
||||
let relay = tokio::spawn(async move {
|
||||
relay_to_mask(
|
||||
client_reader,
|
||||
sink(),
|
||||
mask_read,
|
||||
mask_write,
|
||||
&initial_for_task,
|
||||
false,
|
||||
512,
|
||||
4096,
|
||||
false,
|
||||
0,
|
||||
false,
|
||||
32 * 1024,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
|
||||
client_writer.write_all(&extra).await.unwrap();
|
||||
client_writer.shutdown().await.unwrap();
|
||||
|
||||
timeout(Duration::from_secs(2), relay)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
let mut observed = Vec::new();
|
||||
timeout(
|
||||
Duration::from_secs(2),
|
||||
mask_observer.read_to_end(&mut observed),
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
// In this deterministic test, relay must stop exactly at the configured cap.
|
||||
assert_eq!(
|
||||
observed.len(),
|
||||
initial.len() + (32 * 1024),
|
||||
"masked relay must forward exactly up to the cap (observed={} initial={} cap={})",
|
||||
observed.len(),
|
||||
initial.len(),
|
||||
32 * 1024
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn relay_to_mask_propagates_client_half_close_without_waiting_for_other_direction_timeout() {
|
||||
let initial = b"GET /half-close HTTP/1.1\r\n".to_vec();
|
||||
|
||||
let (client_reader, mut client_writer) = duplex(8 * 1024);
|
||||
let (mask_read, _mask_read_peer) = duplex(8 * 1024);
|
||||
let (mut mask_observer, mask_write) = duplex(8 * 1024);
|
||||
let initial_for_task = initial.clone();
|
||||
|
||||
let relay = tokio::spawn(async move {
|
||||
relay_to_mask(
|
||||
client_reader,
|
||||
sink(),
|
||||
mask_read,
|
||||
mask_write,
|
||||
&initial_for_task,
|
||||
false,
|
||||
512,
|
||||
4096,
|
||||
false,
|
||||
0,
|
||||
false,
|
||||
32 * 1024,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
|
||||
client_writer.shutdown().await.unwrap();
|
||||
|
||||
let mut observed = Vec::new();
|
||||
timeout(
|
||||
Duration::from_millis(80),
|
||||
mask_observer.read_to_end(&mut observed),
|
||||
)
|
||||
.await
|
||||
.expect("mask backend write side should be half-closed promptly")
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(&observed[..initial.len()], initial.as_slice());
|
||||
|
||||
timeout(Duration::from_secs(2), relay)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
}
|
||||
|
|
@ -0,0 +1,100 @@
|
|||
use super::*;
|
||||
use tokio::io::AsyncReadExt;
|
||||
use tokio::time::{Duration, timeout};
|
||||
|
||||
async fn collect_padding(
|
||||
total_sent: usize,
|
||||
enabled: bool,
|
||||
floor: usize,
|
||||
cap: usize,
|
||||
above_cap_blur: bool,
|
||||
blur_max: usize,
|
||||
aggressive: bool,
|
||||
) -> Vec<u8> {
|
||||
let (mut tx, mut rx) = tokio::io::duplex(256 * 1024);
|
||||
|
||||
maybe_write_shape_padding(
|
||||
&mut tx,
|
||||
total_sent,
|
||||
enabled,
|
||||
floor,
|
||||
cap,
|
||||
above_cap_blur,
|
||||
blur_max,
|
||||
aggressive,
|
||||
)
|
||||
.await;
|
||||
|
||||
drop(tx);
|
||||
|
||||
let mut output = Vec::new();
|
||||
timeout(Duration::from_secs(1), rx.read_to_end(&mut output))
|
||||
.await
|
||||
.expect("reading padded output timed out")
|
||||
.expect("failed reading padded output");
|
||||
output
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn padding_output_is_not_all_zero() {
|
||||
let output = collect_padding(1, true, 256, 4096, false, 0, false).await;
|
||||
|
||||
assert!(
|
||||
output.len() >= 255,
|
||||
"expected at least 255 padding bytes, got {}",
|
||||
output.len()
|
||||
);
|
||||
|
||||
let nonzero = output.iter().filter(|&&b| b != 0).count();
|
||||
// In 255 bytes of uniform randomness, the expected number of zero bytes is ~1.
|
||||
// A weak nonzero check can miss severe entropy collapse.
|
||||
assert!(
|
||||
nonzero >= 240,
|
||||
"RNG output entropy collapsed, too many zero bytes: {} nonzero out of {}",
|
||||
nonzero,
|
||||
output.len(),
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn padding_reaches_first_bucket_boundary() {
|
||||
let output = collect_padding(1, true, 64, 4096, false, 0, false).await;
|
||||
assert_eq!(output.len(), 63);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn disabled_padding_produces_no_output() {
|
||||
let output = collect_padding(0, false, 256, 4096, false, 0, false).await;
|
||||
assert!(output.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn at_cap_without_blur_produces_no_output() {
|
||||
let output = collect_padding(4096, true, 64, 4096, false, 0, false).await;
|
||||
assert!(output.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn above_cap_blur_is_positive_and_bounded_in_aggressive_mode() {
|
||||
let output = collect_padding(4096, true, 64, 4096, true, 128, true).await;
|
||||
assert!(!output.is_empty());
|
||||
assert!(output.len() <= 128, "blur exceeded max: {}", output.len());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn stress_padding_runs_are_not_constant_pattern() {
|
||||
// Stress and sanity-check: repeated runs should not collapse to identical
|
||||
// first 16 bytes across all samples.
|
||||
let mut first_chunks = Vec::new();
|
||||
for _ in 0..64 {
|
||||
let out = collect_padding(1, true, 64, 4096, false, 0, false).await;
|
||||
first_chunks.push(out[..16].to_vec());
|
||||
}
|
||||
|
||||
let first = &first_chunks[0];
|
||||
let all_same = first_chunks.iter().all(|chunk| chunk == first);
|
||||
assert!(
|
||||
!all_same,
|
||||
"all stress samples had identical prefix, rng output appears degenerate"
|
||||
);
|
||||
}
|
||||
|
|
@ -1376,6 +1376,7 @@ async fn relay_to_mask_keeps_backend_to_client_flow_when_client_to_backend_stall
|
|||
false,
|
||||
0,
|
||||
false,
|
||||
5 * 1024 * 1024,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
|
|
@ -1506,6 +1507,7 @@ async fn relay_to_mask_timeout_cancels_and_drops_all_io_endpoints() {
|
|||
false,
|
||||
0,
|
||||
false,
|
||||
5 * 1024 * 1024,
|
||||
),
|
||||
)
|
||||
.await;
|
||||
|
|
|
|||
|
|
@ -0,0 +1,360 @@
|
|||
use super::*;
|
||||
use std::net::TcpListener as StdTcpListener;
|
||||
use std::net::SocketAddr;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::time::{Duration, Instant, timeout};
|
||||
|
||||
fn closed_local_port() -> u16 {
|
||||
let listener = StdTcpListener::bind("127.0.0.1:0").unwrap();
|
||||
let port = listener.local_addr().unwrap().port();
|
||||
drop(listener);
|
||||
port
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn self_target_detection_matches_literal_ipv4_listener() {
|
||||
let local: SocketAddr = "198.51.100.40:443".parse().unwrap();
|
||||
assert!(is_mask_target_local_listener_async(
|
||||
"198.51.100.40",
|
||||
443,
|
||||
local,
|
||||
None,
|
||||
)
|
||||
.await);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn self_target_detection_matches_bracketed_ipv6_listener() {
|
||||
let local: SocketAddr = "[2001:db8::44]:8443".parse().unwrap();
|
||||
assert!(is_mask_target_local_listener_async(
|
||||
"[2001:db8::44]",
|
||||
8443,
|
||||
local,
|
||||
None,
|
||||
)
|
||||
.await);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn self_target_detection_keeps_same_ip_different_port_forwardable() {
|
||||
let local: SocketAddr = "203.0.113.44:443".parse().unwrap();
|
||||
assert!(!is_mask_target_local_listener_async(
|
||||
"203.0.113.44",
|
||||
8443,
|
||||
local,
|
||||
None,
|
||||
)
|
||||
.await);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn self_target_detection_normalizes_ipv4_mapped_ipv6_literal() {
|
||||
let local: SocketAddr = "127.0.0.1:443".parse().unwrap();
|
||||
assert!(is_mask_target_local_listener_async(
|
||||
"::ffff:127.0.0.1",
|
||||
443,
|
||||
local,
|
||||
None,
|
||||
)
|
||||
.await);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn self_target_detection_unspecified_bind_blocks_loopback_target() {
|
||||
let local: SocketAddr = "0.0.0.0:443".parse().unwrap();
|
||||
assert!(is_mask_target_local_listener_async(
|
||||
"127.0.0.1",
|
||||
443,
|
||||
local,
|
||||
None,
|
||||
)
|
||||
.await);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn self_target_detection_unspecified_bind_keeps_remote_target_forwardable() {
|
||||
let local: SocketAddr = "0.0.0.0:443".parse().unwrap();
|
||||
let remote: SocketAddr = "198.51.100.44:443".parse().unwrap();
|
||||
assert!(!is_mask_target_local_listener_async(
|
||||
"mask.example",
|
||||
443,
|
||||
local,
|
||||
Some(remote),
|
||||
)
|
||||
.await);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn self_target_fallback_refuses_recursive_loopback_connect() {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let local_addr = listener.local_addr().unwrap();
|
||||
let accept_task = tokio::spawn(async move {
|
||||
timeout(Duration::from_millis(120), listener.accept())
|
||||
.await
|
||||
.is_ok()
|
||||
});
|
||||
|
||||
let mut config = ProxyConfig::default();
|
||||
config.general.beobachten = false;
|
||||
config.censorship.mask = true;
|
||||
config.censorship.mask_unix_sock = None;
|
||||
config.censorship.mask_host = Some(local_addr.ip().to_string());
|
||||
config.censorship.mask_port = local_addr.port();
|
||||
config.censorship.mask_proxy_protocol = 0;
|
||||
|
||||
let peer: SocketAddr = "203.0.113.90:55090".parse().unwrap();
|
||||
let beobachten = BeobachtenStore::new();
|
||||
|
||||
handle_bad_client(
|
||||
tokio::io::empty(),
|
||||
tokio::io::sink(),
|
||||
b"GET /",
|
||||
peer,
|
||||
local_addr,
|
||||
&config,
|
||||
&beobachten,
|
||||
)
|
||||
.await;
|
||||
|
||||
let accepted = accept_task.await.unwrap();
|
||||
assert!(
|
||||
!accepted,
|
||||
"self-target masking must fail closed without connecting to local listener"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn same_ip_different_port_still_forwards_to_mask_backend() {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let backend_addr = listener.local_addr().unwrap();
|
||||
let probe = b"GET /".to_vec();
|
||||
let accept_task = tokio::spawn({
|
||||
let expected = probe.clone();
|
||||
async move {
|
||||
let (mut stream, _) = listener.accept().await.unwrap();
|
||||
let mut got = vec![0u8; expected.len()];
|
||||
stream.read_exact(&mut got).await.unwrap();
|
||||
assert_eq!(got, expected);
|
||||
}
|
||||
});
|
||||
|
||||
let mut config = ProxyConfig::default();
|
||||
config.general.beobachten = false;
|
||||
config.censorship.mask = true;
|
||||
config.censorship.mask_unix_sock = None;
|
||||
config.censorship.mask_host = Some("127.0.0.1".to_string());
|
||||
config.censorship.mask_port = backend_addr.port();
|
||||
config.censorship.mask_proxy_protocol = 0;
|
||||
|
||||
let peer: SocketAddr = "203.0.113.91:55091".parse().unwrap();
|
||||
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
|
||||
let beobachten = BeobachtenStore::new();
|
||||
|
||||
handle_bad_client(
|
||||
tokio::io::empty(),
|
||||
tokio::io::sink(),
|
||||
&probe,
|
||||
peer,
|
||||
local_addr,
|
||||
&config,
|
||||
&beobachten,
|
||||
)
|
||||
.await;
|
||||
|
||||
timeout(Duration::from_secs(2), accept_task)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detect_client_type_http_boundary_get_and_post() {
|
||||
assert_eq!(detect_client_type(b"GET "), "HTTP");
|
||||
assert_eq!(detect_client_type(b"GET /"), "HTTP");
|
||||
|
||||
assert_eq!(detect_client_type(b"POST"), "HTTP");
|
||||
assert_eq!(detect_client_type(b"POST "), "HTTP");
|
||||
assert_eq!(detect_client_type(b"POSTX"), "HTTP");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detect_client_type_tls_and_length_boundaries() {
|
||||
assert_eq!(detect_client_type(b"\x16\x03\x01"), "port-scanner");
|
||||
assert_eq!(detect_client_type(b"\x16\x03\x01\x00"), "TLS-scanner");
|
||||
|
||||
assert_eq!(detect_client_type(b"123456789"), "port-scanner");
|
||||
assert_eq!(detect_client_type(b"1234567890"), "unknown");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_mask_proxy_header_v1_cross_family_falls_back_to_unknown() {
|
||||
let peer: SocketAddr = "192.168.1.5:12345".parse().unwrap();
|
||||
let local: SocketAddr = "[2001:db8::1]:443".parse().unwrap();
|
||||
let header = build_mask_proxy_header(1, peer, local).unwrap();
|
||||
assert_eq!(header, b"PROXY UNKNOWN\r\n");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn next_mask_shape_bucket_checked_mul_overflow_fails_closed() {
|
||||
let floor = usize::MAX / 2 + 1;
|
||||
let cap = usize::MAX;
|
||||
let total = floor + 1;
|
||||
assert_eq!(next_mask_shape_bucket(total, floor, cap), total);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn self_target_reject_path_keeps_timing_budget() {
|
||||
let mut config = ProxyConfig::default();
|
||||
config.general.beobachten = false;
|
||||
config.censorship.mask = true;
|
||||
config.censorship.mask_unix_sock = None;
|
||||
config.censorship.mask_host = Some("127.0.0.1".to_string());
|
||||
config.censorship.mask_port = 443;
|
||||
|
||||
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
|
||||
let peer: SocketAddr = "203.0.113.92:55092".parse().unwrap();
|
||||
let beobachten = BeobachtenStore::new();
|
||||
|
||||
let (client, server) = duplex(1024);
|
||||
drop(client);
|
||||
|
||||
let started = Instant::now();
|
||||
handle_bad_client(
|
||||
server,
|
||||
tokio::io::sink(),
|
||||
b"GET / HTTP/1.1\r\n",
|
||||
peer,
|
||||
local_addr,
|
||||
&config,
|
||||
&beobachten,
|
||||
)
|
||||
.await;
|
||||
|
||||
let elapsed = started.elapsed();
|
||||
assert!(
|
||||
elapsed >= Duration::from_millis(40) && elapsed < Duration::from_millis(250),
|
||||
"self-target reject path must keep coarse timing budget without stalling"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn relay_path_idle_timeout_eviction_remains_effective() {
|
||||
let (client_read, mut client_write) = duplex(1024);
|
||||
let (mask_read, mask_write) = duplex(1024);
|
||||
|
||||
tokio::spawn(async move {
|
||||
tokio::time::sleep(Duration::from_millis(10)).await;
|
||||
client_write.write_all(b"a").await.unwrap();
|
||||
tokio::time::sleep(Duration::from_millis(180)).await;
|
||||
let _ = client_write.write_all(b"b").await;
|
||||
});
|
||||
|
||||
let started = Instant::now();
|
||||
relay_to_mask(
|
||||
client_read,
|
||||
tokio::io::sink(),
|
||||
mask_read,
|
||||
mask_write,
|
||||
b"init",
|
||||
false,
|
||||
0,
|
||||
0,
|
||||
false,
|
||||
0,
|
||||
false,
|
||||
5 * 1024 * 1024,
|
||||
)
|
||||
.await;
|
||||
|
||||
let elapsed = started.elapsed();
|
||||
assert!(
|
||||
elapsed >= Duration::from_millis(90) && elapsed < Duration::from_millis(180),
|
||||
"idle-timeout eviction must occur before late trickle write"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn offline_mask_target_refusal_respects_timing_normalization_budget() {
|
||||
let mut config = ProxyConfig::default();
|
||||
config.general.beobachten = false;
|
||||
config.censorship.mask = true;
|
||||
config.censorship.mask_unix_sock = None;
|
||||
config.censorship.mask_host = Some("127.0.0.1".to_string());
|
||||
config.censorship.mask_port = closed_local_port();
|
||||
config.censorship.mask_timing_normalization_enabled = true;
|
||||
config.censorship.mask_timing_normalization_floor_ms = 120;
|
||||
config.censorship.mask_timing_normalization_ceiling_ms = 120;
|
||||
|
||||
let peer: SocketAddr = "203.0.113.93:55093".parse().unwrap();
|
||||
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
|
||||
let beobachten = BeobachtenStore::new();
|
||||
|
||||
let (mut client, server) = duplex(1024);
|
||||
let started = Instant::now();
|
||||
let task = tokio::spawn(async move {
|
||||
handle_bad_client(
|
||||
server,
|
||||
tokio::io::sink(),
|
||||
b"GET / HTTP/1.1\r\n\r\n",
|
||||
peer,
|
||||
local_addr,
|
||||
&config,
|
||||
&beobachten,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
|
||||
client.shutdown().await.unwrap();
|
||||
timeout(Duration::from_secs(2), task).await.unwrap().unwrap();
|
||||
let elapsed = started.elapsed();
|
||||
|
||||
assert!(
|
||||
elapsed >= Duration::from_millis(110) && elapsed < Duration::from_millis(220),
|
||||
"offline-refusal path must honor normalization budget without unbounded drift"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn offline_mask_target_refusal_with_idle_client_is_bounded_by_consume_timeout() {
|
||||
let mut config = ProxyConfig::default();
|
||||
config.general.beobachten = false;
|
||||
config.censorship.mask = true;
|
||||
config.censorship.mask_unix_sock = None;
|
||||
config.censorship.mask_host = Some("127.0.0.1".to_string());
|
||||
config.censorship.mask_port = closed_local_port();
|
||||
config.censorship.mask_timing_normalization_enabled = false;
|
||||
|
||||
let peer: SocketAddr = "203.0.113.94:55094".parse().unwrap();
|
||||
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
|
||||
let beobachten = BeobachtenStore::new();
|
||||
|
||||
let (mut client, server) = duplex(1024);
|
||||
let started = Instant::now();
|
||||
let task = tokio::spawn(async move {
|
||||
handle_bad_client(
|
||||
server,
|
||||
tokio::io::sink(),
|
||||
b"GET / HTTP/1.1\r\n\r\n",
|
||||
peer,
|
||||
local_addr,
|
||||
&config,
|
||||
&beobachten,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(120)).await;
|
||||
client
|
||||
.write_all(b"still-open-before-timeout")
|
||||
.await
|
||||
.expect("connection should still be open before consume timeout expires");
|
||||
|
||||
timeout(Duration::from_secs(2), task).await.unwrap().unwrap();
|
||||
let elapsed = started.elapsed();
|
||||
|
||||
assert!(
|
||||
elapsed >= Duration::from_millis(190) && elapsed < Duration::from_millis(350),
|
||||
"offline-refusal path must not retain idle client indefinitely"
|
||||
);
|
||||
}
|
||||
|
|
@ -43,6 +43,7 @@ async fn run_relay_case(
|
|||
above_cap_blur,
|
||||
above_cap_blur_max_bytes,
|
||||
false,
|
||||
5 * 1024 * 1024,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
|
|
|
|||
|
|
@ -88,6 +88,7 @@ async fn relay_to_mask_applies_cap_clamped_padding_for_non_power_of_two_cap() {
|
|||
false,
|
||||
0,
|
||||
false,
|
||||
5 * 1024 * 1024,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
|
|
|
|||
|
|
@ -0,0 +1,55 @@
|
|||
#![cfg(unix)]
|
||||
|
||||
use super::*;
|
||||
use tokio::io::{AsyncWriteExt, duplex};
|
||||
use tokio::time::{Duration, Instant, timeout};
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn adversarial_delayed_interface_lookup_does_not_consume_outcome_floor_budget() {
|
||||
let mut config = ProxyConfig::default();
|
||||
config.general.beobachten = false;
|
||||
config.censorship.mask = true;
|
||||
config.censorship.mask_unix_sock = None;
|
||||
config.censorship.mask_host = Some("127.0.0.1".to_string());
|
||||
config.censorship.mask_port = 443;
|
||||
config.censorship.mask_timing_normalization_enabled = true;
|
||||
config.censorship.mask_timing_normalization_floor_ms = 120;
|
||||
config.censorship.mask_timing_normalization_ceiling_ms = 120;
|
||||
|
||||
let peer: SocketAddr = "203.0.113.151:55151".parse().expect("valid peer");
|
||||
let local_addr: SocketAddr = "0.0.0.0:443".parse().expect("valid local addr");
|
||||
let beobachten = BeobachtenStore::new();
|
||||
|
||||
let refresh_lock = LOCAL_INTERFACE_REFRESH_LOCK.get_or_init(|| AsyncMutex::new(()));
|
||||
let held_refresh_guard = refresh_lock.lock().await;
|
||||
|
||||
let (mut client, server) = duplex(1024);
|
||||
let started = Instant::now();
|
||||
let task = tokio::spawn(async move {
|
||||
handle_bad_client(
|
||||
server,
|
||||
tokio::io::sink(),
|
||||
b"GET / HTTP/1.1\r\n\r\n",
|
||||
peer,
|
||||
local_addr,
|
||||
&config,
|
||||
&beobachten,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(80)).await;
|
||||
drop(held_refresh_guard);
|
||||
client.shutdown().await.expect("client shutdown must succeed");
|
||||
|
||||
timeout(Duration::from_secs(2), task)
|
||||
.await
|
||||
.expect("task must finish in bounded time")
|
||||
.expect("task must not panic");
|
||||
let elapsed = started.elapsed();
|
||||
|
||||
assert!(
|
||||
elapsed >= Duration::from_millis(180) && elapsed < Duration::from_millis(350),
|
||||
"timing normalization floor must start after pre-outcome self-target checks"
|
||||
);
|
||||
}
|
||||
|
|
@ -9,6 +9,7 @@ use tokio::time::{Duration, timeout};
|
|||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn blackhat_campaign_saturation_quota_race_with_queue_pressure_stays_fail_closed() {
|
||||
let _guard = super::quota_user_lock_test_scope();
|
||||
let _pressure_guard = super::relay_idle_pressure_test_scope();
|
||||
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
map.clear();
|
||||
|
||||
|
|
|
|||
|
|
@ -645,6 +645,75 @@ fn quota_exceeded_boundary_is_inclusive() {
|
|||
assert!(!quota_exceeded_for_user(&stats, user, Some(51)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quota_soft_helper_matches_capped_generic_helper_matrix() {
|
||||
let stats = Stats::new();
|
||||
let user = "quota-soft-parity";
|
||||
|
||||
for used in [0u64, 1, 7, 63, 127, 255] {
|
||||
stats.sub_user_octets_to(user, stats.get_user_total_octets(user));
|
||||
stats.add_user_octets_to(user, used);
|
||||
|
||||
for quota in [8u64, 64, 128, 256] {
|
||||
for overshoot in [0u64, 1, 5, 32] {
|
||||
for bytes in [0u64, 1, 2, 7, 31, 64] {
|
||||
let soft = quota_would_be_exceeded_for_user_soft(
|
||||
&stats,
|
||||
user,
|
||||
Some(quota),
|
||||
bytes,
|
||||
overshoot,
|
||||
);
|
||||
let capped = quota_would_be_exceeded_for_user(
|
||||
&stats,
|
||||
user,
|
||||
Some(quota_soft_cap(quota, overshoot)),
|
||||
bytes,
|
||||
);
|
||||
assert_eq!(
|
||||
soft, capped,
|
||||
"soft helper parity mismatch: used={used} quota={quota} overshoot={overshoot} bytes={bytes}"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quota_soft_helper_none_limit_never_rejects() {
|
||||
let stats = Stats::new();
|
||||
let user = "quota-soft-none";
|
||||
stats.add_user_octets_to(user, u64::MAX);
|
||||
|
||||
assert!(!quota_would_be_exceeded_for_user_soft(
|
||||
&stats,
|
||||
user,
|
||||
None,
|
||||
u64::MAX,
|
||||
u64::MAX,
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quota_soft_cap_saturates_and_stays_fail_closed() {
|
||||
let stats = Stats::new();
|
||||
let user = "quota-soft-saturating";
|
||||
let quota = u64::MAX - 2;
|
||||
let overshoot = 100;
|
||||
|
||||
assert_eq!(quota_soft_cap(quota, overshoot), u64::MAX);
|
||||
|
||||
stats.add_user_octets_to(user, u64::MAX - 1);
|
||||
assert!(quota_would_be_exceeded_for_user_soft(
|
||||
&stats,
|
||||
user,
|
||||
Some(quota),
|
||||
2,
|
||||
overshoot,
|
||||
));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn enqueue_c2me_close_fast_path_succeeds_without_backpressure() {
|
||||
let (tx, mut rx) = mpsc::channel::<C2MeCommand>(4);
|
||||
|
|
|
|||
|
|
@ -0,0 +1,295 @@
|
|||
use super::*;
|
||||
use crate::crypto::{AesCtr, SecureRandom};
|
||||
use crate::stats::Stats;
|
||||
use crate::stream::CryptoWriter;
|
||||
use bytes::Bytes;
|
||||
use std::pin::Pin;
|
||||
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::task::{Context, Poll, Waker};
|
||||
use tokio::io::AsyncWrite;
|
||||
use tokio::sync::Notify;
|
||||
use tokio::task::JoinSet;
|
||||
use tokio::time::{Duration, timeout};
|
||||
|
||||
fn make_crypto_writer<W>(writer: W) -> CryptoWriter<W>
|
||||
where
|
||||
W: tokio::io::AsyncWrite + Unpin,
|
||||
{
|
||||
let key = [0u8; 32];
|
||||
let iv = 0u128;
|
||||
CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024)
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct BlockingWriteState {
|
||||
write_entered: AtomicBool,
|
||||
released: AtomicBool,
|
||||
write_waker: Mutex<Option<Waker>>,
|
||||
write_entered_notify: Notify,
|
||||
}
|
||||
|
||||
struct BlockingWrite {
|
||||
state: Arc<BlockingWriteState>,
|
||||
}
|
||||
|
||||
impl BlockingWrite {
|
||||
fn new(state: Arc<BlockingWriteState>) -> Self {
|
||||
Self { state }
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for BlockingWrite {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<std::io::Result<usize>> {
|
||||
self.state.write_entered.store(true, Ordering::Release);
|
||||
self.state.write_entered_notify.notify_waiters();
|
||||
|
||||
if self.state.released.load(Ordering::Acquire) {
|
||||
return Poll::Ready(Ok(buf.len()));
|
||||
}
|
||||
|
||||
if let Ok(mut slot) = self.state.write_waker.lock() {
|
||||
*slot = Some(cx.waker().clone());
|
||||
}
|
||||
|
||||
Poll::Pending
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
async fn wait_until_blocking_write_entered(state: &Arc<BlockingWriteState>) {
|
||||
for _ in 0..8 {
|
||||
if state.write_entered.load(Ordering::Acquire) {
|
||||
return;
|
||||
}
|
||||
let _ = timeout(Duration::from_millis(25), state.write_entered_notify.notified()).await;
|
||||
}
|
||||
|
||||
panic!("blocking writer did not enter poll_write in bounded time");
|
||||
}
|
||||
|
||||
fn release_blocking_write(state: &Arc<BlockingWriteState>) {
|
||||
state.released.store(true, Ordering::Release);
|
||||
if let Ok(mut slot) = state.write_waker.lock()
|
||||
&& let Some(waker) = slot.take()
|
||||
{
|
||||
waker.wake();
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn adversarial_blocked_write_releases_cross_mode_lock_and_preserves_fail_closed_quota() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let user = format!("middle-cross-release-regression-{}", std::process::id());
|
||||
let cross_mode_lock = Arc::new(cross_mode_quota_user_lock_for_tests(&user));
|
||||
let bytes_me2c = Arc::new(AtomicU64::new(0));
|
||||
let writer_state = Arc::new(BlockingWriteState::default());
|
||||
|
||||
let first = {
|
||||
let stats = Arc::clone(&stats);
|
||||
let user = user.clone();
|
||||
let cross_mode_lock = Arc::clone(&cross_mode_lock);
|
||||
let bytes_me2c = Arc::clone(&bytes_me2c);
|
||||
let writer_state = Arc::clone(&writer_state);
|
||||
tokio::spawn(async move {
|
||||
let mut writer = make_crypto_writer(BlockingWrite::new(writer_state));
|
||||
let mut frame_buf = Vec::new();
|
||||
process_me_writer_response_with_cross_mode_lock(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0xAA, 0xBB, 0xCC, 0xDD]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
stats.as_ref(),
|
||||
&user,
|
||||
Some(4),
|
||||
0,
|
||||
Some(&cross_mode_lock),
|
||||
bytes_me2c.as_ref(),
|
||||
41_000,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await
|
||||
})
|
||||
};
|
||||
|
||||
wait_until_blocking_write_entered(&writer_state).await;
|
||||
|
||||
let guard = timeout(Duration::from_millis(40), cross_mode_lock.lock())
|
||||
.await
|
||||
.expect("cross-mode lock must be released while first write is pending");
|
||||
drop(guard);
|
||||
|
||||
let second = {
|
||||
let stats = Arc::clone(&stats);
|
||||
let user = user.clone();
|
||||
let cross_mode_lock = Arc::clone(&cross_mode_lock);
|
||||
let bytes_me2c = Arc::clone(&bytes_me2c);
|
||||
tokio::spawn(async move {
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
timeout(
|
||||
Duration::from_millis(150),
|
||||
process_me_writer_response_with_cross_mode_lock(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0xEE]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
stats.as_ref(),
|
||||
&user,
|
||||
Some(4),
|
||||
0,
|
||||
Some(&cross_mode_lock),
|
||||
bytes_me2c.as_ref(),
|
||||
41_001,
|
||||
false,
|
||||
false,
|
||||
),
|
||||
)
|
||||
.await
|
||||
})
|
||||
};
|
||||
|
||||
let second_result = second
|
||||
.await
|
||||
.expect("second task must not panic")
|
||||
.expect("second write must not block on cross-mode lock");
|
||||
assert!(
|
||||
matches!(second_result, Err(ProxyError::DataQuotaExceeded { .. })),
|
||||
"second write must fail closed due to first write reservation"
|
||||
);
|
||||
|
||||
release_blocking_write(&writer_state);
|
||||
|
||||
let first_result = timeout(Duration::from_millis(300), first)
|
||||
.await
|
||||
.expect("first task timed out")
|
||||
.expect("first task must not panic");
|
||||
assert!(first_result.is_ok());
|
||||
|
||||
assert_eq!(stats.get_user_total_octets(&user), 4);
|
||||
assert_eq!(bytes_me2c.load(Ordering::Relaxed), 4);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn stress_pending_write_does_not_starve_same_user_waiters_after_quota_boundary() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let user = format!("middle-cross-release-stress-{}", std::process::id());
|
||||
let cross_mode_lock = Arc::new(cross_mode_quota_user_lock_for_tests(&user));
|
||||
let bytes_me2c = Arc::new(AtomicU64::new(0));
|
||||
let writer_state = Arc::new(BlockingWriteState::default());
|
||||
|
||||
let first = {
|
||||
let stats = Arc::clone(&stats);
|
||||
let user = user.clone();
|
||||
let cross_mode_lock = Arc::clone(&cross_mode_lock);
|
||||
let bytes_me2c = Arc::clone(&bytes_me2c);
|
||||
let writer_state = Arc::clone(&writer_state);
|
||||
tokio::spawn(async move {
|
||||
let mut writer = make_crypto_writer(BlockingWrite::new(writer_state));
|
||||
let mut frame_buf = Vec::new();
|
||||
process_me_writer_response_with_cross_mode_lock(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0x01, 0x02]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
stats.as_ref(),
|
||||
&user,
|
||||
Some(3),
|
||||
0,
|
||||
Some(&cross_mode_lock),
|
||||
bytes_me2c.as_ref(),
|
||||
41_100,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await
|
||||
})
|
||||
};
|
||||
|
||||
wait_until_blocking_write_entered(&writer_state).await;
|
||||
|
||||
let mut set = JoinSet::new();
|
||||
for idx in 0..48u64 {
|
||||
let stats = Arc::clone(&stats);
|
||||
let user = user.clone();
|
||||
let cross_mode_lock = Arc::clone(&cross_mode_lock);
|
||||
let bytes_me2c = Arc::clone(&bytes_me2c);
|
||||
set.spawn(async move {
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
timeout(
|
||||
Duration::from_millis(200),
|
||||
process_me_writer_response_with_cross_mode_lock(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0x10]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
stats.as_ref(),
|
||||
&user,
|
||||
Some(3),
|
||||
0,
|
||||
Some(&cross_mode_lock),
|
||||
bytes_me2c.as_ref(),
|
||||
41_200 + idx,
|
||||
false,
|
||||
false,
|
||||
),
|
||||
)
|
||||
.await
|
||||
});
|
||||
}
|
||||
|
||||
let mut ok = 0usize;
|
||||
let mut quota_exceeded = 0usize;
|
||||
while let Some(done) = set.join_next().await {
|
||||
let timed = done.expect("waiter task must not panic");
|
||||
let result = timed.expect("waiter must not block behind pending first write");
|
||||
match result {
|
||||
Ok(_) => ok += 1,
|
||||
Err(ProxyError::DataQuotaExceeded { .. }) => quota_exceeded += 1,
|
||||
Err(other) => panic!("unexpected error in waiter: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(ok, 1, "exactly one waiter should consume remaining one-byte quota");
|
||||
assert_eq!(quota_exceeded, 47);
|
||||
|
||||
release_blocking_write(&writer_state);
|
||||
|
||||
let first_result = timeout(Duration::from_millis(300), first)
|
||||
.await
|
||||
.expect("first task timed out")
|
||||
.expect("first task must not panic");
|
||||
assert!(first_result.is_ok());
|
||||
|
||||
assert_eq!(stats.get_user_total_octets(&user), 3);
|
||||
assert_eq!(bytes_me2c.load(Ordering::Relaxed), 3);
|
||||
}
|
||||
|
|
@ -0,0 +1,116 @@
|
|||
use super::*;
|
||||
use crate::crypto::{AesCtr, SecureRandom};
|
||||
use crate::stats::Stats;
|
||||
use crate::stream::CryptoWriter;
|
||||
use bytes::Bytes;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::{Mutex, OnceLock};
|
||||
|
||||
fn make_crypto_writer<W>(writer: W) -> CryptoWriter<W>
|
||||
where
|
||||
W: tokio::io::AsyncWrite + Unpin,
|
||||
{
|
||||
let key = [0u8; 32];
|
||||
let iv = 0u128;
|
||||
CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024)
|
||||
}
|
||||
|
||||
fn lookup_counter_test_lock() -> &'static Mutex<()> {
|
||||
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
LOCK.get_or_init(|| Mutex::new(()))
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn tdd_prefetched_cross_mode_lock_avoids_per_frame_registry_lookup_in_me_to_client_writer() {
|
||||
let _guard = lookup_counter_test_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(|poison| poison.into_inner());
|
||||
|
||||
let stats = Stats::new();
|
||||
let user = format!("middle-cross-mode-lookup-{}", std::process::id());
|
||||
let cross_mode_lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
|
||||
|
||||
crate::proxy::quota_lock_registry::reset_cross_mode_quota_user_lock_lookup_count_for_tests();
|
||||
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
|
||||
for idx in 0..8u64 {
|
||||
let outcome = process_me_writer_response_with_cross_mode_lock(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0xAB]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(1024),
|
||||
0,
|
||||
Some(&cross_mode_lock),
|
||||
&bytes_me2c,
|
||||
20_000 + idx,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(outcome.is_ok());
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
crate::proxy::quota_lock_registry::cross_mode_quota_user_lock_lookup_count_for_user_for_tests(&user),
|
||||
0,
|
||||
"prefetched lock path must not re-query lock registry per frame"
|
||||
);
|
||||
assert_eq!(stats.get_user_total_octets(&user), 8);
|
||||
assert_eq!(bytes_me2c.load(Ordering::Relaxed), 8);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn control_without_prefetched_lock_still_uses_registry_lookup_path() {
|
||||
let _guard = lookup_counter_test_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(|poison| poison.into_inner());
|
||||
|
||||
let stats = Stats::new();
|
||||
let user = format!("middle-cross-mode-lookup-control-{}", std::process::id());
|
||||
|
||||
crate::proxy::quota_lock_registry::reset_cross_mode_quota_user_lock_lookup_count_for_tests();
|
||||
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
|
||||
let outcome = process_me_writer_response_with_cross_mode_lock(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0xCD]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(1024),
|
||||
0,
|
||||
None,
|
||||
&bytes_me2c,
|
||||
20_100,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(outcome.is_ok());
|
||||
assert_eq!(
|
||||
crate::proxy::quota_lock_registry::cross_mode_quota_user_lock_lookup_count_for_user_for_tests(&user),
|
||||
1,
|
||||
"fallback path without prefetched lock should perform a registry lookup"
|
||||
);
|
||||
}
|
||||
|
|
@ -0,0 +1,376 @@
|
|||
use super::*;
|
||||
use crate::crypto::{AesCtr, SecureRandom};
|
||||
use crate::stats::Stats;
|
||||
use crate::stream::CryptoWriter;
|
||||
use bytes::Bytes;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use tokio::time::{Duration, timeout};
|
||||
|
||||
fn make_crypto_writer<W>(writer: W) -> CryptoWriter<W>
|
||||
where
|
||||
W: tokio::io::AsyncWrite + Unpin,
|
||||
{
|
||||
let key = [0u8; 32];
|
||||
let iv = 0u128;
|
||||
CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn positive_quota_limited_me_to_client_write_updates_counters_exactly_once() {
|
||||
let stats = Stats::new();
|
||||
let user = format!("middle-cross-matrix-positive-{}", std::process::id());
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
|
||||
let result = process_me_writer_response(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[1, 2, 3, 4]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(128),
|
||||
0,
|
||||
&bytes_me2c,
|
||||
10_001,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(stats.get_user_total_octets(&user), 4);
|
||||
assert_eq!(bytes_me2c.load(Ordering::Relaxed), 4);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn negative_held_cross_mode_lock_blocks_quota_limited_me_to_client_path() {
|
||||
let stats = Stats::new();
|
||||
let user = format!("middle-cross-matrix-negative-{}", std::process::id());
|
||||
let held = cross_mode_quota_user_lock_for_tests(&user);
|
||||
let held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold lock before ME->C call");
|
||||
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
|
||||
let blocked = timeout(
|
||||
Duration::from_millis(25),
|
||||
process_me_writer_response(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0x41]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(256),
|
||||
0,
|
||||
&bytes_me2c,
|
||||
10_002,
|
||||
false,
|
||||
false,
|
||||
),
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(blocked.is_err());
|
||||
drop(held_guard);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn edge_quota_none_bypasses_cross_mode_lock_guard_in_me_to_client_path() {
|
||||
let stats = Stats::new();
|
||||
let user = format!("middle-cross-matrix-edge-none-{}", std::process::id());
|
||||
let held = cross_mode_quota_user_lock_for_tests(&user);
|
||||
let held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold lock while quota is disabled");
|
||||
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
|
||||
let outcome = timeout(
|
||||
Duration::from_millis(80),
|
||||
process_me_writer_response(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0x11, 0x22]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
None,
|
||||
0,
|
||||
&bytes_me2c,
|
||||
10_003,
|
||||
false,
|
||||
false,
|
||||
),
|
||||
)
|
||||
.await
|
||||
.expect("quota-none path must not wait on cross-mode lock");
|
||||
|
||||
assert!(outcome.is_ok());
|
||||
drop(held_guard);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn adversarial_same_user_parallel_quota_limited_writes_stay_hard_capped() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let user = format!("middle-cross-matrix-adversarial-{}", std::process::id());
|
||||
let limit = 64u64;
|
||||
let bytes_me2c = Arc::new(AtomicU64::new(0));
|
||||
let mut tasks = Vec::new();
|
||||
|
||||
for idx in 0..256u64 {
|
||||
let stats = Arc::clone(&stats);
|
||||
let bytes_me2c = Arc::clone(&bytes_me2c);
|
||||
let user = user.clone();
|
||||
tasks.push(tokio::spawn(async move {
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
process_me_writer_response(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0xEE]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
stats.as_ref(),
|
||||
&user,
|
||||
Some(limit),
|
||||
0,
|
||||
bytes_me2c.as_ref(),
|
||||
11_000 + idx,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await
|
||||
}));
|
||||
}
|
||||
|
||||
let mut ok = 0usize;
|
||||
for task in tasks {
|
||||
match task.await.expect("task must not panic") {
|
||||
Ok(_) => ok += 1,
|
||||
Err(ProxyError::DataQuotaExceeded { .. }) => {}
|
||||
Err(other) => panic!("unexpected error in adversarial parallel case: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(ok, limit as usize);
|
||||
assert_eq!(stats.get_user_total_octets(&user), limit);
|
||||
assert_eq!(bytes_me2c.load(Ordering::Relaxed), limit);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn integration_shared_lock_blocks_direct_relay_and_middle_relay_for_same_user() {
|
||||
let user = format!("middle-cross-matrix-integration-{}", std::process::id());
|
||||
let relay_lock = crate::proxy::relay::cross_mode_quota_user_lock_for_tests(&user);
|
||||
let middle_lock = cross_mode_quota_user_lock_for_tests(&user);
|
||||
assert!(
|
||||
Arc::ptr_eq(&relay_lock, &middle_lock),
|
||||
"relay and middle-relay must share the same cross-mode lock identity"
|
||||
);
|
||||
|
||||
let held_guard = relay_lock
|
||||
.try_lock()
|
||||
.expect("test must hold shared cross-mode lock");
|
||||
|
||||
let stats = Stats::new();
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
|
||||
let middle_blocked = timeout(
|
||||
Duration::from_millis(25),
|
||||
process_me_writer_response(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0x92]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(1024),
|
||||
0,
|
||||
&bytes_me2c,
|
||||
12_001,
|
||||
false,
|
||||
false,
|
||||
),
|
||||
)
|
||||
.await;
|
||||
assert!(middle_blocked.is_err());
|
||||
|
||||
drop(held_guard);
|
||||
|
||||
let middle_ready = timeout(
|
||||
Duration::from_millis(250),
|
||||
process_me_writer_response(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0x94]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(1024),
|
||||
0,
|
||||
&bytes_me2c,
|
||||
12_002,
|
||||
false,
|
||||
false,
|
||||
),
|
||||
)
|
||||
.await
|
||||
.expect("middle path must complete after release");
|
||||
|
||||
assert!(middle_ready.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn light_fuzz_mixed_payload_sizes_with_periodic_lock_holds_keeps_accounting_consistent() {
|
||||
let stats = Stats::new();
|
||||
let user = format!("middle-cross-matrix-fuzz-{}", std::process::id());
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
let mut seed = 0xC0DE_1234_55AA_9988u64;
|
||||
|
||||
for case in 0..96u32 {
|
||||
seed ^= seed << 7;
|
||||
seed ^= seed >> 9;
|
||||
seed ^= seed << 8;
|
||||
|
||||
let hold = (seed & 0x03) == 0;
|
||||
let mut held_lock = None;
|
||||
let maybe_guard = if hold {
|
||||
held_lock = Some(cross_mode_quota_user_lock_for_tests(&user));
|
||||
Some(
|
||||
held_lock
|
||||
.as_ref()
|
||||
.expect("held lock should be present")
|
||||
.try_lock()
|
||||
.expect("cross-mode lock should be acquirable in fuzz round"),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let payload_len = ((seed >> 8) as usize % 8) + 1;
|
||||
let payload = vec![(seed & 0xff) as u8; payload_len];
|
||||
let before = stats.get_user_total_octets(&user);
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
|
||||
let timed = timeout(
|
||||
Duration::from_millis(20),
|
||||
process_me_writer_response(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from(payload),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(1024),
|
||||
0,
|
||||
&bytes_me2c,
|
||||
13_000 + case as u64,
|
||||
false,
|
||||
false,
|
||||
),
|
||||
)
|
||||
.await;
|
||||
|
||||
if hold {
|
||||
assert!(timed.is_err(), "held-lock fuzz round must block within timeout");
|
||||
assert_eq!(stats.get_user_total_octets(&user), before);
|
||||
} else {
|
||||
let done = timed.expect("unheld fuzz round must complete in time");
|
||||
assert!(done.is_ok());
|
||||
}
|
||||
|
||||
drop(maybe_guard);
|
||||
drop(held_lock);
|
||||
assert_eq!(bytes_me2c.load(Ordering::Relaxed), stats.get_user_total_octets(&user));
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn stress_held_user_lock_does_not_block_other_users_me_to_client_writes() {
|
||||
let held_user = format!("middle-cross-matrix-stress-held-{}", std::process::id());
|
||||
let free_user = format!("middle-cross-matrix-stress-free-{}", std::process::id());
|
||||
|
||||
let held = cross_mode_quota_user_lock_for_tests(&held_user);
|
||||
let held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold lock for blocked user");
|
||||
|
||||
let mut tasks = Vec::new();
|
||||
for idx in 0..64u64 {
|
||||
let user = free_user.clone();
|
||||
tasks.push(tokio::spawn(async move {
|
||||
let stats = Stats::new();
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
process_me_writer_response(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0xA0]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(1),
|
||||
0,
|
||||
&bytes_me2c,
|
||||
14_000 + idx,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await
|
||||
}));
|
||||
}
|
||||
|
||||
timeout(Duration::from_secs(2), async {
|
||||
for task in tasks {
|
||||
let done = task.await.expect("free-user task must not panic");
|
||||
assert!(done.is_ok());
|
||||
}
|
||||
})
|
||||
.await
|
||||
.expect("free-user tasks should complete without waiting for held user's lock");
|
||||
|
||||
drop(held_guard);
|
||||
}
|
||||
|
|
@ -0,0 +1,254 @@
|
|||
use super::*;
|
||||
use crate::crypto::{AesCtr, SecureRandom};
|
||||
use crate::stats::Stats;
|
||||
use crate::stream::CryptoWriter;
|
||||
use bytes::Bytes;
|
||||
use std::pin::Pin;
|
||||
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::task::{Context, Poll, Waker};
|
||||
use tokio::io::AsyncWrite;
|
||||
use tokio::sync::Notify;
|
||||
use tokio::time::{Duration, timeout};
|
||||
|
||||
fn make_crypto_writer<W>(writer: W) -> CryptoWriter<W>
|
||||
where
|
||||
W: tokio::io::AsyncWrite + Unpin,
|
||||
{
|
||||
let key = [0u8; 32];
|
||||
let iv = 0u128;
|
||||
CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024)
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct BlockingWriteState {
|
||||
write_entered: AtomicBool,
|
||||
released: AtomicBool,
|
||||
write_waker: Mutex<Option<Waker>>,
|
||||
write_entered_notify: Notify,
|
||||
}
|
||||
|
||||
struct BlockingWrite {
|
||||
state: Arc<BlockingWriteState>,
|
||||
}
|
||||
|
||||
impl BlockingWrite {
|
||||
fn new(state: Arc<BlockingWriteState>) -> Self {
|
||||
Self { state }
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for BlockingWrite {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<std::io::Result<usize>> {
|
||||
self.state.write_entered.store(true, Ordering::Release);
|
||||
self.state.write_entered_notify.notify_waiters();
|
||||
|
||||
if self.state.released.load(Ordering::Acquire) {
|
||||
return Poll::Ready(Ok(buf.len()));
|
||||
}
|
||||
|
||||
if let Ok(mut slot) = self.state.write_waker.lock() {
|
||||
*slot = Some(cx.waker().clone());
|
||||
}
|
||||
Poll::Pending
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
async fn wait_until_blocking_write_entered(state: &Arc<BlockingWriteState>) {
|
||||
for _ in 0..8 {
|
||||
if state.write_entered.load(Ordering::Acquire) {
|
||||
return;
|
||||
}
|
||||
let _ = timeout(Duration::from_millis(25), state.write_entered_notify.notified()).await;
|
||||
}
|
||||
|
||||
panic!("blocking writer did not enter poll_write in bounded time");
|
||||
}
|
||||
|
||||
fn release_blocking_write(state: &Arc<BlockingWriteState>) {
|
||||
state.released.store(true, Ordering::Release);
|
||||
if let Ok(mut slot) = state.write_waker.lock()
|
||||
&& let Some(waker) = slot.take()
|
||||
{
|
||||
waker.wake();
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn adversarial_held_cross_mode_lock_blocks_me_to_client_quota_reservation_path() {
|
||||
let stats = Stats::new();
|
||||
let user = format!("middle-me2c-cross-mode-held-{}", std::process::id());
|
||||
let held = cross_mode_quota_user_lock_for_tests(&user);
|
||||
let held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold shared cross-mode lock before ME->C write path");
|
||||
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
|
||||
let blocked = timeout(
|
||||
Duration::from_millis(25),
|
||||
process_me_writer_response(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0x41]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(1024),
|
||||
0,
|
||||
&bytes_me2c,
|
||||
9901,
|
||||
false,
|
||||
false,
|
||||
),
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(
|
||||
blocked.is_err(),
|
||||
"ME->C quota reservation path must be serialized by held shared cross-mode lock"
|
||||
);
|
||||
|
||||
drop(held_guard);
|
||||
|
||||
let released = timeout(
|
||||
Duration::from_millis(250),
|
||||
process_me_writer_response(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0x42]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(1024),
|
||||
0,
|
||||
&bytes_me2c,
|
||||
9902,
|
||||
false,
|
||||
false,
|
||||
),
|
||||
)
|
||||
.await
|
||||
.expect("ME->C write must complete after cross-mode lock release");
|
||||
|
||||
assert!(released.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn business_uncontended_cross_mode_lock_allows_me_to_client_quota_reservation() {
|
||||
let stats = Stats::new();
|
||||
let user = format!("middle-me2c-cross-mode-free-{}", std::process::id());
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
|
||||
let outcome = timeout(
|
||||
Duration::from_millis(250),
|
||||
process_me_writer_response(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0x55, 0x66]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(1024),
|
||||
0,
|
||||
&bytes_me2c,
|
||||
9903,
|
||||
false,
|
||||
false,
|
||||
),
|
||||
)
|
||||
.await
|
||||
.expect("uncontended ME->C path should not stall");
|
||||
|
||||
assert!(outcome.is_ok());
|
||||
assert_eq!(stats.get_user_total_octets(&user), 2);
|
||||
assert_eq!(bytes_me2c.load(std::sync::atomic::Ordering::Relaxed), 2);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn adversarial_cross_mode_lock_is_released_before_me_to_client_write_await() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let user = format!("middle-me2c-lock-drop-before-write-{}", std::process::id());
|
||||
let cross_mode_lock = cross_mode_quota_user_lock_for_tests(&user);
|
||||
let bytes_me2c = Arc::new(AtomicU64::new(0));
|
||||
let writer_state = Arc::new(BlockingWriteState::default());
|
||||
|
||||
let worker = {
|
||||
let stats = Arc::clone(&stats);
|
||||
let user = user.clone();
|
||||
let cross_mode_lock = Arc::clone(&cross_mode_lock);
|
||||
let bytes_me2c = Arc::clone(&bytes_me2c);
|
||||
let writer_state = Arc::clone(&writer_state);
|
||||
tokio::spawn(async move {
|
||||
let mut writer = make_crypto_writer(BlockingWrite::new(writer_state));
|
||||
let mut frame_buf = Vec::new();
|
||||
let rng = SecureRandom::new();
|
||||
process_me_writer_response_with_cross_mode_lock(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0xDE, 0xAD, 0xBE, 0xEF]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&rng,
|
||||
&mut frame_buf,
|
||||
stats.as_ref(),
|
||||
&user,
|
||||
Some(1024),
|
||||
0,
|
||||
Some(&cross_mode_lock),
|
||||
bytes_me2c.as_ref(),
|
||||
9910,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await
|
||||
})
|
||||
};
|
||||
|
||||
wait_until_blocking_write_entered(&writer_state).await;
|
||||
|
||||
let acquired_guard = timeout(Duration::from_millis(40), cross_mode_lock.lock())
|
||||
.await
|
||||
.expect("cross-mode lock must be free while ME->C write is pending");
|
||||
drop(acquired_guard);
|
||||
|
||||
release_blocking_write(&writer_state);
|
||||
|
||||
let result = timeout(Duration::from_millis(300), worker)
|
||||
.await
|
||||
.expect("ME->C worker timed out after releasing blocking writer")
|
||||
.expect("ME->C worker must not panic");
|
||||
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(stats.get_user_total_octets(&user), 4);
|
||||
assert_eq!(bytes_me2c.load(Ordering::Relaxed), 4);
|
||||
}
|
||||
|
|
@ -0,0 +1,232 @@
|
|||
use super::*;
|
||||
use crate::crypto::{AesCtr, SecureRandom};
|
||||
use crate::stats::Stats;
|
||||
use crate::stream::CryptoWriter;
|
||||
use bytes::Bytes;
|
||||
use std::io;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
|
||||
use std::task::{Context, Poll, Waker};
|
||||
use tokio::io::AsyncWrite;
|
||||
use tokio::time::{Duration, timeout};
|
||||
|
||||
fn make_crypto_writer<W>(writer: W) -> CryptoWriter<W>
|
||||
where
|
||||
W: tokio::io::AsyncWrite + Unpin,
|
||||
{
|
||||
let key = [0u8; 32];
|
||||
let iv = 0u128;
|
||||
CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024)
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct GateState {
|
||||
open: AtomicBool,
|
||||
parked_waker: std::sync::Mutex<Option<Waker>>,
|
||||
}
|
||||
|
||||
impl GateState {
|
||||
fn open(&self) {
|
||||
self.open.store(true, Ordering::Relaxed);
|
||||
if let Ok(mut guard) = self.parked_waker.lock()
|
||||
&& let Some(w) = guard.take()
|
||||
{
|
||||
w.wake();
|
||||
}
|
||||
}
|
||||
|
||||
fn has_waiter(&self) -> bool {
|
||||
self.parked_waker
|
||||
.lock()
|
||||
.map(|guard| guard.is_some())
|
||||
.unwrap_or(false)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct GateWriter {
|
||||
gate: Arc<GateState>,
|
||||
}
|
||||
|
||||
impl GateWriter {
|
||||
fn new(gate: Arc<GateState>) -> Self {
|
||||
Self { gate }
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for GateWriter {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
if self.gate.open.load(Ordering::Relaxed) {
|
||||
return Poll::Ready(Ok(buf.len()));
|
||||
}
|
||||
|
||||
if let Ok(mut guard) = self.gate.parked_waker.lock() {
|
||||
*guard = Some(cx.waker().clone());
|
||||
}
|
||||
Poll::Pending
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
struct FailingWriter;
|
||||
|
||||
impl AsyncWrite for FailingWriter {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
_buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
Poll::Ready(Err(io::Error::new(
|
||||
io::ErrorKind::BrokenPipe,
|
||||
"injected writer failure",
|
||||
)))
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn adversarial_same_user_slow_writer_must_not_hol_block_peer_connection() {
|
||||
let stats = Stats::new();
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
let rng = SecureRandom::new();
|
||||
let quota_limit = Some(1024);
|
||||
let user = "hol-quota-user";
|
||||
|
||||
let gate = Arc::new(GateState::default());
|
||||
|
||||
let mut blocked_writer = make_crypto_writer(GateWriter::new(Arc::clone(&gate)));
|
||||
let slow_task = tokio::spawn(async move {
|
||||
let mut frame_buf = Vec::new();
|
||||
process_me_writer_response(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0x10, 0x20, 0x30, 0x40]),
|
||||
},
|
||||
&mut blocked_writer,
|
||||
ProtoTag::Intermediate,
|
||||
&rng,
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
user,
|
||||
quota_limit,
|
||||
0,
|
||||
&bytes_me2c,
|
||||
7001,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await
|
||||
});
|
||||
|
||||
timeout(Duration::from_millis(100), async {
|
||||
loop {
|
||||
if gate.has_waiter() {
|
||||
break;
|
||||
}
|
||||
tokio::task::yield_now().await;
|
||||
}
|
||||
})
|
||||
.await
|
||||
.expect("first writer must reach backpressure and park");
|
||||
|
||||
let stats_fast = Stats::new();
|
||||
let bytes_fast = AtomicU64::new(0);
|
||||
let rng_fast = SecureRandom::new();
|
||||
let mut fast_writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf_fast = Vec::new();
|
||||
|
||||
timeout(
|
||||
Duration::from_millis(50),
|
||||
process_me_writer_response(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0x41]),
|
||||
},
|
||||
&mut fast_writer,
|
||||
ProtoTag::Intermediate,
|
||||
&rng_fast,
|
||||
&mut frame_buf_fast,
|
||||
&stats_fast,
|
||||
user,
|
||||
quota_limit,
|
||||
0,
|
||||
&bytes_fast,
|
||||
7002,
|
||||
false,
|
||||
false,
|
||||
),
|
||||
)
|
||||
.await
|
||||
.expect("peer connection must not be blocked by same-user stalled write")
|
||||
.expect("fast peer write must succeed");
|
||||
|
||||
gate.open();
|
||||
let slow_result = timeout(Duration::from_secs(1), slow_task)
|
||||
.await
|
||||
.expect("stalled task must complete once gate opens")
|
||||
.expect("stalled task must not panic");
|
||||
assert!(slow_result.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn negative_write_failure_rolls_back_pre_accounted_quota_and_forensics_bytes() {
|
||||
let stats = Stats::new();
|
||||
let user = "rollback-user";
|
||||
stats.add_user_octets_from(user, 7);
|
||||
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
let rng = SecureRandom::new();
|
||||
let mut writer = make_crypto_writer(FailingWriter);
|
||||
let mut frame_buf = Vec::new();
|
||||
|
||||
let result = process_me_writer_response(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[1, 2, 3, 4]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&rng,
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
user,
|
||||
Some(64),
|
||||
0,
|
||||
&bytes_me2c,
|
||||
7003,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(matches!(result, Err(ProxyError::Io(_))));
|
||||
assert_eq!(
|
||||
stats.get_user_total_octets(user),
|
||||
7,
|
||||
"failed client write must not overcharge user quota accounting"
|
||||
);
|
||||
assert_eq!(
|
||||
bytes_me2c.load(Ordering::Relaxed),
|
||||
0,
|
||||
"failed client write must not inflate ME->C forensic byte counter"
|
||||
);
|
||||
}
|
||||
|
|
@ -3,7 +3,7 @@ use crate::crypto::AesCtr;
|
|||
use crate::stats::Stats;
|
||||
use crate::stream::{BufferPool, CryptoReader};
|
||||
use std::sync::atomic::AtomicU64;
|
||||
use std::sync::{Arc, Mutex, OnceLock};
|
||||
use std::sync::Arc;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::io::duplex;
|
||||
use tokio::time::{Duration as TokioDuration, Instant as TokioInstant, timeout};
|
||||
|
|
@ -48,18 +48,6 @@ fn make_idle_policy(soft_ms: u64, hard_ms: u64, grace_ms: u64) -> RelayClientIdl
|
|||
}
|
||||
}
|
||||
|
||||
fn idle_pressure_test_lock() -> &'static Mutex<()> {
|
||||
static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
TEST_LOCK.get_or_init(|| Mutex::new(()))
|
||||
}
|
||||
|
||||
fn acquire_idle_pressure_test_lock() -> std::sync::MutexGuard<'static, ()> {
|
||||
match idle_pressure_test_lock().lock() {
|
||||
Ok(guard) => guard,
|
||||
Err(poisoned) => poisoned.into_inner(),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn idle_policy_soft_mark_then_hard_close_increments_reason_counters() {
|
||||
let (reader, _writer) = duplex(1024);
|
||||
|
|
@ -372,7 +360,7 @@ async fn stress_many_idle_sessions_fail_closed_without_hang() {
|
|||
|
||||
#[test]
|
||||
fn pressure_evicts_oldest_idle_candidate_with_deterministic_ordering() {
|
||||
let _guard = acquire_idle_pressure_test_lock();
|
||||
let _guard = relay_idle_pressure_test_scope();
|
||||
clear_relay_idle_pressure_state_for_testing();
|
||||
let stats = Stats::new();
|
||||
|
||||
|
|
@ -402,7 +390,7 @@ fn pressure_evicts_oldest_idle_candidate_with_deterministic_ordering() {
|
|||
|
||||
#[test]
|
||||
fn pressure_does_not_evict_without_new_pressure_signal() {
|
||||
let _guard = acquire_idle_pressure_test_lock();
|
||||
let _guard = relay_idle_pressure_test_scope();
|
||||
clear_relay_idle_pressure_state_for_testing();
|
||||
let stats = Stats::new();
|
||||
|
||||
|
|
@ -421,7 +409,7 @@ fn pressure_does_not_evict_without_new_pressure_signal() {
|
|||
|
||||
#[test]
|
||||
fn stress_pressure_eviction_preserves_fifo_across_many_candidates() {
|
||||
let _guard = acquire_idle_pressure_test_lock();
|
||||
let _guard = relay_idle_pressure_test_scope();
|
||||
clear_relay_idle_pressure_state_for_testing();
|
||||
let stats = Stats::new();
|
||||
|
||||
|
|
@ -457,7 +445,7 @@ fn stress_pressure_eviction_preserves_fifo_across_many_candidates() {
|
|||
|
||||
#[test]
|
||||
fn blackhat_single_pressure_event_must_not_evict_more_than_one_candidate() {
|
||||
let _guard = acquire_idle_pressure_test_lock();
|
||||
let _guard = relay_idle_pressure_test_scope();
|
||||
clear_relay_idle_pressure_state_for_testing();
|
||||
let stats = Stats::new();
|
||||
|
||||
|
|
@ -491,7 +479,7 @@ fn blackhat_single_pressure_event_must_not_evict_more_than_one_candidate() {
|
|||
|
||||
#[test]
|
||||
fn blackhat_pressure_counter_must_track_global_budget_not_per_session_cursor() {
|
||||
let _guard = acquire_idle_pressure_test_lock();
|
||||
let _guard = relay_idle_pressure_test_scope();
|
||||
clear_relay_idle_pressure_state_for_testing();
|
||||
let stats = Stats::new();
|
||||
|
||||
|
|
@ -524,7 +512,7 @@ fn blackhat_pressure_counter_must_track_global_budget_not_per_session_cursor() {
|
|||
|
||||
#[test]
|
||||
fn blackhat_stale_pressure_before_idle_mark_must_not_trigger_eviction() {
|
||||
let _guard = acquire_idle_pressure_test_lock();
|
||||
let _guard = relay_idle_pressure_test_scope();
|
||||
clear_relay_idle_pressure_state_for_testing();
|
||||
let stats = Stats::new();
|
||||
|
||||
|
|
@ -543,7 +531,7 @@ fn blackhat_stale_pressure_before_idle_mark_must_not_trigger_eviction() {
|
|||
|
||||
#[test]
|
||||
fn blackhat_stale_pressure_must_not_evict_any_of_newly_marked_batch() {
|
||||
let _guard = acquire_idle_pressure_test_lock();
|
||||
let _guard = relay_idle_pressure_test_scope();
|
||||
clear_relay_idle_pressure_state_for_testing();
|
||||
let stats = Stats::new();
|
||||
|
||||
|
|
@ -575,7 +563,7 @@ fn blackhat_stale_pressure_must_not_evict_any_of_newly_marked_batch() {
|
|||
|
||||
#[test]
|
||||
fn blackhat_stale_pressure_seen_without_candidates_must_be_globally_invalidated() {
|
||||
let _guard = acquire_idle_pressure_test_lock();
|
||||
let _guard = relay_idle_pressure_test_scope();
|
||||
clear_relay_idle_pressure_state_for_testing();
|
||||
let stats = Stats::new();
|
||||
|
||||
|
|
@ -601,7 +589,7 @@ fn blackhat_stale_pressure_seen_without_candidates_must_be_globally_invalidated(
|
|||
|
||||
#[test]
|
||||
fn blackhat_stale_pressure_must_not_survive_candidate_churn() {
|
||||
let _guard = acquire_idle_pressure_test_lock();
|
||||
let _guard = relay_idle_pressure_test_scope();
|
||||
clear_relay_idle_pressure_state_for_testing();
|
||||
let stats = Stats::new();
|
||||
|
||||
|
|
@ -621,7 +609,7 @@ fn blackhat_stale_pressure_must_not_survive_candidate_churn() {
|
|||
|
||||
#[test]
|
||||
fn blackhat_pressure_seq_saturation_must_not_disable_future_pressure_accounting() {
|
||||
let _guard = acquire_idle_pressure_test_lock();
|
||||
let _guard = relay_idle_pressure_test_scope();
|
||||
clear_relay_idle_pressure_state_for_testing();
|
||||
|
||||
{
|
||||
|
|
@ -646,7 +634,7 @@ fn blackhat_pressure_seq_saturation_must_not_disable_future_pressure_accounting(
|
|||
|
||||
#[test]
|
||||
fn blackhat_pressure_seq_saturation_must_not_break_multiple_distinct_events() {
|
||||
let _guard = acquire_idle_pressure_test_lock();
|
||||
let _guard = relay_idle_pressure_test_scope();
|
||||
clear_relay_idle_pressure_state_for_testing();
|
||||
|
||||
{
|
||||
|
|
@ -673,7 +661,7 @@ fn blackhat_pressure_seq_saturation_must_not_break_multiple_distinct_events() {
|
|||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn integration_race_single_pressure_event_allows_at_most_one_eviction_under_parallel_claims()
|
||||
{
|
||||
let _guard = acquire_idle_pressure_test_lock();
|
||||
let _guard = relay_idle_pressure_test_scope();
|
||||
clear_relay_idle_pressure_state_for_testing();
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
|
|
@ -738,7 +726,7 @@ async fn integration_race_single_pressure_event_allows_at_most_one_eviction_unde
|
|||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn integration_race_burst_pressure_with_churn_preserves_empty_set_invalidation_and_budget() {
|
||||
let _guard = acquire_idle_pressure_test_lock();
|
||||
let _guard = relay_idle_pressure_test_scope();
|
||||
clear_relay_idle_pressure_state_for_testing();
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
|
|
|
|||
|
|
@ -0,0 +1,59 @@
|
|||
use super::*;
|
||||
use std::panic::{AssertUnwindSafe, catch_unwind};
|
||||
|
||||
#[test]
|
||||
fn blackhat_registry_poison_recovers_with_fail_closed_reset_and_pressure_accounting() {
|
||||
let _guard = relay_idle_pressure_test_scope();
|
||||
clear_relay_idle_pressure_state_for_testing();
|
||||
|
||||
let _ = catch_unwind(AssertUnwindSafe(|| {
|
||||
let registry = relay_idle_candidate_registry();
|
||||
let mut guard = registry
|
||||
.lock()
|
||||
.expect("registry lock must be acquired before poison");
|
||||
guard.by_conn_id.insert(
|
||||
999,
|
||||
RelayIdleCandidateMeta {
|
||||
mark_order_seq: 1,
|
||||
mark_pressure_seq: 0,
|
||||
},
|
||||
);
|
||||
guard.ordered.insert((1, 999));
|
||||
panic!("intentional poison for idle-registry recovery");
|
||||
}));
|
||||
|
||||
// Helper lock must recover from poison, reset stale state, and continue.
|
||||
assert!(mark_relay_idle_candidate(42));
|
||||
assert_eq!(oldest_relay_idle_candidate(), Some(42));
|
||||
|
||||
let before = relay_pressure_event_seq();
|
||||
note_relay_pressure_event();
|
||||
let after = relay_pressure_event_seq();
|
||||
assert!(after > before, "pressure accounting must still advance after poison");
|
||||
|
||||
clear_relay_idle_pressure_state_for_testing();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn clear_state_helper_must_reset_poisoned_registry_for_deterministic_fifo_tests() {
|
||||
let _guard = relay_idle_pressure_test_scope();
|
||||
clear_relay_idle_pressure_state_for_testing();
|
||||
|
||||
let _ = catch_unwind(AssertUnwindSafe(|| {
|
||||
let registry = relay_idle_candidate_registry();
|
||||
let _guard = registry
|
||||
.lock()
|
||||
.expect("registry lock must be acquired before poison");
|
||||
panic!("intentional poison while lock held");
|
||||
}));
|
||||
|
||||
clear_relay_idle_pressure_state_for_testing();
|
||||
|
||||
assert_eq!(oldest_relay_idle_candidate(), None);
|
||||
assert_eq!(relay_pressure_event_seq(), 0);
|
||||
|
||||
assert!(mark_relay_idle_candidate(7));
|
||||
assert_eq!(oldest_relay_idle_candidate(), Some(7));
|
||||
|
||||
clear_relay_idle_pressure_state_for_testing();
|
||||
}
|
||||
|
|
@ -0,0 +1,372 @@
|
|||
use super::*;
|
||||
use crate::crypto::{AesCtr, SecureRandom};
|
||||
use crate::error::ProxyError;
|
||||
use crate::stats::Stats;
|
||||
use crate::stream::CryptoWriter;
|
||||
use bytes::Bytes;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::{Arc, OnceLock, Mutex};
|
||||
use tokio::sync::Mutex as AsyncMutex;
|
||||
use tokio::task::JoinSet;
|
||||
use tokio::time::{Duration, timeout};
|
||||
|
||||
fn make_crypto_writer<W>(writer: W) -> CryptoWriter<W>
|
||||
where
|
||||
W: tokio::io::AsyncWrite + Unpin,
|
||||
{
|
||||
let key = [0u8; 32];
|
||||
let iv = 0u128;
|
||||
CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024)
|
||||
}
|
||||
|
||||
fn lookup_test_lock() -> &'static Mutex<()> {
|
||||
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
LOCK.get_or_init(|| Mutex::new(()))
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn positive_me2c_quota_counts_bytes_exactly_once() {
|
||||
let _guard = lookup_test_lock().lock().unwrap();
|
||||
let stats = Stats::new();
|
||||
let user = format!("quota-middle-ext-positive-{}", std::process::id());
|
||||
let lock = Arc::new(AsyncMutex::new(()));
|
||||
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
|
||||
let result = process_me_writer_response_with_cross_mode_lock(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[1, 2, 3, 4, 5]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(64),
|
||||
0,
|
||||
Some(&lock),
|
||||
&bytes_me2c,
|
||||
70_001,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(stats.get_user_total_octets(&user), 5);
|
||||
assert_eq!(bytes_me2c.load(Ordering::Relaxed), 5);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn negative_held_crossmode_lock_blocks_me2c_write() {
|
||||
let _guard = lookup_test_lock().lock().unwrap();
|
||||
let stats = Stats::new();
|
||||
let user = format!("quota-middle-ext-negative-{}", std::process::id());
|
||||
|
||||
let lock = Arc::new(AsyncMutex::new(()));
|
||||
let _held = lock.try_lock().expect("lock must be held");
|
||||
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
|
||||
let blocked = timeout(
|
||||
Duration::from_millis(25),
|
||||
process_me_writer_response_with_cross_mode_lock(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0xFE]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(16),
|
||||
0,
|
||||
Some(&lock),
|
||||
&bytes_me2c,
|
||||
70_101,
|
||||
false,
|
||||
false,
|
||||
),
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(blocked.is_err());
|
||||
assert_eq!(stats.get_user_total_octets(&user), 0);
|
||||
assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn edge_zero_quota_zero_payload_is_fail_closed() {
|
||||
let _guard = lookup_test_lock().lock().unwrap();
|
||||
let stats = Stats::new();
|
||||
let user = format!("quota-middle-ext-edge-{}", std::process::id());
|
||||
|
||||
let lock = Arc::new(AsyncMutex::new(()));
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
|
||||
let result = process_me_writer_response_with_cross_mode_lock(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::new(),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(0),
|
||||
0,
|
||||
Some(&lock),
|
||||
&bytes_me2c,
|
||||
70_201,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. })));
|
||||
assert_eq!(stats.get_user_total_octets(&user), 0);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn adversarial_parallel_me2c_race_falls_back_to_quota_error() {
|
||||
let _guard = lookup_test_lock().lock().unwrap();
|
||||
let stats = Arc::new(Stats::new());
|
||||
let user = format!("quota-middle-ext-blackhat-{}", std::process::id());
|
||||
let quota = 64u64;
|
||||
let lock = Arc::new(AsyncMutex::new(()));
|
||||
let bytes_me2c = Arc::new(AtomicU64::new(0));
|
||||
|
||||
let mut set = JoinSet::new();
|
||||
for i in 0..256u64 {
|
||||
let stats = Arc::clone(&stats);
|
||||
let user = user.clone();
|
||||
let lock = Arc::clone(&lock);
|
||||
let bytes_me2c = Arc::clone(&bytes_me2c);
|
||||
|
||||
set.spawn(async move {
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let payload = vec![((i & 0xFF) as u8); (i % 4 + 1) as usize];
|
||||
|
||||
process_me_writer_response_with_cross_mode_lock(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from(payload),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
stats.as_ref(),
|
||||
&user,
|
||||
Some(quota),
|
||||
0,
|
||||
Some(&lock),
|
||||
bytes_me2c.as_ref(),
|
||||
70_301 + i,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await
|
||||
});
|
||||
}
|
||||
|
||||
let mut succeeded = 0usize;
|
||||
while let Some(done) = set.join_next().await {
|
||||
match done.expect("task must not panic") {
|
||||
Ok(_) => succeeded += 1,
|
||||
Err(ProxyError::DataQuotaExceeded { .. }) => {}
|
||||
Err(other) => panic!("unexpected error {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(stats.get_user_total_octets(&user), bytes_me2c.load(Ordering::Relaxed));
|
||||
assert!(stats.get_user_total_octets(&user) <= quota);
|
||||
assert!(succeeded <= quota as usize);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn integration_shared_prefetched_lock_blocks_then_releases_writer() {
|
||||
let stats = Stats::new();
|
||||
let user = format!("quota-middle-ext-integration-{}", std::process::id());
|
||||
let lock = Arc::new(AsyncMutex::new(()));
|
||||
let held = lock
|
||||
.try_lock()
|
||||
.expect("integration test must hold prefetched lock first");
|
||||
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
|
||||
let blocked = timeout(
|
||||
Duration::from_millis(25),
|
||||
process_me_writer_response_with_cross_mode_lock(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0xA1]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(8),
|
||||
0,
|
||||
Some(&lock),
|
||||
&bytes_me2c,
|
||||
70_360,
|
||||
false,
|
||||
false,
|
||||
),
|
||||
)
|
||||
.await;
|
||||
assert!(blocked.is_err());
|
||||
|
||||
drop(held);
|
||||
|
||||
let after_release = timeout(
|
||||
Duration::from_millis(150),
|
||||
process_me_writer_response_with_cross_mode_lock(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0xA2]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(8),
|
||||
0,
|
||||
Some(&lock),
|
||||
&bytes_me2c,
|
||||
70_361,
|
||||
false,
|
||||
false,
|
||||
),
|
||||
)
|
||||
.await
|
||||
.expect("writer should progress once the shared lock is released");
|
||||
|
||||
assert!(after_release.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn light_fuzz_small_payloads_toggle_lock_state_stays_consistent() {
|
||||
let _guard = lookup_test_lock().lock().unwrap();
|
||||
let stats = Stats::new();
|
||||
let user = format!("quota-middle-ext-fuzz-{}", std::process::id());
|
||||
let mut seed = 0xCAFE_BABE_1234u64;
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
|
||||
for case in 0..48u32 {
|
||||
seed ^= seed << 5;
|
||||
seed ^= seed >> 12;
|
||||
seed ^= seed << 13;
|
||||
let hold = (seed & 0x1) == 0;
|
||||
|
||||
let lock = Arc::new(AsyncMutex::new(()));
|
||||
let maybe_guard = if hold {
|
||||
Some(lock.try_lock().unwrap())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
|
||||
let result = timeout(
|
||||
Duration::from_millis(30),
|
||||
process_me_writer_response_with_cross_mode_lock(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from(vec![(seed & 0xFF) as u8; ((seed as usize % 5) + 1)]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(128),
|
||||
0,
|
||||
Some(&lock),
|
||||
&bytes_me2c,
|
||||
70_401 + case as u64,
|
||||
false,
|
||||
false,
|
||||
),
|
||||
)
|
||||
.await;
|
||||
|
||||
if hold {
|
||||
assert!(result.is_err());
|
||||
} else {
|
||||
assert!(result.unwrap().is_ok());
|
||||
}
|
||||
|
||||
drop(maybe_guard);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn stress_parallel_free_users_during_held_user_lock_maintains_liveness() {
|
||||
let _guard = lookup_test_lock().lock().unwrap();
|
||||
let held = Arc::new(AsyncMutex::new(()));
|
||||
let _held_guard = held.try_lock().unwrap();
|
||||
|
||||
let mut set = JoinSet::new();
|
||||
for i in 0..48u64 {
|
||||
set.spawn(async move {
|
||||
let stats = Stats::new();
|
||||
let user = format!("quota-middle-ext-stress-free-{i}");
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
let free_lock = Arc::new(AsyncMutex::new(()));
|
||||
|
||||
process_me_writer_response_with_cross_mode_lock(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0xEE]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(1),
|
||||
0,
|
||||
Some(&free_lock),
|
||||
&bytes_me2c,
|
||||
70_500 + i,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await
|
||||
});
|
||||
}
|
||||
|
||||
timeout(Duration::from_secs(2), async {
|
||||
while let Some(task) = set.join_next().await {
|
||||
task.unwrap().unwrap();
|
||||
}
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,399 @@
|
|||
use super::*;
|
||||
use crate::crypto::{AesCtr, SecureRandom};
|
||||
use crate::stats::Stats;
|
||||
use crate::stream::CryptoWriter;
|
||||
use bytes::Bytes;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::{Arc, Mutex, OnceLock};
|
||||
use tokio::sync::Mutex as AsyncMutex;
|
||||
use tokio::task::JoinSet;
|
||||
use tokio::time::{Duration, timeout};
|
||||
|
||||
fn make_crypto_writer<W>(writer: W) -> CryptoWriter<W>
|
||||
where
|
||||
W: tokio::io::AsyncWrite + Unpin,
|
||||
{
|
||||
let key = [0u8; 32];
|
||||
let iv = 0u128;
|
||||
CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024)
|
||||
}
|
||||
|
||||
fn lookup_counter_test_lock() -> &'static Mutex<()> {
|
||||
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
LOCK.get_or_init(|| Mutex::new(()))
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn positive_prefetched_cross_mode_lock_multi_frame_accounting_is_exact() {
|
||||
let _guard = lookup_counter_test_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(|poison| poison.into_inner());
|
||||
|
||||
let stats = Stats::new();
|
||||
let user = format!("quota-extreme-positive-{}", std::process::id());
|
||||
let lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
|
||||
crate::proxy::quota_lock_registry::reset_cross_mode_quota_user_lock_lookup_count_for_tests();
|
||||
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
|
||||
for idx in 0..12u64 {
|
||||
let payload = vec![0x5A; ((idx % 4) + 1) as usize];
|
||||
let result = process_me_writer_response_with_cross_mode_lock(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from(payload),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(512),
|
||||
0,
|
||||
Some(&lock),
|
||||
&bytes_me2c,
|
||||
31_000 + idx,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
crate::proxy::quota_lock_registry::cross_mode_quota_user_lock_lookup_count_for_user_for_tests(&user),
|
||||
0,
|
||||
"prefetched lock path must avoid hot-path registry lookups"
|
||||
);
|
||||
assert_eq!(
|
||||
stats.get_user_total_octets(&user),
|
||||
bytes_me2c.load(Ordering::Relaxed),
|
||||
"forensics and quota accounting must remain synchronized"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn negative_held_prefetched_lock_blocks_writer_without_accounting_mutation() {
|
||||
let _guard = lookup_counter_test_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(|poison| poison.into_inner());
|
||||
|
||||
let stats = Stats::new();
|
||||
let user = format!("quota-extreme-negative-{}", std::process::id());
|
||||
let lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
|
||||
let held_guard = lock
|
||||
.try_lock()
|
||||
.expect("test must hold lock before calling ME->C writer");
|
||||
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
|
||||
let blocked = timeout(
|
||||
Duration::from_millis(25),
|
||||
process_me_writer_response_with_cross_mode_lock(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[1, 2, 3]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(64),
|
||||
0,
|
||||
Some(&lock),
|
||||
&bytes_me2c,
|
||||
31_100,
|
||||
false,
|
||||
false,
|
||||
),
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(blocked.is_err());
|
||||
assert_eq!(stats.get_user_total_octets(&user), 0);
|
||||
assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0);
|
||||
|
||||
drop(held_guard);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn edge_zero_quota_and_zero_payload_is_fail_closed() {
|
||||
let _guard = lookup_counter_test_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(|poison| poison.into_inner());
|
||||
|
||||
let stats = Stats::new();
|
||||
let user = format!("quota-extreme-edge-{}", std::process::id());
|
||||
let lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
|
||||
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
|
||||
let result = process_me_writer_response_with_cross_mode_lock(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::new(),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(0),
|
||||
0,
|
||||
Some(&lock),
|
||||
&bytes_me2c,
|
||||
31_200,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. })));
|
||||
assert_eq!(stats.get_user_total_octets(&user), 0);
|
||||
assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn adversarial_blackhat_parallel_quota_race_never_overshoots_soft_cap() {
|
||||
let _guard = lookup_counter_test_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(|poison| poison.into_inner());
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
let user = format!("quota-extreme-blackhat-{}", std::process::id());
|
||||
let quota = 80u64;
|
||||
let overshoot = 7u64;
|
||||
let soft_limit = quota + overshoot;
|
||||
let lock = Arc::new(crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user));
|
||||
let bytes_me2c = Arc::new(AtomicU64::new(0));
|
||||
|
||||
let mut set = JoinSet::new();
|
||||
for idx in 0..256u64 {
|
||||
let stats = Arc::clone(&stats);
|
||||
let user = user.clone();
|
||||
let lock = Arc::clone(&lock);
|
||||
let bytes_me2c = Arc::clone(&bytes_me2c);
|
||||
|
||||
set.spawn(async move {
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let len = ((idx % 5) + 1) as usize;
|
||||
let payload = vec![0xAA; len];
|
||||
|
||||
process_me_writer_response_with_cross_mode_lock(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from(payload),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
stats.as_ref(),
|
||||
&user,
|
||||
Some(quota),
|
||||
overshoot,
|
||||
Some(&lock),
|
||||
bytes_me2c.as_ref(),
|
||||
31_300 + idx,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await
|
||||
});
|
||||
}
|
||||
|
||||
while let Some(done) = set.join_next().await {
|
||||
match done.expect("task must not panic") {
|
||||
Ok(_) | Err(ProxyError::DataQuotaExceeded { .. }) => {}
|
||||
Err(other) => panic!("unexpected error variant under black-hat race: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
let total = stats.get_user_total_octets(&user);
|
||||
assert!(
|
||||
total <= soft_limit,
|
||||
"parallel adversarial race must stay under soft cap"
|
||||
);
|
||||
assert_eq!(bytes_me2c.load(Ordering::Relaxed), total);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn integration_without_prefetched_lock_uses_registry_lookup_path() {
|
||||
let _guard = lookup_counter_test_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(|poison| poison.into_inner());
|
||||
|
||||
let stats = Stats::new();
|
||||
let user = format!("quota-extreme-integration-{}", std::process::id());
|
||||
crate::proxy::quota_lock_registry::reset_cross_mode_quota_user_lock_lookup_count_for_tests();
|
||||
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
|
||||
for idx in 0..3u64 {
|
||||
let result = process_me_writer_response_with_cross_mode_lock(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0x41]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(16),
|
||||
0,
|
||||
None,
|
||||
&bytes_me2c,
|
||||
31_400 + idx,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
crate::proxy::quota_lock_registry::cross_mode_quota_user_lock_lookup_count_for_user_for_tests(&user),
|
||||
3,
|
||||
"control path should perform one lock-registry lookup per call"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn light_fuzz_quota_matrix_preserves_fail_closed_accounting() {
|
||||
let _guard = lookup_counter_test_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(|poison| poison.into_inner());
|
||||
|
||||
let stats = Stats::new();
|
||||
let user = format!("quota-extreme-fuzz-{}", std::process::id());
|
||||
let lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
let mut seed = 0xA11C_55EE_2026_0323u64;
|
||||
|
||||
for idx in 0..512u64 {
|
||||
seed ^= seed << 7;
|
||||
seed ^= seed >> 9;
|
||||
seed ^= seed << 8;
|
||||
|
||||
let quota = 24 + (seed & 0x3f);
|
||||
let overshoot = (seed >> 13) & 0x0f;
|
||||
let len = ((seed >> 19) & 0x07) + 1;
|
||||
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let before = stats.get_user_total_octets(&user);
|
||||
|
||||
let result = process_me_writer_response_with_cross_mode_lock(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from(vec![0x11; len as usize]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(quota),
|
||||
overshoot,
|
||||
Some(&lock),
|
||||
&bytes_me2c,
|
||||
31_500 + idx,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await;
|
||||
|
||||
let after = stats.get_user_total_octets(&user);
|
||||
if result.is_ok() {
|
||||
assert!(after >= before);
|
||||
} else {
|
||||
assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. })));
|
||||
assert_eq!(after, before);
|
||||
}
|
||||
assert_eq!(bytes_me2c.load(Ordering::Relaxed), after);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn stress_prefetched_lock_high_fanout_exact_quota_success_count() {
|
||||
let _guard = lookup_counter_test_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(|poison| poison.into_inner());
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
let user = format!("quota-extreme-stress-{}", std::process::id());
|
||||
let quota = 96u64;
|
||||
let lock: Arc<AsyncMutex<()>> = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
|
||||
let bytes_me2c = Arc::new(AtomicU64::new(0));
|
||||
|
||||
crate::proxy::quota_lock_registry::reset_cross_mode_quota_user_lock_lookup_count_for_tests();
|
||||
|
||||
let mut set = JoinSet::new();
|
||||
for idx in 0..384u64 {
|
||||
let stats = Arc::clone(&stats);
|
||||
let user = user.clone();
|
||||
let lock = Arc::clone(&lock);
|
||||
let bytes_me2c = Arc::clone(&bytes_me2c);
|
||||
|
||||
set.spawn(async move {
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
process_me_writer_response_with_cross_mode_lock(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0xFF]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
stats.as_ref(),
|
||||
&user,
|
||||
Some(quota),
|
||||
0,
|
||||
Some(&lock),
|
||||
bytes_me2c.as_ref(),
|
||||
31_600 + idx,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await
|
||||
});
|
||||
}
|
||||
|
||||
let mut success = 0usize;
|
||||
while let Some(done) = set.join_next().await {
|
||||
match done.expect("task must not panic") {
|
||||
Ok(_) => success += 1,
|
||||
Err(ProxyError::DataQuotaExceeded { .. }) => {}
|
||||
Err(other) => panic!("unexpected error variant in stress fanout: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(success, quota as usize);
|
||||
assert_eq!(stats.get_user_total_octets(&user), quota);
|
||||
assert_eq!(bytes_me2c.load(Ordering::Relaxed), quota);
|
||||
assert_eq!(
|
||||
crate::proxy::quota_lock_registry::cross_mode_quota_user_lock_lookup_count_for_user_for_tests(&user),
|
||||
0,
|
||||
"stress prefetched path must not use lock registry lookups"
|
||||
);
|
||||
}
|
||||
|
|
@ -0,0 +1,361 @@
|
|||
use super::*;
|
||||
use crate::crypto::AesCtr;
|
||||
use crate::stats::Stats;
|
||||
use crate::stream::{BufferPool, CryptoReader};
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicU64;
|
||||
use std::time::Instant;
|
||||
use tokio::io::{AsyncRead, AsyncWriteExt, duplex};
|
||||
use tokio::task::JoinSet;
|
||||
use tokio::time::{Duration as TokioDuration, sleep};
|
||||
|
||||
fn make_crypto_reader<T>(reader: T) -> CryptoReader<T>
|
||||
where
|
||||
T: AsyncRead + Unpin + Send + 'static,
|
||||
{
|
||||
let key = [0u8; 32];
|
||||
let iv = 0u128;
|
||||
CryptoReader::new(reader, AesCtr::new(&key, iv))
|
||||
}
|
||||
|
||||
fn encrypt_for_reader(plaintext: &[u8]) -> Vec<u8> {
|
||||
let key = [0u8; 32];
|
||||
let iv = 0u128;
|
||||
let mut cipher = AesCtr::new(&key, iv);
|
||||
cipher.encrypt(plaintext)
|
||||
}
|
||||
|
||||
fn make_forensics(conn_id: u64, started_at: Instant) -> RelayForensicsState {
|
||||
RelayForensicsState {
|
||||
trace_id: 0xB200_0000 + conn_id,
|
||||
conn_id,
|
||||
user: format!("tiny-frame-debt-concurrency-user-{conn_id}"),
|
||||
peer: "127.0.0.1:50000".parse().expect("peer parse must succeed"),
|
||||
peer_hash: hash_ip("127.0.0.1".parse().expect("ip parse must succeed")),
|
||||
started_at,
|
||||
bytes_c2me: 0,
|
||||
bytes_me2c: Arc::new(AtomicU64::new(0)),
|
||||
desync_all_full: false,
|
||||
}
|
||||
}
|
||||
|
||||
fn make_enabled_idle_policy() -> RelayClientIdlePolicy {
|
||||
RelayClientIdlePolicy {
|
||||
enabled: true,
|
||||
soft_idle: Duration::from_millis(50),
|
||||
hard_idle: Duration::from_millis(120),
|
||||
grace_after_downstream_activity: Duration::from_secs(0),
|
||||
legacy_frame_read_timeout: Duration::from_millis(50),
|
||||
}
|
||||
}
|
||||
|
||||
async fn read_once(
|
||||
crypto_reader: &mut CryptoReader<tokio::io::DuplexStream>,
|
||||
proto: ProtoTag,
|
||||
forensics: &RelayForensicsState,
|
||||
frame_counter: &mut u64,
|
||||
idle_state: &mut RelayClientIdleState,
|
||||
) -> Result<Option<(PooledBuffer, bool)>> {
|
||||
let buffer_pool = Arc::new(BufferPool::new());
|
||||
let stats = Stats::new();
|
||||
let idle_policy = make_enabled_idle_policy();
|
||||
let last_downstream_activity_ms = AtomicU64::new(0);
|
||||
read_client_payload_with_idle_policy(
|
||||
crypto_reader,
|
||||
proto,
|
||||
1024,
|
||||
&buffer_pool,
|
||||
forensics,
|
||||
frame_counter,
|
||||
&stats,
|
||||
&idle_policy,
|
||||
idle_state,
|
||||
&last_downstream_activity_ms,
|
||||
forensics.started_at,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn stress_parallel_pure_tiny_floods_all_fail_closed() {
|
||||
let mut set = JoinSet::new();
|
||||
|
||||
for idx in 0..32u64 {
|
||||
set.spawn(async move {
|
||||
let (reader, mut writer) = duplex(4096);
|
||||
let mut crypto_reader = make_crypto_reader(reader);
|
||||
let started = Instant::now();
|
||||
let forensics = make_forensics(1000 + idx, started);
|
||||
let mut frame_counter = 0u64;
|
||||
let mut idle_state = RelayClientIdleState::new(started);
|
||||
|
||||
let flood_plaintext = vec![0u8; 1024];
|
||||
let flood_encrypted = encrypt_for_reader(&flood_plaintext);
|
||||
writer.write_all(&flood_encrypted).await.unwrap();
|
||||
drop(writer);
|
||||
|
||||
let result = run_relay_test_step_timeout(
|
||||
"tiny flood task",
|
||||
read_once(
|
||||
&mut crypto_reader,
|
||||
ProtoTag::Abridged,
|
||||
&forensics,
|
||||
&mut frame_counter,
|
||||
&mut idle_state,
|
||||
),
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(matches!(result, Err(ProxyError::Proxy(_))));
|
||||
assert_eq!(frame_counter, 0);
|
||||
});
|
||||
}
|
||||
|
||||
while let Some(result) = set.join_next().await {
|
||||
result.expect("parallel tiny flood worker must not panic");
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn stress_parallel_benign_tiny_burst_then_real_all_pass() {
|
||||
let mut set = JoinSet::new();
|
||||
|
||||
for idx in 0..24u64 {
|
||||
set.spawn(async move {
|
||||
let (reader, mut writer) = duplex(2048);
|
||||
let mut crypto_reader = make_crypto_reader(reader);
|
||||
let started = Instant::now();
|
||||
let forensics = make_forensics(2000 + idx, started);
|
||||
let mut frame_counter = 0u64;
|
||||
let mut idle_state = RelayClientIdleState::new(started);
|
||||
|
||||
let payload = [idx as u8, 2, 3, 4];
|
||||
let mut plaintext = Vec::with_capacity(20);
|
||||
for _ in 0..6 {
|
||||
plaintext.push(0x00);
|
||||
}
|
||||
plaintext.push(0x01);
|
||||
plaintext.extend_from_slice(&payload);
|
||||
let encrypted = encrypt_for_reader(&plaintext);
|
||||
writer.write_all(&encrypted).await.unwrap();
|
||||
|
||||
let result = run_relay_test_step_timeout(
|
||||
"benign tiny burst read",
|
||||
read_once(
|
||||
&mut crypto_reader,
|
||||
ProtoTag::Abridged,
|
||||
&forensics,
|
||||
&mut frame_counter,
|
||||
&mut idle_state,
|
||||
),
|
||||
)
|
||||
.await
|
||||
.expect("benign payload must parse")
|
||||
.expect("benign payload must return frame");
|
||||
|
||||
assert_eq!(result.0.as_ref(), &payload);
|
||||
assert_eq!(frame_counter, 1);
|
||||
});
|
||||
}
|
||||
|
||||
while let Some(result) = set.join_next().await {
|
||||
result.expect("parallel benign worker must not panic");
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn adversarial_lockstep_alternating_attack_under_jitter_closes() {
|
||||
let mut set = JoinSet::new();
|
||||
|
||||
for idx in 0..12u64 {
|
||||
set.spawn(async move {
|
||||
let (reader, mut writer) = duplex(8192);
|
||||
let mut crypto_reader = make_crypto_reader(reader);
|
||||
let started = Instant::now();
|
||||
let forensics = make_forensics(3000 + idx, started);
|
||||
let mut frame_counter = 0u64;
|
||||
let mut idle_state = RelayClientIdleState::new(started);
|
||||
|
||||
let mut plaintext = Vec::with_capacity(2000);
|
||||
for n in 0..180u8 {
|
||||
plaintext.push(0x00);
|
||||
plaintext.push(0x01);
|
||||
plaintext.extend_from_slice(&[n, n ^ 0x21, n ^ 0x42, n ^ 0x84]);
|
||||
}
|
||||
let encrypted = encrypt_for_reader(&plaintext);
|
||||
|
||||
let writer_task = tokio::spawn(async move {
|
||||
for chunk in encrypted.chunks(17) {
|
||||
writer.write_all(chunk).await.unwrap();
|
||||
sleep(TokioDuration::from_millis(1)).await;
|
||||
}
|
||||
drop(writer);
|
||||
});
|
||||
|
||||
let mut closed = false;
|
||||
for _ in 0..220 {
|
||||
let result = run_relay_test_step_timeout(
|
||||
"alternating jitter read step",
|
||||
read_once(
|
||||
&mut crypto_reader,
|
||||
ProtoTag::Abridged,
|
||||
&forensics,
|
||||
&mut frame_counter,
|
||||
&mut idle_state,
|
||||
),
|
||||
)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(Some((_payload, _))) => {}
|
||||
Err(ProxyError::Proxy(_)) => {
|
||||
closed = true;
|
||||
break;
|
||||
}
|
||||
Ok(None) => break,
|
||||
Err(other) => panic!("unexpected error in alternating jitter case: {other}"),
|
||||
}
|
||||
}
|
||||
|
||||
writer_task.await.expect("writer jitter task must not panic");
|
||||
assert!(closed, "alternating attack must close before EOF");
|
||||
});
|
||||
}
|
||||
|
||||
while let Some(result) = set.join_next().await {
|
||||
result.expect("alternating jitter worker must not panic");
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn integration_mixed_population_attackers_close_benign_survive() {
|
||||
let mut set = JoinSet::new();
|
||||
|
||||
for idx in 0..20u64 {
|
||||
set.spawn(async move {
|
||||
let (reader, mut writer) = duplex(4096);
|
||||
let mut crypto_reader = make_crypto_reader(reader);
|
||||
let started = Instant::now();
|
||||
let forensics = make_forensics(4000 + idx, started);
|
||||
let mut frame_counter = 0u64;
|
||||
let mut idle_state = RelayClientIdleState::new(started);
|
||||
|
||||
if idx % 2 == 0 {
|
||||
let mut plaintext = Vec::with_capacity(1280);
|
||||
for n in 0..140u8 {
|
||||
plaintext.push(0x00);
|
||||
plaintext.push(0x01);
|
||||
plaintext.extend_from_slice(&[n, n, n, n]);
|
||||
}
|
||||
writer.write_all(&encrypt_for_reader(&plaintext)).await.unwrap();
|
||||
drop(writer);
|
||||
|
||||
let mut closed = false;
|
||||
for _ in 0..200 {
|
||||
match read_once(
|
||||
&mut crypto_reader,
|
||||
ProtoTag::Abridged,
|
||||
&forensics,
|
||||
&mut frame_counter,
|
||||
&mut idle_state,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(Some(_)) => {}
|
||||
Err(ProxyError::Proxy(_)) => {
|
||||
closed = true;
|
||||
break;
|
||||
}
|
||||
Ok(None) => break,
|
||||
Err(other) => panic!("unexpected attacker error: {other}"),
|
||||
}
|
||||
}
|
||||
assert!(closed, "attacker session must fail closed");
|
||||
} else {
|
||||
let payload = [1u8, 9, 8, 7];
|
||||
let mut plaintext = Vec::new();
|
||||
for _ in 0..4 {
|
||||
plaintext.push(0x00);
|
||||
}
|
||||
plaintext.push(0x01);
|
||||
plaintext.extend_from_slice(&payload);
|
||||
writer.write_all(&encrypt_for_reader(&plaintext)).await.unwrap();
|
||||
|
||||
let got = read_once(
|
||||
&mut crypto_reader,
|
||||
ProtoTag::Abridged,
|
||||
&forensics,
|
||||
&mut frame_counter,
|
||||
&mut idle_state,
|
||||
)
|
||||
.await
|
||||
.expect("benign session must parse")
|
||||
.expect("benign session must return a frame");
|
||||
assert_eq!(got.0.as_ref(), &payload);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
while let Some(result) = set.join_next().await {
|
||||
result.expect("mixed-population worker must not panic");
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn light_fuzz_parallel_patterns_no_hang_or_panic() {
|
||||
let mut set = JoinSet::new();
|
||||
|
||||
for case in 0..40u64 {
|
||||
set.spawn(async move {
|
||||
let (reader, mut writer) = duplex(8192);
|
||||
let mut crypto_reader = make_crypto_reader(reader);
|
||||
let started = Instant::now();
|
||||
let forensics = make_forensics(5000 + case, started);
|
||||
let mut frame_counter = 0u64;
|
||||
let mut idle_state = RelayClientIdleState::new(started);
|
||||
|
||||
let mut seed = 0x9E37_79B9u64 ^ (case << 8);
|
||||
let mut plaintext = Vec::with_capacity(2048);
|
||||
for _ in 0..256 {
|
||||
seed ^= seed << 7;
|
||||
seed ^= seed >> 9;
|
||||
seed ^= seed << 8;
|
||||
let is_tiny = (seed & 1) == 0;
|
||||
if is_tiny {
|
||||
plaintext.push(0x00);
|
||||
} else {
|
||||
plaintext.push(0x01);
|
||||
plaintext.extend_from_slice(&[(seed >> 8) as u8, 2, 3, 4]);
|
||||
}
|
||||
}
|
||||
|
||||
writer.write_all(&encrypt_for_reader(&plaintext)).await.unwrap();
|
||||
drop(writer);
|
||||
|
||||
for _ in 0..320 {
|
||||
let step = run_relay_test_step_timeout(
|
||||
"fuzz case read step",
|
||||
read_once(
|
||||
&mut crypto_reader,
|
||||
ProtoTag::Abridged,
|
||||
&forensics,
|
||||
&mut frame_counter,
|
||||
&mut idle_state,
|
||||
),
|
||||
)
|
||||
.await;
|
||||
|
||||
match step {
|
||||
Ok(Some(_)) => {}
|
||||
Err(ProxyError::Proxy(_)) => break,
|
||||
Ok(None) => break,
|
||||
Err(other) => panic!("unexpected fuzz case error: {other}"),
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
while let Some(result) = set.join_next().await {
|
||||
result.expect("fuzz worker must not panic");
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,425 @@
|
|||
use super::*;
|
||||
use crate::crypto::AesCtr;
|
||||
use crate::stats::Stats;
|
||||
use crate::stream::{BufferPool, CryptoReader, PooledBuffer};
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicU64;
|
||||
use std::time::Instant;
|
||||
use tokio::io::{AsyncRead, AsyncWriteExt, duplex};
|
||||
use tokio::time::{Duration as TokioDuration, sleep};
|
||||
|
||||
fn make_crypto_reader<T>(reader: T) -> CryptoReader<T>
|
||||
where
|
||||
T: AsyncRead + Unpin + Send + 'static,
|
||||
{
|
||||
let key = [0u8; 32];
|
||||
let iv = 0u128;
|
||||
CryptoReader::new(reader, AesCtr::new(&key, iv))
|
||||
}
|
||||
|
||||
fn encrypt_for_reader(plaintext: &[u8]) -> Vec<u8> {
|
||||
let key = [0u8; 32];
|
||||
let iv = 0u128;
|
||||
let mut cipher = AesCtr::new(&key, iv);
|
||||
cipher.encrypt(plaintext)
|
||||
}
|
||||
|
||||
fn make_forensics(conn_id: u64, started_at: Instant) -> RelayForensicsState {
|
||||
RelayForensicsState {
|
||||
trace_id: 0xB300_0000 + conn_id,
|
||||
conn_id,
|
||||
user: format!("tiny-frame-debt-proto-chunk-user-{conn_id}"),
|
||||
peer: "127.0.0.1:50000".parse().expect("peer parse must succeed"),
|
||||
peer_hash: hash_ip("127.0.0.1".parse().expect("ip parse must succeed")),
|
||||
started_at,
|
||||
bytes_c2me: 0,
|
||||
bytes_me2c: Arc::new(AtomicU64::new(0)),
|
||||
desync_all_full: false,
|
||||
}
|
||||
}
|
||||
|
||||
fn make_enabled_idle_policy() -> RelayClientIdlePolicy {
|
||||
RelayClientIdlePolicy {
|
||||
enabled: true,
|
||||
soft_idle: Duration::from_millis(50),
|
||||
hard_idle: Duration::from_millis(120),
|
||||
grace_after_downstream_activity: Duration::from_secs(0),
|
||||
legacy_frame_read_timeout: Duration::from_millis(50),
|
||||
}
|
||||
}
|
||||
|
||||
fn append_tiny_frame(plaintext: &mut Vec<u8>, proto: ProtoTag) {
|
||||
match proto {
|
||||
ProtoTag::Abridged => plaintext.push(0x00),
|
||||
ProtoTag::Intermediate | ProtoTag::Secure => plaintext.extend_from_slice(&0u32.to_le_bytes()),
|
||||
}
|
||||
}
|
||||
|
||||
fn append_real_frame(plaintext: &mut Vec<u8>, proto: ProtoTag, payload: [u8; 4]) {
|
||||
match proto {
|
||||
ProtoTag::Abridged => {
|
||||
plaintext.push(0x01);
|
||||
plaintext.extend_from_slice(&payload);
|
||||
}
|
||||
ProtoTag::Intermediate | ProtoTag::Secure => {
|
||||
plaintext.extend_from_slice(&4u32.to_le_bytes());
|
||||
plaintext.extend_from_slice(&payload);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn write_chunked_with_jitter(
|
||||
writer: &mut tokio::io::DuplexStream,
|
||||
bytes: &[u8],
|
||||
mut seed: u64,
|
||||
) {
|
||||
let mut offset = 0usize;
|
||||
while offset < bytes.len() {
|
||||
seed ^= seed << 7;
|
||||
seed ^= seed >> 9;
|
||||
seed ^= seed << 8;
|
||||
let chunk_len = 1 + ((seed as usize) & 0x1f);
|
||||
let end = (offset + chunk_len).min(bytes.len());
|
||||
writer.write_all(&bytes[offset..end]).await.unwrap();
|
||||
|
||||
let delay_ms = ((seed >> 16) % 3) as u64;
|
||||
if delay_ms > 0 {
|
||||
sleep(TokioDuration::from_millis(delay_ms)).await;
|
||||
}
|
||||
offset = end;
|
||||
}
|
||||
}
|
||||
|
||||
async fn read_once_with_state(
|
||||
crypto_reader: &mut CryptoReader<tokio::io::DuplexStream>,
|
||||
proto: ProtoTag,
|
||||
forensics: &RelayForensicsState,
|
||||
frame_counter: &mut u64,
|
||||
idle_state: &mut RelayClientIdleState,
|
||||
) -> Result<Option<(PooledBuffer, bool)>> {
|
||||
let buffer_pool = Arc::new(BufferPool::new());
|
||||
let stats = Stats::new();
|
||||
let idle_policy = make_enabled_idle_policy();
|
||||
let last_downstream_activity_ms = AtomicU64::new(0);
|
||||
read_client_payload_with_idle_policy(
|
||||
crypto_reader,
|
||||
proto,
|
||||
1024,
|
||||
&buffer_pool,
|
||||
forensics,
|
||||
frame_counter,
|
||||
&stats,
|
||||
&idle_policy,
|
||||
idle_state,
|
||||
&last_downstream_activity_ms,
|
||||
forensics.started_at,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
fn is_fail_closed_outcome(result: &Result<Option<(PooledBuffer, bool)>>) -> bool {
|
||||
matches!(result, Err(ProxyError::Proxy(_)))
|
||||
|| matches!(result, Err(ProxyError::Io(e)) if e.kind() == std::io::ErrorKind::TimedOut)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn intermediate_chunked_zero_flood_fail_closed() {
|
||||
let (reader, mut writer) = duplex(4096);
|
||||
let mut crypto_reader = make_crypto_reader(reader);
|
||||
let started = Instant::now();
|
||||
let forensics = make_forensics(6101, started);
|
||||
let mut frame_counter = 0u64;
|
||||
let mut idle_state = RelayClientIdleState::new(started);
|
||||
|
||||
let mut plaintext = Vec::with_capacity(4 * 256);
|
||||
for _ in 0..256 {
|
||||
append_tiny_frame(&mut plaintext, ProtoTag::Intermediate);
|
||||
}
|
||||
let encrypted = encrypt_for_reader(&plaintext);
|
||||
write_chunked_with_jitter(&mut writer, &encrypted, 0x1111_2222).await;
|
||||
drop(writer);
|
||||
|
||||
let result = run_relay_test_step_timeout(
|
||||
"intermediate flood read",
|
||||
read_once_with_state(
|
||||
&mut crypto_reader,
|
||||
ProtoTag::Intermediate,
|
||||
&forensics,
|
||||
&mut frame_counter,
|
||||
&mut idle_state,
|
||||
),
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(
|
||||
is_fail_closed_outcome(&result),
|
||||
"zero-length flood must fail closed via debt guard or idle timeout"
|
||||
);
|
||||
assert_eq!(frame_counter, 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn secure_chunked_zero_flood_fail_closed() {
|
||||
let (reader, mut writer) = duplex(4096);
|
||||
let mut crypto_reader = make_crypto_reader(reader);
|
||||
let started = Instant::now();
|
||||
let forensics = make_forensics(6102, started);
|
||||
let mut frame_counter = 0u64;
|
||||
let mut idle_state = RelayClientIdleState::new(started);
|
||||
|
||||
let mut plaintext = Vec::with_capacity(4 * 256);
|
||||
for _ in 0..256 {
|
||||
append_tiny_frame(&mut plaintext, ProtoTag::Secure);
|
||||
}
|
||||
let encrypted = encrypt_for_reader(&plaintext);
|
||||
write_chunked_with_jitter(&mut writer, &encrypted, 0x3333_4444).await;
|
||||
drop(writer);
|
||||
|
||||
let result = run_relay_test_step_timeout(
|
||||
"secure flood read",
|
||||
read_once_with_state(
|
||||
&mut crypto_reader,
|
||||
ProtoTag::Secure,
|
||||
&forensics,
|
||||
&mut frame_counter,
|
||||
&mut idle_state,
|
||||
),
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(
|
||||
is_fail_closed_outcome(&result),
|
||||
"secure zero-length flood must fail closed via debt guard or idle timeout"
|
||||
);
|
||||
assert_eq!(frame_counter, 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn intermediate_chunked_alternating_attack_closes_before_eof() {
|
||||
let (reader, mut writer) = duplex(8192);
|
||||
let mut crypto_reader = make_crypto_reader(reader);
|
||||
let started = Instant::now();
|
||||
let forensics = make_forensics(6103, started);
|
||||
let mut frame_counter = 0u64;
|
||||
let mut idle_state = RelayClientIdleState::new(started);
|
||||
|
||||
let mut plaintext = Vec::with_capacity(8 * 200);
|
||||
for n in 0..180u8 {
|
||||
append_tiny_frame(&mut plaintext, ProtoTag::Intermediate);
|
||||
append_real_frame(&mut plaintext, ProtoTag::Intermediate, [n, n ^ 1, n ^ 2, n ^ 3]);
|
||||
}
|
||||
let encrypted = encrypt_for_reader(&plaintext);
|
||||
|
||||
let writer_task = tokio::spawn(async move {
|
||||
write_chunked_with_jitter(&mut writer, &encrypted, 0x5555_6666).await;
|
||||
drop(writer);
|
||||
});
|
||||
|
||||
let mut closed = false;
|
||||
for _ in 0..240 {
|
||||
let step = run_relay_test_step_timeout(
|
||||
"intermediate alternating read step",
|
||||
read_once_with_state(
|
||||
&mut crypto_reader,
|
||||
ProtoTag::Intermediate,
|
||||
&forensics,
|
||||
&mut frame_counter,
|
||||
&mut idle_state,
|
||||
),
|
||||
)
|
||||
.await;
|
||||
|
||||
match step {
|
||||
Ok(Some(_)) => {}
|
||||
Err(ProxyError::Proxy(_)) => {
|
||||
closed = true;
|
||||
break;
|
||||
}
|
||||
Ok(None) => break,
|
||||
Err(other) => panic!("unexpected intermediate alternating error: {other}"),
|
||||
}
|
||||
}
|
||||
|
||||
writer_task.await.expect("intermediate writer task must not panic");
|
||||
assert!(closed, "intermediate alternating attack must fail closed");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn secure_chunked_alternating_attack_closes_before_eof() {
|
||||
let (reader, mut writer) = duplex(8192);
|
||||
let mut crypto_reader = make_crypto_reader(reader);
|
||||
let started = Instant::now();
|
||||
let forensics = make_forensics(6104, started);
|
||||
let mut frame_counter = 0u64;
|
||||
let mut idle_state = RelayClientIdleState::new(started);
|
||||
|
||||
let mut plaintext = Vec::with_capacity(8 * 200);
|
||||
for n in 0..180u8 {
|
||||
append_tiny_frame(&mut plaintext, ProtoTag::Secure);
|
||||
append_real_frame(&mut plaintext, ProtoTag::Secure, [n, n ^ 7, n ^ 11, n ^ 19]);
|
||||
}
|
||||
let encrypted = encrypt_for_reader(&plaintext);
|
||||
|
||||
let writer_task = tokio::spawn(async move {
|
||||
write_chunked_with_jitter(&mut writer, &encrypted, 0x7777_8888).await;
|
||||
drop(writer);
|
||||
});
|
||||
|
||||
let mut closed = false;
|
||||
for _ in 0..240 {
|
||||
let step = run_relay_test_step_timeout(
|
||||
"secure alternating read step",
|
||||
read_once_with_state(
|
||||
&mut crypto_reader,
|
||||
ProtoTag::Secure,
|
||||
&forensics,
|
||||
&mut frame_counter,
|
||||
&mut idle_state,
|
||||
),
|
||||
)
|
||||
.await;
|
||||
|
||||
match step {
|
||||
Ok(Some(_)) => {}
|
||||
Err(ProxyError::Proxy(_)) => {
|
||||
closed = true;
|
||||
break;
|
||||
}
|
||||
Ok(None) => break,
|
||||
Err(other) => panic!("unexpected secure alternating error: {other}"),
|
||||
}
|
||||
}
|
||||
|
||||
writer_task.await.expect("secure writer task must not panic");
|
||||
assert!(closed, "secure alternating attack must fail closed");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn intermediate_chunked_safe_small_burst_still_returns_real_frame() {
|
||||
let (reader, mut writer) = duplex(1024);
|
||||
let mut crypto_reader = make_crypto_reader(reader);
|
||||
let started = Instant::now();
|
||||
let forensics = make_forensics(6105, started);
|
||||
let mut frame_counter = 0u64;
|
||||
let mut idle_state = RelayClientIdleState::new(started);
|
||||
|
||||
let payload = [9u8, 8, 7, 6];
|
||||
let mut plaintext = Vec::new();
|
||||
for _ in 0..7 {
|
||||
append_tiny_frame(&mut plaintext, ProtoTag::Intermediate);
|
||||
}
|
||||
append_real_frame(&mut plaintext, ProtoTag::Intermediate, payload);
|
||||
let encrypted = encrypt_for_reader(&plaintext);
|
||||
write_chunked_with_jitter(&mut writer, &encrypted, 0xAAAA_BBBB).await;
|
||||
|
||||
let result = read_once_with_state(
|
||||
&mut crypto_reader,
|
||||
ProtoTag::Intermediate,
|
||||
&forensics,
|
||||
&mut frame_counter,
|
||||
&mut idle_state,
|
||||
)
|
||||
.await
|
||||
.expect("intermediate safe burst should parse")
|
||||
.expect("intermediate safe burst should return a frame");
|
||||
|
||||
assert_eq!(result.0.as_ref(), &payload);
|
||||
assert_eq!(frame_counter, 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn secure_chunked_safe_small_burst_still_returns_real_frame() {
|
||||
let (reader, mut writer) = duplex(1024);
|
||||
let mut crypto_reader = make_crypto_reader(reader);
|
||||
let started = Instant::now();
|
||||
let forensics = make_forensics(6106, started);
|
||||
let mut frame_counter = 0u64;
|
||||
let mut idle_state = RelayClientIdleState::new(started);
|
||||
|
||||
let payload = [3u8, 1, 4, 1];
|
||||
let mut plaintext = Vec::new();
|
||||
for _ in 0..7 {
|
||||
append_tiny_frame(&mut plaintext, ProtoTag::Secure);
|
||||
}
|
||||
append_real_frame(&mut plaintext, ProtoTag::Secure, payload);
|
||||
let encrypted = encrypt_for_reader(&plaintext);
|
||||
write_chunked_with_jitter(&mut writer, &encrypted, 0xCCCC_DDDD).await;
|
||||
|
||||
let result = read_once_with_state(
|
||||
&mut crypto_reader,
|
||||
ProtoTag::Secure,
|
||||
&forensics,
|
||||
&mut frame_counter,
|
||||
&mut idle_state,
|
||||
)
|
||||
.await
|
||||
.expect("secure safe burst should parse")
|
||||
.expect("secure safe burst should return a frame");
|
||||
|
||||
assert_eq!(result.0.as_ref(), &payload);
|
||||
assert_eq!(frame_counter, 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn light_fuzz_proto_chunking_outcomes_are_bounded() {
|
||||
let mut seed = 0xDEAD_BEEF_2026_0322u64;
|
||||
|
||||
for case in 0..48u64 {
|
||||
seed ^= seed << 7;
|
||||
seed ^= seed >> 9;
|
||||
seed ^= seed << 8;
|
||||
|
||||
let proto = if (seed & 1) == 0 {
|
||||
ProtoTag::Intermediate
|
||||
} else {
|
||||
ProtoTag::Secure
|
||||
};
|
||||
|
||||
let (reader, mut writer) = duplex(8192);
|
||||
let mut crypto_reader = make_crypto_reader(reader);
|
||||
let started = Instant::now();
|
||||
let forensics = make_forensics(6200 + case, started);
|
||||
let mut frame_counter = 0u64;
|
||||
let mut idle_state = RelayClientIdleState::new(started);
|
||||
|
||||
let mut stream = Vec::new();
|
||||
let mut local_seed = seed ^ case;
|
||||
for _ in 0..220 {
|
||||
local_seed ^= local_seed << 7;
|
||||
local_seed ^= local_seed >> 9;
|
||||
local_seed ^= local_seed << 8;
|
||||
if (local_seed & 1) == 0 {
|
||||
append_tiny_frame(&mut stream, proto);
|
||||
} else {
|
||||
let b = (local_seed >> 8) as u8;
|
||||
append_real_frame(&mut stream, proto, [b, b ^ 0x12, b ^ 0x24, b ^ 0x48]);
|
||||
}
|
||||
}
|
||||
|
||||
let encrypted = encrypt_for_reader(&stream);
|
||||
write_chunked_with_jitter(&mut writer, &encrypted, seed ^ 0x1234_5678).await;
|
||||
drop(writer);
|
||||
|
||||
for _ in 0..260 {
|
||||
let step = run_relay_test_step_timeout(
|
||||
"fuzz proto read step",
|
||||
read_once_with_state(
|
||||
&mut crypto_reader,
|
||||
proto,
|
||||
&forensics,
|
||||
&mut frame_counter,
|
||||
&mut idle_state,
|
||||
),
|
||||
)
|
||||
.await;
|
||||
|
||||
match step {
|
||||
Ok(Some((_payload, _))) => {}
|
||||
Err(ProxyError::Proxy(_)) => break,
|
||||
Err(ProxyError::Io(e)) if e.kind() == std::io::ErrorKind::TimedOut => break,
|
||||
Ok(None) => break,
|
||||
Err(other) => panic!("unexpected proto chunking fuzz error: {other}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,798 @@
|
|||
use super::*;
|
||||
use crate::crypto::AesCtr;
|
||||
use crate::stats::Stats;
|
||||
use crate::stream::{BufferPool, CryptoReader};
|
||||
use std::sync::atomic::AtomicU64;
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
use tokio::io::{AsyncRead, AsyncWriteExt, duplex};
|
||||
|
||||
fn make_crypto_reader<T>(reader: T) -> CryptoReader<T>
|
||||
where
|
||||
T: AsyncRead + Unpin + Send + 'static,
|
||||
{
|
||||
let key = [0u8; 32];
|
||||
let iv = 0u128;
|
||||
CryptoReader::new(reader, AesCtr::new(&key, iv))
|
||||
}
|
||||
|
||||
fn encrypt_for_reader(plaintext: &[u8]) -> Vec<u8> {
|
||||
let key = [0u8; 32];
|
||||
let iv = 0u128;
|
||||
let mut cipher = AesCtr::new(&key, iv);
|
||||
cipher.encrypt(plaintext)
|
||||
}
|
||||
|
||||
fn make_forensics(conn_id: u64, started_at: Instant) -> RelayForensicsState {
|
||||
RelayForensicsState {
|
||||
trace_id: 0xB100_0000 + conn_id,
|
||||
conn_id,
|
||||
user: format!("tiny-frame-debt-user-{conn_id}"),
|
||||
peer: "127.0.0.1:50000".parse().expect("peer parse must succeed"),
|
||||
peer_hash: hash_ip("127.0.0.1".parse().expect("ip parse must succeed")),
|
||||
started_at,
|
||||
bytes_c2me: 0,
|
||||
bytes_me2c: Arc::new(AtomicU64::new(0)),
|
||||
desync_all_full: false,
|
||||
}
|
||||
}
|
||||
|
||||
fn make_enabled_idle_policy() -> RelayClientIdlePolicy {
|
||||
RelayClientIdlePolicy {
|
||||
enabled: true,
|
||||
soft_idle: Duration::from_millis(50),
|
||||
hard_idle: Duration::from_millis(120),
|
||||
grace_after_downstream_activity: Duration::from_secs(0),
|
||||
legacy_frame_read_timeout: Duration::from_millis(50),
|
||||
}
|
||||
}
|
||||
|
||||
async fn read_bounded(
|
||||
crypto_reader: &mut CryptoReader<tokio::io::DuplexStream>,
|
||||
proto_tag: ProtoTag,
|
||||
buffer_pool: &Arc<BufferPool>,
|
||||
forensics: &RelayForensicsState,
|
||||
frame_counter: &mut u64,
|
||||
stats: &Stats,
|
||||
idle_policy: &RelayClientIdlePolicy,
|
||||
idle_state: &mut RelayClientIdleState,
|
||||
last_downstream_activity_ms: &AtomicU64,
|
||||
session_started_at: Instant,
|
||||
) -> Result<Option<(PooledBuffer, bool)>> {
|
||||
run_relay_test_step_timeout(
|
||||
"tiny-frame debt read step",
|
||||
read_client_payload_with_idle_policy(
|
||||
crypto_reader,
|
||||
proto_tag,
|
||||
1024,
|
||||
buffer_pool,
|
||||
forensics,
|
||||
frame_counter,
|
||||
stats,
|
||||
idle_policy,
|
||||
idle_state,
|
||||
last_downstream_activity_ms,
|
||||
session_started_at,
|
||||
),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
fn simulate_tiny_debt_pattern(pattern: &[bool], max_steps: usize) -> (Option<usize>, u32, usize) {
|
||||
let mut debt = 0u32;
|
||||
let mut reals = 0usize;
|
||||
for (idx, is_tiny) in pattern.iter().copied().take(max_steps).enumerate() {
|
||||
if is_tiny {
|
||||
debt = debt.saturating_add(TINY_FRAME_DEBT_PER_TINY);
|
||||
if debt >= TINY_FRAME_DEBT_LIMIT {
|
||||
return (Some(idx + 1), debt, reals);
|
||||
}
|
||||
} else {
|
||||
reals = reals.saturating_add(1);
|
||||
debt = debt.saturating_sub(1);
|
||||
}
|
||||
}
|
||||
(None, debt, reals)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tiny_frame_debt_constants_match_security_budget_expectations() {
|
||||
assert_eq!(TINY_FRAME_DEBT_PER_TINY, 8);
|
||||
assert_eq!(TINY_FRAME_DEBT_LIMIT, 512);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn relay_client_idle_state_initial_debt_is_zero() {
|
||||
let state = RelayClientIdleState::new(Instant::now());
|
||||
assert_eq!(state.tiny_frame_debt, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn on_client_frame_does_not_reset_tiny_frame_debt() {
|
||||
let now = Instant::now();
|
||||
let mut state = RelayClientIdleState::new(now);
|
||||
state.tiny_frame_debt = 77;
|
||||
state.on_client_frame(now);
|
||||
assert_eq!(state.tiny_frame_debt, 77);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tiny_frame_debt_increment_is_saturating() {
|
||||
let mut debt = u32::MAX - 1;
|
||||
debt = debt.saturating_add(TINY_FRAME_DEBT_PER_TINY);
|
||||
assert_eq!(debt, u32::MAX);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tiny_frame_debt_decrement_is_saturating() {
|
||||
let mut debt = 0u32;
|
||||
debt = debt.saturating_sub(1);
|
||||
assert_eq!(debt, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn consecutive_tiny_frames_close_exactly_at_threshold() {
|
||||
let max_tiny_without_close = (TINY_FRAME_DEBT_LIMIT / TINY_FRAME_DEBT_PER_TINY) as usize;
|
||||
let pattern = vec![true; max_tiny_without_close];
|
||||
let (closed_at, _, _) = simulate_tiny_debt_pattern(&pattern, pattern.len());
|
||||
assert_eq!(closed_at, Some(max_tiny_without_close));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn one_less_than_threshold_tiny_frames_do_not_close() {
|
||||
let tiny_count = (TINY_FRAME_DEBT_LIMIT / TINY_FRAME_DEBT_PER_TINY) as usize - 1;
|
||||
let pattern = vec![true; tiny_count];
|
||||
let (closed_at, debt, _) = simulate_tiny_debt_pattern(&pattern, pattern.len());
|
||||
assert_eq!(closed_at, None);
|
||||
assert!(debt < TINY_FRAME_DEBT_LIMIT);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn alternating_one_to_one_closes_with_bounded_real_frame_count() {
|
||||
let mut pattern = Vec::with_capacity(512);
|
||||
for _ in 0..256 {
|
||||
pattern.push(true);
|
||||
pattern.push(false);
|
||||
}
|
||||
let (closed_at, _, reals) = simulate_tiny_debt_pattern(&pattern, pattern.len());
|
||||
assert!(closed_at.is_some());
|
||||
assert!(reals <= 80, "expected bounded real frames before close, got {reals}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn alternating_one_to_eight_is_stable_for_long_runs() {
|
||||
let mut pattern = Vec::with_capacity(9 * 5000);
|
||||
for _ in 0..5000 {
|
||||
pattern.push(true);
|
||||
for _ in 0..8 {
|
||||
pattern.push(false);
|
||||
}
|
||||
}
|
||||
let (closed_at, debt, _) = simulate_tiny_debt_pattern(&pattern, pattern.len());
|
||||
assert_eq!(closed_at, None);
|
||||
assert!(debt <= TINY_FRAME_DEBT_PER_TINY);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn alternating_one_to_seven_eventually_closes() {
|
||||
let mut pattern = Vec::with_capacity(8 * 2000);
|
||||
for _ in 0..2000 {
|
||||
pattern.push(true);
|
||||
for _ in 0..7 {
|
||||
pattern.push(false);
|
||||
}
|
||||
}
|
||||
let (closed_at, _, _) = simulate_tiny_debt_pattern(&pattern, pattern.len());
|
||||
assert!(closed_at.is_some(), "1:7 tiny-to-real must eventually close");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn two_tiny_one_real_closes_faster_than_one_to_one() {
|
||||
let mut one_to_one = Vec::with_capacity(512);
|
||||
for _ in 0..256 {
|
||||
one_to_one.push(true);
|
||||
one_to_one.push(false);
|
||||
}
|
||||
|
||||
let mut two_to_one = Vec::with_capacity(768);
|
||||
for _ in 0..256 {
|
||||
two_to_one.push(true);
|
||||
two_to_one.push(true);
|
||||
two_to_one.push(false);
|
||||
}
|
||||
|
||||
let (a_close, _, _) = simulate_tiny_debt_pattern(&one_to_one, one_to_one.len());
|
||||
let (b_close, _, _) = simulate_tiny_debt_pattern(&two_to_one, two_to_one.len());
|
||||
assert!(a_close.is_some() && b_close.is_some());
|
||||
assert!(b_close.unwrap_or(usize::MAX) < a_close.unwrap_or(0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn burst_then_drain_can_recover_without_close() {
|
||||
let burst_tiny = ((TINY_FRAME_DEBT_LIMIT / TINY_FRAME_DEBT_PER_TINY) / 2) as usize;
|
||||
let mut pattern = Vec::with_capacity(burst_tiny + 600);
|
||||
for _ in 0..burst_tiny {
|
||||
pattern.push(true);
|
||||
}
|
||||
pattern.extend(std::iter::repeat_n(false, 600));
|
||||
|
||||
let (closed_at, debt, _) = simulate_tiny_debt_pattern(&pattern, pattern.len());
|
||||
assert_eq!(closed_at, None);
|
||||
assert_eq!(debt, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn light_fuzz_tiny_frame_debt_model_stays_within_bounds() {
|
||||
let mut seed = 0xA5A5_91C3_2026_0322u64;
|
||||
for _case in 0..128 {
|
||||
seed ^= seed << 7;
|
||||
seed ^= seed >> 9;
|
||||
seed ^= seed << 8;
|
||||
|
||||
let len = 512 + ((seed as usize) & 0x3ff);
|
||||
let mut pattern = Vec::with_capacity(len);
|
||||
let mut local_seed = seed;
|
||||
for _ in 0..len {
|
||||
local_seed ^= local_seed << 7;
|
||||
local_seed ^= local_seed >> 9;
|
||||
local_seed ^= local_seed << 8;
|
||||
pattern.push((local_seed & 1) == 0);
|
||||
}
|
||||
|
||||
let (closed_at, debt, _) = simulate_tiny_debt_pattern(&pattern, pattern.len());
|
||||
if closed_at.is_none() {
|
||||
assert!(debt < TINY_FRAME_DEBT_LIMIT);
|
||||
}
|
||||
assert!(debt <= u32::MAX);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stress_many_independent_simulations_keep_isolated_debt_state() {
|
||||
for idx in 0..2048usize {
|
||||
let mut pattern = Vec::with_capacity(64);
|
||||
for j in 0..64usize {
|
||||
pattern.push(((idx ^ j) & 3) == 0);
|
||||
}
|
||||
let (_closed_at, debt, _reals) = simulate_tiny_debt_pattern(&pattern, pattern.len());
|
||||
assert!(debt <= TINY_FRAME_DEBT_LIMIT.saturating_add(TINY_FRAME_DEBT_PER_TINY));
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn idle_policy_enabled_intermediate_zero_length_flood_is_fail_closed() {
|
||||
let (reader, mut writer) = duplex(4096);
|
||||
let mut crypto_reader = make_crypto_reader(reader);
|
||||
let buffer_pool = Arc::new(BufferPool::new());
|
||||
let stats = Stats::new();
|
||||
let session_started_at = Instant::now();
|
||||
let forensics = make_forensics(11, session_started_at);
|
||||
let mut frame_counter = 0u64;
|
||||
let mut idle_state = RelayClientIdleState::new(session_started_at);
|
||||
let idle_policy = make_enabled_idle_policy();
|
||||
let last_downstream_activity_ms = AtomicU64::new(0);
|
||||
|
||||
let flood_plaintext = vec![0u8; 4 * 256];
|
||||
let flood_encrypted = encrypt_for_reader(&flood_plaintext);
|
||||
writer.write_all(&flood_encrypted).await.unwrap();
|
||||
drop(writer);
|
||||
|
||||
let result = read_bounded(
|
||||
&mut crypto_reader,
|
||||
ProtoTag::Intermediate,
|
||||
&buffer_pool,
|
||||
&forensics,
|
||||
&mut frame_counter,
|
||||
&stats,
|
||||
&idle_policy,
|
||||
&mut idle_state,
|
||||
&last_downstream_activity_ms,
|
||||
session_started_at,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(matches!(result, Err(ProxyError::Proxy(_))));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn idle_policy_enabled_secure_zero_length_flood_is_fail_closed() {
|
||||
let (reader, mut writer) = duplex(4096);
|
||||
let mut crypto_reader = make_crypto_reader(reader);
|
||||
let buffer_pool = Arc::new(BufferPool::new());
|
||||
let stats = Stats::new();
|
||||
let session_started_at = Instant::now();
|
||||
let forensics = make_forensics(12, session_started_at);
|
||||
let mut frame_counter = 0u64;
|
||||
let mut idle_state = RelayClientIdleState::new(session_started_at);
|
||||
let idle_policy = make_enabled_idle_policy();
|
||||
let last_downstream_activity_ms = AtomicU64::new(0);
|
||||
|
||||
let flood_plaintext = vec![0u8; 4 * 256];
|
||||
let flood_encrypted = encrypt_for_reader(&flood_plaintext);
|
||||
writer.write_all(&flood_encrypted).await.unwrap();
|
||||
drop(writer);
|
||||
|
||||
let result = read_bounded(
|
||||
&mut crypto_reader,
|
||||
ProtoTag::Secure,
|
||||
&buffer_pool,
|
||||
&forensics,
|
||||
&mut frame_counter,
|
||||
&stats,
|
||||
&idle_policy,
|
||||
&mut idle_state,
|
||||
&last_downstream_activity_ms,
|
||||
session_started_at,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(matches!(result, Err(ProxyError::Proxy(_))));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn intermediate_alternating_zero_and_real_eventually_closes() {
|
||||
let (reader, mut writer) = duplex(8192);
|
||||
let mut crypto_reader = make_crypto_reader(reader);
|
||||
let buffer_pool = Arc::new(BufferPool::new());
|
||||
let stats = Stats::new();
|
||||
let session_started_at = Instant::now();
|
||||
let forensics = make_forensics(13, session_started_at);
|
||||
let mut frame_counter = 0u64;
|
||||
let mut idle_state = RelayClientIdleState::new(session_started_at);
|
||||
let idle_policy = make_enabled_idle_policy();
|
||||
let last_downstream_activity_ms = AtomicU64::new(0);
|
||||
|
||||
let mut plaintext = Vec::with_capacity(3000);
|
||||
for idx in 0..160u8 {
|
||||
plaintext.extend_from_slice(&0u32.to_le_bytes());
|
||||
plaintext.extend_from_slice(&4u32.to_le_bytes());
|
||||
plaintext.extend_from_slice(&[idx, idx ^ 0x11, idx ^ 0x22, idx ^ 0x33]);
|
||||
}
|
||||
let encrypted = encrypt_for_reader(&plaintext);
|
||||
writer.write_all(&encrypted).await.unwrap();
|
||||
drop(writer);
|
||||
|
||||
let mut closed = false;
|
||||
for _ in 0..220 {
|
||||
let result = read_bounded(
|
||||
&mut crypto_reader,
|
||||
ProtoTag::Intermediate,
|
||||
&buffer_pool,
|
||||
&forensics,
|
||||
&mut frame_counter,
|
||||
&stats,
|
||||
&idle_policy,
|
||||
&mut idle_state,
|
||||
&last_downstream_activity_ms,
|
||||
session_started_at,
|
||||
)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(Some(_)) => {}
|
||||
Err(ProxyError::Proxy(_)) => {
|
||||
closed = true;
|
||||
break;
|
||||
}
|
||||
Ok(None) => break,
|
||||
Err(other) => panic!("unexpected error while probing alternating close: {other}"),
|
||||
}
|
||||
}
|
||||
|
||||
assert!(closed, "intermediate alternating attack must fail closed");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn small_tiny_burst_followed_by_real_frame_does_not_spuriously_close() {
|
||||
let (reader, mut writer) = duplex(1024);
|
||||
let mut crypto_reader = make_crypto_reader(reader);
|
||||
let buffer_pool = Arc::new(BufferPool::new());
|
||||
let stats = Stats::new();
|
||||
let session_started_at = Instant::now();
|
||||
let forensics = make_forensics(14, session_started_at);
|
||||
let mut frame_counter = 0u64;
|
||||
let mut idle_state = RelayClientIdleState::new(session_started_at);
|
||||
let idle_policy = make_enabled_idle_policy();
|
||||
let last_downstream_activity_ms = AtomicU64::new(0);
|
||||
|
||||
let mut plaintext = Vec::with_capacity(64);
|
||||
for _ in 0..8 {
|
||||
plaintext.push(0x00);
|
||||
}
|
||||
plaintext.push(0x01);
|
||||
plaintext.extend_from_slice(&[1, 2, 3, 4]);
|
||||
|
||||
let encrypted = encrypt_for_reader(&plaintext);
|
||||
writer.write_all(&encrypted).await.unwrap();
|
||||
|
||||
let first = read_bounded(
|
||||
&mut crypto_reader,
|
||||
ProtoTag::Abridged,
|
||||
&buffer_pool,
|
||||
&forensics,
|
||||
&mut frame_counter,
|
||||
&stats,
|
||||
&idle_policy,
|
||||
&mut idle_state,
|
||||
&last_downstream_activity_ms,
|
||||
session_started_at,
|
||||
)
|
||||
.await;
|
||||
|
||||
match first {
|
||||
Ok(Some((payload, _))) => assert_eq!(payload.as_ref(), &[1, 2, 3, 4]),
|
||||
Err(e) => panic!("unexpected close after small tiny burst: {e}"),
|
||||
Ok(None) => panic!("unexpected EOF before real frame"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn idle_policy_enabled_zero_length_flood_is_fail_closed() {
|
||||
let (reader, mut writer) = duplex(4096);
|
||||
let mut crypto_reader = make_crypto_reader(reader);
|
||||
let buffer_pool = Arc::new(BufferPool::new());
|
||||
let stats = Stats::new();
|
||||
let session_started_at = Instant::now();
|
||||
let forensics = make_forensics(1, session_started_at);
|
||||
let mut frame_counter = 0u64;
|
||||
let mut idle_state = RelayClientIdleState::new(session_started_at);
|
||||
let idle_policy = make_enabled_idle_policy();
|
||||
let last_downstream_activity_ms = AtomicU64::new(0);
|
||||
|
||||
let flood_plaintext = vec![0u8; 1024];
|
||||
let flood_encrypted = encrypt_for_reader(&flood_plaintext);
|
||||
writer
|
||||
.write_all(&flood_encrypted)
|
||||
.await
|
||||
.expect("zero-length flood bytes must be writable");
|
||||
drop(writer);
|
||||
|
||||
let result = read_bounded(
|
||||
&mut crypto_reader,
|
||||
ProtoTag::Abridged,
|
||||
&buffer_pool,
|
||||
&forensics,
|
||||
&mut frame_counter,
|
||||
&stats,
|
||||
&idle_policy,
|
||||
&mut idle_state,
|
||||
&last_downstream_activity_ms,
|
||||
session_started_at,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(
|
||||
matches!(result, Err(ProxyError::Proxy(_))),
|
||||
"idle policy enabled must fail closed for pure zero-length flood"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn idle_policy_enabled_alternating_tiny_real_eventually_closes() {
|
||||
let (reader, mut writer) = duplex(8192);
|
||||
let mut crypto_reader = make_crypto_reader(reader);
|
||||
let buffer_pool = Arc::new(BufferPool::new());
|
||||
let stats = Stats::new();
|
||||
let session_started_at = Instant::now();
|
||||
let forensics = make_forensics(2, session_started_at);
|
||||
let mut frame_counter = 0u64;
|
||||
let mut idle_state = RelayClientIdleState::new(session_started_at);
|
||||
let idle_policy = make_enabled_idle_policy();
|
||||
let last_downstream_activity_ms = AtomicU64::new(0);
|
||||
|
||||
let mut plaintext = Vec::with_capacity(256 * 6);
|
||||
for idx in 0..=255u8 {
|
||||
plaintext.push(0x00);
|
||||
plaintext.push(0x01);
|
||||
plaintext.extend_from_slice(&[idx, idx ^ 0x55, idx ^ 0xAA, 0x11]);
|
||||
}
|
||||
|
||||
let encrypted = encrypt_for_reader(&plaintext);
|
||||
writer
|
||||
.write_all(&encrypted)
|
||||
.await
|
||||
.expect("alternating flood bytes must be writable");
|
||||
drop(writer);
|
||||
|
||||
let mut saw_proxy_close = false;
|
||||
for _ in 0..300 {
|
||||
let result = read_bounded(
|
||||
&mut crypto_reader,
|
||||
ProtoTag::Abridged,
|
||||
&buffer_pool,
|
||||
&forensics,
|
||||
&mut frame_counter,
|
||||
&stats,
|
||||
&idle_policy,
|
||||
&mut idle_state,
|
||||
&last_downstream_activity_ms,
|
||||
session_started_at,
|
||||
)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(Some((_payload, _quickack))) => {}
|
||||
Err(ProxyError::Proxy(_)) => {
|
||||
saw_proxy_close = true;
|
||||
break;
|
||||
}
|
||||
Err(ProxyError::Io(e)) => panic!("unexpected IO error before close: {e}"),
|
||||
Ok(None) => panic!("unexpected EOF before debt-based closure"),
|
||||
Err(other) => panic!("unexpected error before close: {other}"),
|
||||
}
|
||||
}
|
||||
|
||||
assert!(
|
||||
saw_proxy_close,
|
||||
"alternating tiny/real sequence must eventually fail closed"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn enabled_idle_policy_valid_nonzero_frame_still_passes() {
|
||||
let (reader, mut writer) = duplex(1024);
|
||||
let mut crypto_reader = make_crypto_reader(reader);
|
||||
let buffer_pool = Arc::new(BufferPool::new());
|
||||
let stats = Stats::new();
|
||||
let session_started_at = Instant::now();
|
||||
let forensics = make_forensics(3, session_started_at);
|
||||
let mut frame_counter = 0u64;
|
||||
let mut idle_state = RelayClientIdleState::new(session_started_at);
|
||||
let idle_policy = make_enabled_idle_policy();
|
||||
let last_downstream_activity_ms = AtomicU64::new(0);
|
||||
|
||||
let payload = [7u8, 8, 9, 10];
|
||||
let mut plaintext = Vec::with_capacity(1 + payload.len());
|
||||
plaintext.push(0x01);
|
||||
plaintext.extend_from_slice(&payload);
|
||||
|
||||
let encrypted = encrypt_for_reader(&plaintext);
|
||||
writer
|
||||
.write_all(&encrypted)
|
||||
.await
|
||||
.expect("nonzero frame must be writable");
|
||||
|
||||
let result = read_bounded(
|
||||
&mut crypto_reader,
|
||||
ProtoTag::Abridged,
|
||||
&buffer_pool,
|
||||
&forensics,
|
||||
&mut frame_counter,
|
||||
&stats,
|
||||
&idle_policy,
|
||||
&mut idle_state,
|
||||
&last_downstream_activity_ms,
|
||||
session_started_at,
|
||||
)
|
||||
.await
|
||||
.expect("valid frame should decode")
|
||||
.expect("valid frame should return payload");
|
||||
|
||||
assert_eq!(result.0.as_ref(), &payload);
|
||||
assert!(!result.1);
|
||||
assert_eq!(frame_counter, 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn abridged_quickack_tiny_flood_is_fail_closed() {
|
||||
let (reader, mut writer) = duplex(4096);
|
||||
let mut crypto_reader = make_crypto_reader(reader);
|
||||
let buffer_pool = Arc::new(BufferPool::new());
|
||||
let stats = Stats::new();
|
||||
let session_started_at = Instant::now();
|
||||
let forensics = make_forensics(21, session_started_at);
|
||||
let mut frame_counter = 0u64;
|
||||
let mut idle_state = RelayClientIdleState::new(session_started_at);
|
||||
let idle_policy = make_enabled_idle_policy();
|
||||
let last_downstream_activity_ms = AtomicU64::new(0);
|
||||
|
||||
let flood_plaintext = vec![0x80u8; 256];
|
||||
let flood_encrypted = encrypt_for_reader(&flood_plaintext);
|
||||
writer.write_all(&flood_encrypted).await.unwrap();
|
||||
drop(writer);
|
||||
|
||||
let result = read_bounded(
|
||||
&mut crypto_reader,
|
||||
ProtoTag::Abridged,
|
||||
&buffer_pool,
|
||||
&forensics,
|
||||
&mut frame_counter,
|
||||
&stats,
|
||||
&idle_policy,
|
||||
&mut idle_state,
|
||||
&last_downstream_activity_ms,
|
||||
session_started_at,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(
|
||||
matches!(result, Err(ProxyError::Proxy(_))),
|
||||
"quickack-marked zero-length flood must fail closed"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn abridged_extended_zero_len_flood_is_fail_closed() {
|
||||
let (reader, mut writer) = duplex(4096);
|
||||
let mut crypto_reader = make_crypto_reader(reader);
|
||||
let buffer_pool = Arc::new(BufferPool::new());
|
||||
let stats = Stats::new();
|
||||
let session_started_at = Instant::now();
|
||||
let forensics = make_forensics(22, session_started_at);
|
||||
let mut frame_counter = 0u64;
|
||||
let mut idle_state = RelayClientIdleState::new(session_started_at);
|
||||
let idle_policy = make_enabled_idle_policy();
|
||||
let last_downstream_activity_ms = AtomicU64::new(0);
|
||||
|
||||
let mut flood_plaintext = Vec::with_capacity(4 * 256);
|
||||
for _ in 0..256 {
|
||||
flood_plaintext.extend_from_slice(&[0x7f, 0x00, 0x00, 0x00]);
|
||||
}
|
||||
let flood_encrypted = encrypt_for_reader(&flood_plaintext);
|
||||
writer.write_all(&flood_encrypted).await.unwrap();
|
||||
drop(writer);
|
||||
|
||||
let result = read_bounded(
|
||||
&mut crypto_reader,
|
||||
ProtoTag::Abridged,
|
||||
&buffer_pool,
|
||||
&forensics,
|
||||
&mut frame_counter,
|
||||
&stats,
|
||||
&idle_policy,
|
||||
&mut idle_state,
|
||||
&last_downstream_activity_ms,
|
||||
session_started_at,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(
|
||||
matches!(result, Err(ProxyError::Proxy(_))),
|
||||
"extended zero-length abridged flood must fail closed"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn one_to_eight_abridged_wire_pattern_survives_without_false_positive_close() {
|
||||
let mut plaintext = Vec::with_capacity(9 * 300);
|
||||
for idx in 0..300usize {
|
||||
plaintext.push(0x00);
|
||||
for _ in 0..8 {
|
||||
let b = idx as u8;
|
||||
plaintext.push(0x01);
|
||||
plaintext.extend_from_slice(&[b, b ^ 0x11, b ^ 0x22, b ^ 0x33]);
|
||||
}
|
||||
}
|
||||
|
||||
// Keep the test single-task and deterministic: make duplex capacity larger than the
|
||||
// generated ciphertext so write_all cannot block waiting for a concurrent reader.
|
||||
let duplex_capacity = plaintext.len().saturating_add(1024);
|
||||
let (reader, mut writer) = duplex(duplex_capacity);
|
||||
let mut crypto_reader = make_crypto_reader(reader);
|
||||
let buffer_pool = Arc::new(BufferPool::new());
|
||||
let stats = Stats::new();
|
||||
let session_started_at = Instant::now();
|
||||
let forensics = make_forensics(23, session_started_at);
|
||||
let mut frame_counter = 0u64;
|
||||
let mut idle_state = RelayClientIdleState::new(session_started_at);
|
||||
let idle_policy = make_enabled_idle_policy();
|
||||
let last_downstream_activity_ms = AtomicU64::new(0);
|
||||
|
||||
let encrypted = encrypt_for_reader(&plaintext);
|
||||
writer.write_all(&encrypted).await.unwrap();
|
||||
drop(writer);
|
||||
|
||||
let mut closed = false;
|
||||
for _ in 0..3000 {
|
||||
match read_bounded(
|
||||
&mut crypto_reader,
|
||||
ProtoTag::Abridged,
|
||||
&buffer_pool,
|
||||
&forensics,
|
||||
&mut frame_counter,
|
||||
&stats,
|
||||
&idle_policy,
|
||||
&mut idle_state,
|
||||
&last_downstream_activity_ms,
|
||||
session_started_at,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(Some(_)) => {}
|
||||
Ok(None) => break,
|
||||
Err(ProxyError::Proxy(_)) => {
|
||||
closed = true;
|
||||
break;
|
||||
}
|
||||
Err(other) => panic!("unexpected error in 1:8 wire test: {other}"),
|
||||
}
|
||||
}
|
||||
|
||||
assert!(
|
||||
!closed,
|
||||
"wire-level 1:8 tiny-to-real pattern should not trigger debt close"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn deterministic_light_fuzz_abridged_wire_behavior_matches_model() {
|
||||
let mut seed = 0xD1CE_BAAD_2026_0322u64;
|
||||
|
||||
for case_idx in 0..32u64 {
|
||||
seed ^= seed << 7;
|
||||
seed ^= seed >> 9;
|
||||
seed ^= seed << 8;
|
||||
|
||||
let events = 300 + ((seed as usize) & 0xff);
|
||||
let mut pattern = Vec::with_capacity(events);
|
||||
let mut local = seed;
|
||||
for _ in 0..events {
|
||||
local ^= local << 7;
|
||||
local ^= local >> 9;
|
||||
local ^= local << 8;
|
||||
pattern.push((local & 0x03) == 0);
|
||||
}
|
||||
|
||||
let mut plaintext = Vec::with_capacity(events * 6);
|
||||
for (idx, tiny) in pattern.iter().copied().enumerate() {
|
||||
if tiny {
|
||||
plaintext.push(0x00);
|
||||
} else {
|
||||
let b = (idx as u8) ^ (case_idx as u8);
|
||||
plaintext.push(0x01);
|
||||
plaintext.extend_from_slice(&[b, b ^ 0x1F, b ^ 0x7A, b ^ 0xC3]);
|
||||
}
|
||||
}
|
||||
|
||||
let (reader, mut writer) = duplex(16 * 1024);
|
||||
let mut crypto_reader = make_crypto_reader(reader);
|
||||
let buffer_pool = Arc::new(BufferPool::new());
|
||||
let stats = Stats::new();
|
||||
let session_started_at = Instant::now();
|
||||
let forensics = make_forensics(500 + case_idx, session_started_at);
|
||||
let mut frame_counter = 0u64;
|
||||
let mut idle_state = RelayClientIdleState::new(session_started_at);
|
||||
let idle_policy = make_enabled_idle_policy();
|
||||
let last_downstream_activity_ms = AtomicU64::new(0);
|
||||
|
||||
writer
|
||||
.write_all(&encrypt_for_reader(&plaintext))
|
||||
.await
|
||||
.unwrap();
|
||||
drop(writer);
|
||||
|
||||
let (expected_close, _, _) = simulate_tiny_debt_pattern(&pattern, pattern.len());
|
||||
let mut observed_close = false;
|
||||
|
||||
for _ in 0..(events + 8) {
|
||||
match read_bounded(
|
||||
&mut crypto_reader,
|
||||
ProtoTag::Abridged,
|
||||
&buffer_pool,
|
||||
&forensics,
|
||||
&mut frame_counter,
|
||||
&stats,
|
||||
&idle_policy,
|
||||
&mut idle_state,
|
||||
&last_downstream_activity_ms,
|
||||
session_started_at,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(Some(_)) => {}
|
||||
Ok(None) => break,
|
||||
Err(ProxyError::Proxy(_)) => {
|
||||
observed_close = true;
|
||||
break;
|
||||
}
|
||||
Err(other) => panic!("unexpected fuzz error: {other}"),
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
observed_close,
|
||||
expected_close.is_some(),
|
||||
"wire parser behavior must match debt model for case {case_idx}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,121 @@
|
|||
use super::*;
|
||||
use crate::crypto::AesCtr;
|
||||
use crate::stats::Stats;
|
||||
use crate::stream::{BufferPool, CryptoReader};
|
||||
use std::sync::atomic::AtomicU64;
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{AsyncRead, AsyncWriteExt, duplex};
|
||||
use std::time::Instant;
|
||||
|
||||
fn make_crypto_reader<T>(reader: T) -> CryptoReader<T>
|
||||
where
|
||||
T: AsyncRead + Unpin + Send + 'static,
|
||||
{
|
||||
let key = [0u8; 32];
|
||||
let iv = 0u128;
|
||||
CryptoReader::new(reader, AesCtr::new(&key, iv))
|
||||
}
|
||||
|
||||
fn encrypt_for_reader(plaintext: &[u8]) -> Vec<u8> {
|
||||
let key = [0u8; 32];
|
||||
let iv = 0u128;
|
||||
let mut cipher = AesCtr::new(&key, iv);
|
||||
cipher.encrypt(plaintext)
|
||||
}
|
||||
|
||||
fn make_forensics(conn_id: u64, started_at: Instant) -> RelayForensicsState {
|
||||
RelayForensicsState {
|
||||
trace_id: 0xB000_0000 + conn_id,
|
||||
conn_id,
|
||||
user: format!("zero-len-test-user-{conn_id}"),
|
||||
peer: "127.0.0.1:50000".parse().expect("peer parse must succeed"),
|
||||
peer_hash: hash_ip("127.0.0.1".parse().expect("ip parse must succeed")),
|
||||
started_at,
|
||||
bytes_c2me: 0,
|
||||
bytes_me2c: Arc::new(AtomicU64::new(0)),
|
||||
desync_all_full: false,
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn adversarial_legacy_zero_length_flood_is_fail_closed() {
|
||||
let (reader, mut writer) = duplex(1024);
|
||||
let mut crypto_reader = make_crypto_reader(reader);
|
||||
let buffer_pool = Arc::new(BufferPool::new());
|
||||
let stats = Stats::new();
|
||||
let session_started_at = Instant::now();
|
||||
let forensics = make_forensics(1, session_started_at);
|
||||
let mut frame_counter = 0u64;
|
||||
|
||||
let flood_plaintext = vec![0u8; 128];
|
||||
let flood_encrypted = encrypt_for_reader(&flood_plaintext);
|
||||
writer
|
||||
.write_all(&flood_encrypted)
|
||||
.await
|
||||
.expect("zero-length flood bytes must be writable");
|
||||
drop(writer);
|
||||
|
||||
let result = read_client_payload_legacy(
|
||||
&mut crypto_reader,
|
||||
ProtoTag::Abridged,
|
||||
1024,
|
||||
Duration::from_millis(30),
|
||||
&buffer_pool,
|
||||
&forensics,
|
||||
&mut frame_counter,
|
||||
&stats,
|
||||
)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Err(ProxyError::Proxy(msg)) => {
|
||||
assert!(
|
||||
msg.contains("Excessive zero-length"),
|
||||
"legacy mode must close flood with explicit zero-length reason, got: {msg}"
|
||||
);
|
||||
}
|
||||
Ok(None) => panic!("legacy zero-length flood must not be accepted as EOF"),
|
||||
Ok(Some(_)) => panic!("legacy zero-length flood must not produce a data frame"),
|
||||
Err(err) => panic!("legacy zero-length flood must be a Proxy error, got: {err}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn business_abridged_nonzero_frame_still_passes() {
|
||||
let (reader, mut writer) = duplex(1024);
|
||||
let mut crypto_reader = make_crypto_reader(reader);
|
||||
let buffer_pool = Arc::new(BufferPool::new());
|
||||
let stats = Stats::new();
|
||||
let session_started_at = Instant::now();
|
||||
let forensics = make_forensics(2, session_started_at);
|
||||
let mut frame_counter = 0u64;
|
||||
|
||||
let payload = [1u8, 2, 3, 4];
|
||||
let mut plaintext = Vec::with_capacity(1 + payload.len());
|
||||
plaintext.push(0x01);
|
||||
plaintext.extend_from_slice(&payload);
|
||||
|
||||
let encrypted = encrypt_for_reader(&plaintext);
|
||||
writer
|
||||
.write_all(&encrypted)
|
||||
.await
|
||||
.expect("nonzero abridged frame must be writable");
|
||||
|
||||
let result = read_client_payload_legacy(
|
||||
&mut crypto_reader,
|
||||
ProtoTag::Abridged,
|
||||
1024,
|
||||
Duration::from_millis(30),
|
||||
&buffer_pool,
|
||||
&forensics,
|
||||
&mut frame_counter,
|
||||
&stats,
|
||||
)
|
||||
.await
|
||||
.expect("valid abridged frame should decode")
|
||||
.expect("valid abridged frame should return payload");
|
||||
|
||||
assert_eq!(result.0.as_ref(), &payload);
|
||||
assert!(!result.1, "quickack flag must remain false");
|
||||
assert_eq!(frame_counter, 1);
|
||||
}
|
||||
|
|
@ -0,0 +1,108 @@
|
|||
use super::*;
|
||||
use std::sync::Arc;
|
||||
use std::sync::{Mutex, OnceLock};
|
||||
|
||||
fn cross_mode_lock_test_guard() -> std::sync::MutexGuard<'static, ()> {
|
||||
static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
TEST_LOCK
|
||||
.get_or_init(|| Mutex::new(()))
|
||||
.lock()
|
||||
.unwrap_or_else(|poisoned| poisoned.into_inner())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn same_user_returns_same_lock_identity() {
|
||||
let _guard = cross_mode_lock_test_guard();
|
||||
let locks = CROSS_MODE_QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
locks.clear();
|
||||
|
||||
let a = cross_mode_quota_user_lock("cross-mode-same-user");
|
||||
let b = cross_mode_quota_user_lock("cross-mode-same-user");
|
||||
|
||||
assert!(
|
||||
Arc::ptr_eq(&a, &b),
|
||||
"same user must reuse a stable lock identity"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn saturation_overflow_path_returns_stable_striped_lock_without_cache_growth() {
|
||||
let _guard = cross_mode_lock_test_guard();
|
||||
let locks = CROSS_MODE_QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
locks.clear();
|
||||
|
||||
let prefix = format!("cross-mode-saturated-{}", std::process::id());
|
||||
let mut retained = Vec::with_capacity(CROSS_MODE_QUOTA_USER_LOCKS_MAX);
|
||||
for idx in 0..CROSS_MODE_QUOTA_USER_LOCKS_MAX {
|
||||
retained.push(cross_mode_quota_user_lock(&format!("{prefix}-{idx}")));
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
locks.len(),
|
||||
CROSS_MODE_QUOTA_USER_LOCKS_MAX,
|
||||
"lock cache must be saturated for overflow check"
|
||||
);
|
||||
|
||||
let overflow_user = format!("cross-mode-overflow-{}", std::process::id());
|
||||
let overflow_a = cross_mode_quota_user_lock(&overflow_user);
|
||||
let overflow_b = cross_mode_quota_user_lock(&overflow_user);
|
||||
|
||||
assert_eq!(
|
||||
locks.len(),
|
||||
CROSS_MODE_QUOTA_USER_LOCKS_MAX,
|
||||
"overflow path must not grow bounded lock cache"
|
||||
);
|
||||
assert!(
|
||||
locks.get(&overflow_user).is_none(),
|
||||
"overflow user must stay on striped fallback while cache is saturated"
|
||||
);
|
||||
assert!(
|
||||
Arc::ptr_eq(&overflow_a, &overflow_b),
|
||||
"overflow user must receive a stable striped lock across repeated lookups"
|
||||
);
|
||||
|
||||
drop(retained);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reclaim_drops_stale_entries_but_preserves_active_user_lock_identity() {
|
||||
let _guard = cross_mode_lock_test_guard();
|
||||
let locks = CROSS_MODE_QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
locks.clear();
|
||||
|
||||
let prefix = format!("cross-mode-reclaim-{}", std::process::id());
|
||||
let protected_user = format!("{prefix}-protected");
|
||||
|
||||
let protected_lock = cross_mode_quota_user_lock(&protected_user);
|
||||
let mut retained = Vec::with_capacity(CROSS_MODE_QUOTA_USER_LOCKS_MAX.saturating_sub(1));
|
||||
for idx in 0..(CROSS_MODE_QUOTA_USER_LOCKS_MAX.saturating_sub(1)) {
|
||||
retained.push(cross_mode_quota_user_lock(&format!("{prefix}-{idx}")));
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
locks.len(),
|
||||
CROSS_MODE_QUOTA_USER_LOCKS_MAX,
|
||||
"fixture must saturate lock cache before reclaim path is exercised"
|
||||
);
|
||||
|
||||
drop(retained);
|
||||
|
||||
let newcomer_user = format!("{prefix}-newcomer");
|
||||
let _newcomer = cross_mode_quota_user_lock(&newcomer_user);
|
||||
|
||||
assert!(
|
||||
locks.get(&protected_user).is_some(),
|
||||
"active protected user must remain cache-resident after reclaim"
|
||||
);
|
||||
let locked = locks
|
||||
.get(&protected_user)
|
||||
.expect("protected user must remain in map after reclaim");
|
||||
assert!(
|
||||
Arc::ptr_eq(locked.value(), &protected_lock),
|
||||
"reclaim must not swap active user lock identity"
|
||||
);
|
||||
assert!(
|
||||
locks.get(&newcomer_user).is_some(),
|
||||
"newcomer should become cacheable after stale entries are reclaimed"
|
||||
);
|
||||
}
|
||||
|
|
@ -0,0 +1,267 @@
|
|||
use super::relay_bidirectional;
|
||||
use crate::stats::Stats;
|
||||
use crate::stream::BufferPool;
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
|
||||
use tokio::time::{Duration, timeout};
|
||||
|
||||
fn quota_test_guard() -> impl Drop {
|
||||
super::quota_user_lock_test_scope()
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn negative_same_user_pipeline_stalls_while_middle_lock_is_held() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("relay-pipeline-stall-{}", std::process::id());
|
||||
let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user);
|
||||
let held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold shared cross-mode lock");
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
let (mut client_peer, relay_client) = duplex(1024);
|
||||
let (relay_server, mut server_peer) = duplex(1024);
|
||||
let (client_reader, client_writer) = tokio::io::split(relay_client);
|
||||
let (server_reader, server_writer) = tokio::io::split(relay_server);
|
||||
|
||||
let relay_user = user.clone();
|
||||
let relay_stats = Arc::clone(&stats);
|
||||
let relay_task = tokio::spawn(async move {
|
||||
relay_bidirectional(
|
||||
client_reader,
|
||||
client_writer,
|
||||
server_reader,
|
||||
server_writer,
|
||||
256,
|
||||
256,
|
||||
&relay_user,
|
||||
relay_stats,
|
||||
Some(1024),
|
||||
Arc::new(BufferPool::new()),
|
||||
)
|
||||
.await
|
||||
});
|
||||
|
||||
server_peer
|
||||
.write_all(&[0xA1])
|
||||
.await
|
||||
.expect("server write should enqueue while relay is stalled");
|
||||
|
||||
let mut one = [0u8; 1];
|
||||
let blocked_read = timeout(Duration::from_millis(40), client_peer.read_exact(&mut one)).await;
|
||||
assert!(
|
||||
blocked_read.is_err(),
|
||||
"same-user relay must remain blocked while cross-mode lock is held"
|
||||
);
|
||||
|
||||
drop(held_guard);
|
||||
|
||||
timeout(Duration::from_millis(400), client_peer.read_exact(&mut one))
|
||||
.await
|
||||
.expect("blocked relay must resume after cross-mode lock release")
|
||||
.expect("resumed relay must deliver queued byte");
|
||||
assert_eq!(one, [0xA1]);
|
||||
|
||||
drop(client_peer);
|
||||
drop(server_peer);
|
||||
|
||||
let relay_result = timeout(Duration::from_secs(1), relay_task)
|
||||
.await
|
||||
.expect("relay task must complete")
|
||||
.expect("relay task must not panic");
|
||||
assert!(relay_result.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn integration_other_user_pipeline_progresses_while_blocked_user_is_stalled() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let blocked_user = format!("relay-pipeline-blocked-{}", std::process::id());
|
||||
let free_user = format!("relay-pipeline-free-{}", std::process::id());
|
||||
let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&blocked_user);
|
||||
let held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold blocked user's shared cross-mode lock");
|
||||
|
||||
let stats_blocked = Arc::new(Stats::new());
|
||||
let stats_free = Arc::new(Stats::new());
|
||||
|
||||
let (mut blocked_client, blocked_relay_client) = duplex(1024);
|
||||
let (blocked_relay_server, mut blocked_server) = duplex(1024);
|
||||
let (blocked_client_reader, blocked_client_writer) = tokio::io::split(blocked_relay_client);
|
||||
let (blocked_server_reader, blocked_server_writer) = tokio::io::split(blocked_relay_server);
|
||||
|
||||
let (mut free_client, free_relay_client) = duplex(1024);
|
||||
let (free_relay_server, mut free_server) = duplex(1024);
|
||||
let (free_client_reader, free_client_writer) = tokio::io::split(free_relay_client);
|
||||
let (free_server_reader, free_server_writer) = tokio::io::split(free_relay_server);
|
||||
|
||||
let blocked_task = {
|
||||
let user = blocked_user.clone();
|
||||
let stats = Arc::clone(&stats_blocked);
|
||||
tokio::spawn(async move {
|
||||
relay_bidirectional(
|
||||
blocked_client_reader,
|
||||
blocked_client_writer,
|
||||
blocked_server_reader,
|
||||
blocked_server_writer,
|
||||
256,
|
||||
256,
|
||||
&user,
|
||||
stats,
|
||||
Some(1024),
|
||||
Arc::new(BufferPool::new()),
|
||||
)
|
||||
.await
|
||||
})
|
||||
};
|
||||
|
||||
let free_task = {
|
||||
let user = free_user.clone();
|
||||
let stats = Arc::clone(&stats_free);
|
||||
tokio::spawn(async move {
|
||||
relay_bidirectional(
|
||||
free_client_reader,
|
||||
free_client_writer,
|
||||
free_server_reader,
|
||||
free_server_writer,
|
||||
256,
|
||||
256,
|
||||
&user,
|
||||
stats,
|
||||
Some(1024),
|
||||
Arc::new(BufferPool::new()),
|
||||
)
|
||||
.await
|
||||
})
|
||||
};
|
||||
|
||||
blocked_server
|
||||
.write_all(&[0xB1])
|
||||
.await
|
||||
.expect("blocked user server write should queue");
|
||||
free_server
|
||||
.write_all(&[0xC1])
|
||||
.await
|
||||
.expect("free user server write should queue");
|
||||
|
||||
let mut blocked_buf = [0u8; 1];
|
||||
let mut free_buf = [0u8; 1];
|
||||
|
||||
let blocked_stalled = timeout(
|
||||
Duration::from_millis(40),
|
||||
blocked_client.read_exact(&mut blocked_buf),
|
||||
)
|
||||
.await;
|
||||
assert!(
|
||||
blocked_stalled.is_err(),
|
||||
"blocked user must remain stalled while its lock is held"
|
||||
);
|
||||
|
||||
timeout(Duration::from_millis(250), free_client.read_exact(&mut free_buf))
|
||||
.await
|
||||
.expect("free user must make progress while other user is blocked")
|
||||
.expect("free user read must succeed");
|
||||
assert_eq!(free_buf, [0xC1]);
|
||||
|
||||
drop(held_guard);
|
||||
|
||||
timeout(Duration::from_millis(400), blocked_client.read_exact(&mut blocked_buf))
|
||||
.await
|
||||
.expect("blocked user must resume after release")
|
||||
.expect("blocked user resumed read must succeed");
|
||||
assert_eq!(blocked_buf, [0xB1]);
|
||||
|
||||
drop(blocked_client);
|
||||
drop(blocked_server);
|
||||
drop(free_client);
|
||||
drop(free_server);
|
||||
|
||||
assert!(
|
||||
timeout(Duration::from_secs(1), blocked_task)
|
||||
.await
|
||||
.expect("blocked relay task must complete")
|
||||
.expect("blocked relay task must not panic")
|
||||
.is_ok()
|
||||
);
|
||||
assert!(
|
||||
timeout(Duration::from_secs(1), free_task)
|
||||
.await
|
||||
.expect("free relay task must complete")
|
||||
.expect("free relay task must not panic")
|
||||
.is_ok()
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn stress_jittered_hold_release_cycles_preserve_pipeline_liveness() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let mut seed = 0x5EED_C0DE_2026_0323u64;
|
||||
for round in 0..24u32 {
|
||||
seed ^= seed << 7;
|
||||
seed ^= seed >> 9;
|
||||
seed ^= seed << 8;
|
||||
|
||||
let hold_ms = 2 + (seed % 10);
|
||||
let user = format!("relay-pipeline-fuzz-{}-{round}", std::process::id());
|
||||
let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user);
|
||||
let held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold lock during fuzz round");
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
let (mut client_peer, relay_client) = duplex(1024);
|
||||
let (relay_server, mut server_peer) = duplex(1024);
|
||||
let (client_reader, client_writer) = tokio::io::split(relay_client);
|
||||
let (server_reader, server_writer) = tokio::io::split(relay_server);
|
||||
|
||||
let relay_user = user.clone();
|
||||
let relay_stats = Arc::clone(&stats);
|
||||
let relay_task = tokio::spawn(async move {
|
||||
relay_bidirectional(
|
||||
client_reader,
|
||||
client_writer,
|
||||
server_reader,
|
||||
server_writer,
|
||||
256,
|
||||
256,
|
||||
&relay_user,
|
||||
relay_stats,
|
||||
Some(1024),
|
||||
Arc::new(BufferPool::new()),
|
||||
)
|
||||
.await
|
||||
});
|
||||
|
||||
server_peer
|
||||
.write_all(&[0xD1])
|
||||
.await
|
||||
.expect("server write should queue in fuzz round");
|
||||
|
||||
let mut one = [0u8; 1];
|
||||
let stalled = timeout(Duration::from_millis(30), client_peer.read_exact(&mut one)).await;
|
||||
assert!(stalled.is_err(), "held phase must stall same-user relay");
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(hold_ms)).await;
|
||||
drop(held_guard);
|
||||
|
||||
timeout(Duration::from_millis(400), client_peer.read_exact(&mut one))
|
||||
.await
|
||||
.expect("released phase must resume same-user relay")
|
||||
.expect("released phase read must succeed");
|
||||
assert_eq!(one, [0xD1]);
|
||||
|
||||
drop(client_peer);
|
||||
drop(server_peer);
|
||||
|
||||
assert!(
|
||||
timeout(Duration::from_secs(1), relay_task)
|
||||
.await
|
||||
.expect("fuzz relay task must complete")
|
||||
.expect("fuzz relay task must not panic")
|
||||
.is_ok()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,213 @@
|
|||
use super::relay_bidirectional;
|
||||
use crate::stats::Stats;
|
||||
use crate::stream::BufferPool;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
|
||||
use tokio::sync::{Barrier, watch};
|
||||
use tokio::time::{Duration, Instant, timeout};
|
||||
|
||||
fn quota_test_guard() -> impl Drop {
|
||||
super::quota_user_lock_test_scope()
|
||||
}
|
||||
|
||||
fn percentile_index(len: usize, percentile: usize) -> usize {
|
||||
((len * percentile) / 100).min(len.saturating_sub(1))
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn micro_benchmark_pipeline_release_to_delivery_latency_stays_bounded() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let rounds = 64usize;
|
||||
let user = format!("relay-pipeline-latency-single-{}", std::process::id());
|
||||
let mut samples_ms = Vec::with_capacity(rounds);
|
||||
|
||||
for round in 0..rounds {
|
||||
let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user);
|
||||
let held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold shared cross-mode lock before round");
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
let (mut client_peer, relay_client) = duplex(1024);
|
||||
let (relay_server, mut server_peer) = duplex(1024);
|
||||
let (client_reader, client_writer) = tokio::io::split(relay_client);
|
||||
let (server_reader, server_writer) = tokio::io::split(relay_server);
|
||||
|
||||
let relay_user = user.clone();
|
||||
let relay_stats = Arc::clone(&stats);
|
||||
let relay_task = tokio::spawn(async move {
|
||||
relay_bidirectional(
|
||||
client_reader,
|
||||
client_writer,
|
||||
server_reader,
|
||||
server_writer,
|
||||
256,
|
||||
256,
|
||||
&relay_user,
|
||||
relay_stats,
|
||||
Some(2048),
|
||||
Arc::new(BufferPool::new()),
|
||||
)
|
||||
.await
|
||||
});
|
||||
|
||||
server_peer
|
||||
.write_all(&[(round as u8) ^ 0xA5])
|
||||
.await
|
||||
.expect("server write should queue before release");
|
||||
|
||||
let release_at = Instant::now();
|
||||
drop(held_guard);
|
||||
|
||||
let mut one = [0u8; 1];
|
||||
timeout(Duration::from_millis(450), client_peer.read_exact(&mut one))
|
||||
.await
|
||||
.expect("client must receive queued byte after release")
|
||||
.expect("queued byte read must succeed");
|
||||
samples_ms.push(release_at.elapsed().as_millis() as u64);
|
||||
|
||||
drop(client_peer);
|
||||
drop(server_peer);
|
||||
|
||||
let relay_result = timeout(Duration::from_secs(1), relay_task)
|
||||
.await
|
||||
.expect("relay task must complete")
|
||||
.expect("relay task must not panic");
|
||||
assert!(relay_result.is_ok());
|
||||
}
|
||||
|
||||
samples_ms.sort_unstable();
|
||||
let p50_ms = samples_ms[percentile_index(samples_ms.len(), 50)];
|
||||
let p95_ms = samples_ms[percentile_index(samples_ms.len(), 95)];
|
||||
|
||||
assert!(
|
||||
p50_ms <= 45,
|
||||
"single-flow release latency p50 must stay bounded; p50_ms={p50_ms}, samples={samples_ms:?}"
|
||||
);
|
||||
assert!(
|
||||
p95_ms <= 130,
|
||||
"single-flow release latency p95 must stay bounded; p95_ms={p95_ms}, samples={samples_ms:?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn stress_128_waiter_pipeline_release_latency_p95_stays_bounded() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let waiters = 128usize;
|
||||
let user = format!("relay-pipeline-latency-fanout-{}", std::process::id());
|
||||
let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user);
|
||||
let held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold shared lock before fanout release benchmark");
|
||||
|
||||
let ready_barrier = Arc::new(Barrier::new(waiters + 1));
|
||||
let release_at = Arc::new(Mutex::new(None::<Instant>));
|
||||
let (release_tx, release_rx) = watch::channel(false);
|
||||
let mut tasks = Vec::with_capacity(waiters);
|
||||
|
||||
for idx in 0..waiters {
|
||||
let user = user.clone();
|
||||
let barrier = Arc::clone(&ready_barrier);
|
||||
let release_at = Arc::clone(&release_at);
|
||||
let mut release_rx = release_rx.clone();
|
||||
|
||||
tasks.push(tokio::spawn(async move {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let (mut client_peer, relay_client) = duplex(512);
|
||||
let (relay_server, mut server_peer) = duplex(512);
|
||||
let (client_reader, client_writer) = tokio::io::split(relay_client);
|
||||
let (server_reader, server_writer) = tokio::io::split(relay_server);
|
||||
|
||||
let relay_user = user;
|
||||
let relay_stats = Arc::clone(&stats);
|
||||
let relay_task = tokio::spawn(async move {
|
||||
relay_bidirectional(
|
||||
client_reader,
|
||||
client_writer,
|
||||
server_reader,
|
||||
server_writer,
|
||||
256,
|
||||
256,
|
||||
&relay_user,
|
||||
relay_stats,
|
||||
Some(2048),
|
||||
Arc::new(BufferPool::new()),
|
||||
)
|
||||
.await
|
||||
});
|
||||
|
||||
server_peer
|
||||
.write_all(&[(idx as u8) ^ 0x5A])
|
||||
.await
|
||||
.expect("fanout server write should queue before release");
|
||||
|
||||
barrier.wait().await;
|
||||
release_rx
|
||||
.changed()
|
||||
.await
|
||||
.expect("release signal should remain available");
|
||||
|
||||
let started = {
|
||||
let guard = release_at.lock().unwrap_or_else(|poison| poison.into_inner());
|
||||
guard.expect("release timestamp must be populated before signal")
|
||||
};
|
||||
|
||||
let mut one = [0u8; 1];
|
||||
timeout(Duration::from_millis(900), client_peer.read_exact(&mut one))
|
||||
.await
|
||||
.expect("fanout waiter must receive queued byte after release")
|
||||
.expect("fanout waiter read must succeed");
|
||||
|
||||
drop(client_peer);
|
||||
drop(server_peer);
|
||||
|
||||
let relay_result = timeout(Duration::from_secs(2), relay_task)
|
||||
.await
|
||||
.expect("fanout relay task must complete")
|
||||
.expect("fanout relay task must not panic");
|
||||
assert!(relay_result.is_ok());
|
||||
|
||||
started.elapsed().as_millis() as u64
|
||||
}));
|
||||
}
|
||||
|
||||
ready_barrier.wait().await;
|
||||
{
|
||||
let mut guard = release_at.lock().unwrap_or_else(|poison| poison.into_inner());
|
||||
*guard = Some(Instant::now());
|
||||
}
|
||||
drop(held_guard);
|
||||
release_tx
|
||||
.send(true)
|
||||
.expect("release broadcast must succeed");
|
||||
|
||||
let mut samples_ms = Vec::with_capacity(waiters);
|
||||
timeout(Duration::from_secs(8), async {
|
||||
for task in tasks {
|
||||
let elapsed = task.await.expect("fanout waiter must not panic");
|
||||
samples_ms.push(elapsed);
|
||||
}
|
||||
})
|
||||
.await
|
||||
.expect("fanout benchmark must complete in bounded time");
|
||||
|
||||
samples_ms.sort_unstable();
|
||||
let p50_ms = samples_ms[percentile_index(samples_ms.len(), 50)];
|
||||
let p95_ms = samples_ms[percentile_index(samples_ms.len(), 95)];
|
||||
let max_ms = *samples_ms.last().unwrap_or(&0);
|
||||
|
||||
assert!(
|
||||
p50_ms <= 120,
|
||||
"fanout release latency p50 must stay bounded; p50_ms={p50_ms}, p95_ms={p95_ms}, max_ms={max_ms}"
|
||||
);
|
||||
assert!(
|
||||
p95_ms <= 260,
|
||||
"fanout release latency p95 must stay bounded; p50_ms={p50_ms}, p95_ms={p95_ms}, max_ms={max_ms}"
|
||||
);
|
||||
assert!(
|
||||
max_ms <= 700,
|
||||
"fanout release latency max must stay bounded; p50_ms={p50_ms}, p95_ms={p95_ms}, max_ms={max_ms}"
|
||||
);
|
||||
}
|
||||
|
|
@ -0,0 +1,604 @@
|
|||
use super::*;
|
||||
use crate::stats::Stats;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
|
||||
use std::task::{Context, Poll, Waker};
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadBuf};
|
||||
use tokio::sync::Barrier;
|
||||
use tokio::time::{Duration, timeout};
|
||||
|
||||
#[derive(Default)]
|
||||
struct WakeCounter {
|
||||
wakes: AtomicUsize,
|
||||
}
|
||||
|
||||
impl std::task::Wake for WakeCounter {
|
||||
fn wake(self: Arc<Self>) {
|
||||
self.wakes.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
fn wake_by_ref(self: &Arc<Self>) {
|
||||
self.wakes.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
fn quota_test_guard() -> impl Drop {
|
||||
super::quota_user_lock_test_scope()
|
||||
}
|
||||
|
||||
fn build_context() -> (Arc<WakeCounter>, Context<'static>) {
|
||||
let wake_counter = Arc::new(WakeCounter::default());
|
||||
let waker = Waker::from(Arc::clone(&wake_counter));
|
||||
let leaked_waker: &'static Waker = Box::leak(Box::new(waker));
|
||||
(wake_counter, Context::from_waker(leaked_waker))
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn positive_cross_mode_uncontended_writer_progresses() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
"cross-mode-tdd-uncontended".to_string(),
|
||||
Some(4096),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
|
||||
let result = io.write_all(&[0x11, 0x22]).await;
|
||||
assert!(result.is_ok(), "uncontended writer must progress");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn adversarial_held_cross_mode_lock_blocks_writer_even_if_local_lock_free() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("cross-mode-tdd-held-{}", std::process::id());
|
||||
let held = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
|
||||
let _held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold cross-mode lock before polling writer");
|
||||
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
Some(4096),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
|
||||
let wake_counter = Arc::new(WakeCounter::default());
|
||||
let waker = Waker::from(Arc::clone(&wake_counter));
|
||||
let mut cx = Context::from_waker(&waker);
|
||||
|
||||
let poll = Pin::new(&mut io).poll_write(&mut cx, &[0xAA]);
|
||||
assert!(poll.is_pending(), "writer must not bypass held cross-mode lock");
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn integration_parallel_waiters_resume_after_cross_mode_release() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("cross-mode-tdd-resume-{}", std::process::id());
|
||||
let held = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
|
||||
let held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold cross-mode lock before launching waiters");
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
let mut waiters = Vec::new();
|
||||
for _ in 0..16 {
|
||||
let stats = Arc::clone(&stats);
|
||||
let user = user.clone();
|
||||
waiters.push(tokio::spawn(async move {
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
stats,
|
||||
user,
|
||||
Some(4096),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
io.write_all(&[0x7F]).await
|
||||
}));
|
||||
}
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(5)).await;
|
||||
drop(held_guard);
|
||||
|
||||
timeout(Duration::from_secs(1), async {
|
||||
for waiter in waiters {
|
||||
let result = waiter.await.expect("waiter task must not panic");
|
||||
assert!(result.is_ok(), "waiter must complete after cross-mode release");
|
||||
}
|
||||
})
|
||||
.await
|
||||
.expect("all waiters must complete in bounded time");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn adversarial_cross_mode_contention_wake_budget_stays_bounded() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("cross-mode-tdd-wakes-{}", std::process::id());
|
||||
let held = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
|
||||
let _held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold cross-mode lock before polling");
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
let mut ios = Vec::new();
|
||||
let mut counters = Vec::new();
|
||||
for _ in 0..20 {
|
||||
ios.push(StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::clone(&stats),
|
||||
user.clone(),
|
||||
Some(2048),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
));
|
||||
}
|
||||
|
||||
for io in &mut ios {
|
||||
let wake_counter = Arc::new(WakeCounter::default());
|
||||
let waker = Waker::from(Arc::clone(&wake_counter));
|
||||
let mut cx = Context::from_waker(&waker);
|
||||
let poll = Pin::new(io).poll_write(&mut cx, &[0x33]);
|
||||
assert!(poll.is_pending());
|
||||
counters.push(wake_counter);
|
||||
}
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(25)).await;
|
||||
let total_wakes: usize = counters
|
||||
.iter()
|
||||
.map(|counter| counter.wakes.load(Ordering::Relaxed))
|
||||
.sum();
|
||||
|
||||
assert!(
|
||||
total_wakes <= 20 * 4,
|
||||
"cross-mode contention should not create wake storms; wakes={total_wakes}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn light_fuzz_cross_mode_release_timing_preserves_read_write_liveness() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let mut seed = 0xC0DE_BAAD_2026_0322u64;
|
||||
for round in 0..16u32 {
|
||||
seed ^= seed << 7;
|
||||
seed ^= seed >> 9;
|
||||
seed ^= seed << 8;
|
||||
|
||||
let sleep_ms = 2 + (seed as u64 % 8);
|
||||
let user = format!("cross-mode-tdd-fuzz-{}-{round}", std::process::id());
|
||||
let held = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
|
||||
let held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold cross-mode lock in fuzz round");
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
let user_reader = user.clone();
|
||||
let reader_task = tokio::spawn(async move {
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::empty(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::clone(&stats),
|
||||
user_reader,
|
||||
Some(4096),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
let mut one = [0u8; 1];
|
||||
io.read(&mut one).await
|
||||
});
|
||||
|
||||
let user_writer = user.clone();
|
||||
let writer_task = tokio::spawn(async move {
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user_writer,
|
||||
Some(4096),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
io.write_all(&[0x44]).await
|
||||
});
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(sleep_ms)).await;
|
||||
drop(held_guard);
|
||||
|
||||
let read_done = timeout(Duration::from_millis(350), reader_task)
|
||||
.await
|
||||
.expect("reader task must complete after release")
|
||||
.expect("reader task must not panic");
|
||||
assert!(read_done.is_ok());
|
||||
|
||||
let write_done = timeout(Duration::from_millis(350), writer_task)
|
||||
.await
|
||||
.expect("writer task must complete after release")
|
||||
.expect("writer task must not panic");
|
||||
assert!(write_done.is_ok());
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn integration_middle_lock_blocks_relay_reader_for_same_user() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("cross-mode-middle-reader-block-{}", std::process::id());
|
||||
let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user);
|
||||
let _held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold middle-relay shared lock");
|
||||
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::empty(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
Some(1024),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
|
||||
let (_wake_counter, mut cx) = build_context();
|
||||
let mut one = [0u8; 1];
|
||||
let mut buf = ReadBuf::new(&mut one);
|
||||
let poll = Pin::new(&mut io).poll_read(&mut cx, &mut buf);
|
||||
assert!(poll.is_pending());
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn integration_middle_lock_release_unblocks_relay_reader() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("cross-mode-middle-reader-release-{}", std::process::id());
|
||||
let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user);
|
||||
let held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold middle-relay shared lock");
|
||||
|
||||
let task = tokio::spawn({
|
||||
let user = user.clone();
|
||||
async move {
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::empty(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
Some(1024),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
let mut one = [0u8; 1];
|
||||
io.read(&mut one).await
|
||||
}
|
||||
});
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(5)).await;
|
||||
drop(held_guard);
|
||||
|
||||
let done = timeout(Duration::from_millis(300), task)
|
||||
.await
|
||||
.expect("reader task must complete after release")
|
||||
.expect("reader task must not panic");
|
||||
assert!(done.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn business_different_user_middle_lock_does_not_block_relay_writer() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let held_user = format!("cross-mode-middle-held-{}", std::process::id());
|
||||
let active_user = format!("cross-mode-middle-active-{}", std::process::id());
|
||||
let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&held_user);
|
||||
let _held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold middle-relay lock for other user");
|
||||
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
active_user,
|
||||
Some(1024),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
|
||||
let (_wake_counter, mut cx) = build_context();
|
||||
let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x61]);
|
||||
assert!(matches!(poll, Poll::Ready(Ok(1))));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn edge_quota_none_bypasses_cross_mode_lock_even_when_held() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("cross-mode-none-limit-{}", std::process::id());
|
||||
let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user);
|
||||
let _held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold lock while quota is disabled");
|
||||
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
None,
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
|
||||
let (_wake_counter, mut cx) = build_context();
|
||||
let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x62, 0x63]);
|
||||
assert!(matches!(poll, Poll::Ready(Ok(2))));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn edge_quota_exceeded_flag_short_circuits_before_lock_path() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("cross-mode-pre-exceeded-{}", std::process::id());
|
||||
let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user);
|
||||
let _held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold shared lock before poll");
|
||||
|
||||
let quota_exceeded = Arc::new(AtomicBool::new(true));
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
Some(1024),
|
||||
Arc::clone("a_exceeded),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
|
||||
let (_wake_counter, mut cx) = build_context();
|
||||
let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x64]);
|
||||
assert!(matches!(poll, Poll::Ready(Err(ref e)) if is_quota_io_error(e)));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn adversarial_repoll_while_middle_lock_held_keeps_pending_without_usage_leak() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("cross-mode-repoll-held-{}", std::process::id());
|
||||
let stats = Arc::new(Stats::new());
|
||||
let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user);
|
||||
let _held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold lock for repoll sequence");
|
||||
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::clone(&stats),
|
||||
user.clone(),
|
||||
Some(1024),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
|
||||
let (_wake_counter, mut cx) = build_context();
|
||||
for _ in 0..8 {
|
||||
let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x65]);
|
||||
assert!(poll.is_pending());
|
||||
}
|
||||
|
||||
assert_eq!(stats.get_user_total_octets(&user), 0);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn integration_same_user_mixed_read_write_waiters_resume_after_release() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("cross-mode-mixed-resume-{}", std::process::id());
|
||||
let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user);
|
||||
let held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold lock before spawning mixed waiters");
|
||||
|
||||
let mut tasks = Vec::new();
|
||||
for i in 0..12usize {
|
||||
let user = user.clone();
|
||||
tasks.push(tokio::spawn(async move {
|
||||
if i % 2 == 0 {
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::empty(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
Some(1024),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
let mut b = [0u8; 1];
|
||||
io.read(&mut b).await.map(|_| ())
|
||||
} else {
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
Some(1024),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
io.write_all(&[0x66]).await
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(8)).await;
|
||||
drop(held_guard);
|
||||
|
||||
timeout(Duration::from_secs(1), async {
|
||||
for task in tasks {
|
||||
let result = task.await.expect("mixed waiter task must not panic");
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
})
|
||||
.await
|
||||
.expect("all mixed waiters must finish after release");
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn integration_one_user_blocked_other_user_progresses_under_middle_lock() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let blocked_user = format!("cross-mode-blocked-{}", std::process::id());
|
||||
let free_user = format!("cross-mode-free-{}", std::process::id());
|
||||
let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&blocked_user);
|
||||
let held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold blocked user lock");
|
||||
|
||||
let blocked_task = tokio::spawn({
|
||||
let blocked_user = blocked_user.clone();
|
||||
async move {
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
blocked_user,
|
||||
Some(1024),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
io.write_all(&[0x77]).await
|
||||
}
|
||||
});
|
||||
|
||||
let free_task = tokio::spawn({
|
||||
let free_user = free_user.clone();
|
||||
async move {
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
free_user,
|
||||
Some(1024),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
io.write_all(&[0x78]).await
|
||||
}
|
||||
});
|
||||
|
||||
let free_done = timeout(Duration::from_millis(250), free_task)
|
||||
.await
|
||||
.expect("free user must not be blocked")
|
||||
.expect("free user task must not panic");
|
||||
assert!(free_done.is_ok());
|
||||
|
||||
drop(held_guard);
|
||||
let blocked_done = timeout(Duration::from_secs(1), blocked_task)
|
||||
.await
|
||||
.expect("blocked user must resume after release")
|
||||
.expect("blocked user task must not panic");
|
||||
assert!(blocked_done.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn stress_middle_lock_release_allows_high_waiter_fanout_completion() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("cross-mode-fanout-{}", std::process::id());
|
||||
let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user);
|
||||
let held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold lock before fanout");
|
||||
|
||||
let waiters = 48usize;
|
||||
let gate = Arc::new(Barrier::new(waiters + 1));
|
||||
let mut tasks = Vec::new();
|
||||
for _ in 0..waiters {
|
||||
let user = user.clone();
|
||||
let gate = Arc::clone(&gate);
|
||||
tasks.push(tokio::spawn(async move {
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
Some(1024),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
gate.wait().await;
|
||||
io.write_all(&[0x79]).await
|
||||
}));
|
||||
}
|
||||
|
||||
gate.wait().await;
|
||||
tokio::time::sleep(Duration::from_millis(10)).await;
|
||||
drop(held_guard);
|
||||
|
||||
timeout(Duration::from_secs(2), async {
|
||||
for task in tasks {
|
||||
let result = task.await.expect("fanout task must not panic");
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
})
|
||||
.await
|
||||
.expect("fanout waiters must complete after release");
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn light_fuzz_middle_lock_hold_release_cycles_preserve_same_user_liveness() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let mut seed = 0xA11C_EE55_2026_0323u64;
|
||||
for round in 0..20u32 {
|
||||
seed ^= seed << 7;
|
||||
seed ^= seed >> 9;
|
||||
seed ^= seed << 8;
|
||||
|
||||
let hold_ms = 2 + (seed % 10);
|
||||
let user = format!("cross-mode-middle-fuzz-{}-{round}", std::process::id());
|
||||
let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user);
|
||||
let held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold lock in fuzz round");
|
||||
|
||||
let writer = tokio::spawn({
|
||||
let user = user.clone();
|
||||
async move {
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
Some(1024),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
io.write_all(&[0x7A]).await
|
||||
}
|
||||
});
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(hold_ms)).await;
|
||||
drop(held_guard);
|
||||
|
||||
let done = timeout(Duration::from_millis(400), writer)
|
||||
.await
|
||||
.expect("writer must complete after lock release")
|
||||
.expect("writer task must not panic");
|
||||
assert!(done.is_ok());
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,81 @@
|
|||
use super::*;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
|
||||
use std::task::Waker;
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
#[derive(Default)]
|
||||
struct WakeCounter {
|
||||
wakes: AtomicUsize,
|
||||
}
|
||||
|
||||
impl std::task::Wake for WakeCounter {
|
||||
fn wake(self: Arc<Self>) {
|
||||
self.wakes.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
fn wake_by_ref(self: &Arc<Self>) {
|
||||
self.wakes.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
fn build_context() -> (Arc<WakeCounter>, Context<'static>) {
|
||||
let wake_counter = Arc::new(WakeCounter::default());
|
||||
let waker = Waker::from(Arc::clone(&wake_counter));
|
||||
let leaked_waker: &'static Waker = Box::leak(Box::new(waker));
|
||||
(wake_counter, Context::from_waker(leaked_waker))
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn adversarial_middle_held_cross_mode_lock_blocks_relay_writer() {
|
||||
let _guard = quota_user_lock_test_scope();
|
||||
|
||||
let user = "cross-mode-lock-shared-user";
|
||||
let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(user);
|
||||
let _held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold shared cross-mode lock before relay poll");
|
||||
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(crate::stats::Stats::new()),
|
||||
user.to_string(),
|
||||
Some(1024),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
|
||||
let (_wake_counter, mut cx) = build_context();
|
||||
let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x41, 0x42, 0x43]);
|
||||
|
||||
assert!(
|
||||
matches!(poll, Poll::Pending),
|
||||
"relay writer must not bypass cross-mode lock held by middle-relay path"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn business_cross_mode_lock_uncontended_allows_relay_writer_progress() {
|
||||
let _guard = quota_user_lock_test_scope();
|
||||
|
||||
let user = "cross-mode-lock-progress-user";
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(crate::stats::Stats::new()),
|
||||
user.to_string(),
|
||||
Some(1024),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
|
||||
let (_wake_counter, mut cx) = build_context();
|
||||
let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x51, 0x52]);
|
||||
|
||||
assert!(
|
||||
matches!(poll, Poll::Ready(Ok(2))),
|
||||
"relay writer should progress when shared cross-mode lock is uncontended"
|
||||
);
|
||||
}
|
||||
|
|
@ -0,0 +1,340 @@
|
|||
use super::*;
|
||||
use crate::stats::Stats;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
|
||||
use std::task::{Context, Waker};
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::time::{Duration, Instant, timeout};
|
||||
|
||||
#[derive(Default)]
|
||||
struct WakeCounter {
|
||||
wakes: AtomicUsize,
|
||||
}
|
||||
|
||||
impl std::task::Wake for WakeCounter {
|
||||
fn wake(self: Arc<Self>) {
|
||||
self.wakes.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
fn wake_by_ref(self: &Arc<Self>) {
|
||||
self.wakes.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
fn quota_test_guard() -> impl Drop {
|
||||
super::quota_user_lock_test_scope()
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn positive_uncontended_dual_lock_writer_has_zero_retry_attempt() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
format!("dual-lock-alt-positive-{}", std::process::id()),
|
||||
Some(2048),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
Instant::now(),
|
||||
);
|
||||
|
||||
let write = io.write_all(&[0xAA, 0xBB]).await;
|
||||
assert!(write.is_ok(), "uncontended write must complete");
|
||||
assert_eq!(
|
||||
io.quota_write_retry_attempt, 0,
|
||||
"uncontended write must not advance retry backoff"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn adversarial_alternating_local_and_cross_mode_contention_preserves_backoff_growth() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("dual-lock-alt-adversarial-{}", std::process::id());
|
||||
let local_lock = quota_user_lock(&user);
|
||||
let cross_mode_lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
|
||||
|
||||
let mut local_guard = Some(
|
||||
local_lock
|
||||
.try_lock()
|
||||
.expect("test must hold local quota lock initially"),
|
||||
);
|
||||
let mut cross_guard = None;
|
||||
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
Some(2048),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
Instant::now(),
|
||||
);
|
||||
|
||||
let wake_counter = Arc::new(WakeCounter::default());
|
||||
let waker = Waker::from(Arc::clone(&wake_counter));
|
||||
let mut cx = Context::from_waker(&waker);
|
||||
|
||||
let first = Pin::new(&mut io).poll_write(&mut cx, &[0x11]);
|
||||
assert!(first.is_pending(), "held local lock must block first poll");
|
||||
|
||||
let mut observed_wakes = 0usize;
|
||||
for idx in 0..18usize {
|
||||
tokio::time::sleep(Duration::from_millis(6)).await;
|
||||
|
||||
if idx % 2 == 0 {
|
||||
drop(local_guard.take());
|
||||
cross_guard = Some(
|
||||
cross_mode_lock
|
||||
.try_lock()
|
||||
.expect("cross-mode lock should be acquirable while local lock released"),
|
||||
);
|
||||
} else {
|
||||
drop(cross_guard.take());
|
||||
local_guard = Some(
|
||||
local_lock
|
||||
.try_lock()
|
||||
.expect("local lock should be acquirable while cross lock released"),
|
||||
);
|
||||
}
|
||||
|
||||
let wakes = wake_counter.wakes.load(Ordering::Relaxed);
|
||||
if wakes > observed_wakes {
|
||||
observed_wakes = wakes;
|
||||
let pending = Pin::new(&mut io).poll_write(&mut cx, &[0x12]);
|
||||
assert!(
|
||||
pending.is_pending(),
|
||||
"alternating contention must keep write pending while one lock is held"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
assert!(
|
||||
io.quota_write_retry_attempt >= 2,
|
||||
"alternating contention must still ramp retry backoff; got {}",
|
||||
io.quota_write_retry_attempt
|
||||
);
|
||||
assert!(
|
||||
wake_counter.wakes.load(Ordering::Relaxed) <= 32,
|
||||
"alternating contention must stay wake-rate-limited"
|
||||
);
|
||||
|
||||
drop(local_guard);
|
||||
drop(cross_guard);
|
||||
let ready = Pin::new(&mut io).poll_write(&mut cx, &[0x13]);
|
||||
assert!(ready.is_ready(), "writer must resume after both locks released");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn edge_retry_scheduler_resets_after_alternating_contention_clears() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("dual-lock-alt-edge-reset-{}", std::process::id());
|
||||
let local_lock = quota_user_lock(&user);
|
||||
let local_guard = local_lock
|
||||
.try_lock()
|
||||
.expect("test must hold local lock for edge scenario");
|
||||
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
Some(2048),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
Instant::now(),
|
||||
);
|
||||
|
||||
let wake_counter = Arc::new(WakeCounter::default());
|
||||
let waker = Waker::from(Arc::clone(&wake_counter));
|
||||
let mut cx = Context::from_waker(&waker);
|
||||
|
||||
let first = Pin::new(&mut io).poll_write(&mut cx, &[0x21]);
|
||||
assert!(first.is_pending());
|
||||
tokio::time::sleep(Duration::from_millis(15)).await;
|
||||
if wake_counter.wakes.load(Ordering::Relaxed) > 0 {
|
||||
let next = Pin::new(&mut io).poll_write(&mut cx, &[0x22]);
|
||||
assert!(next.is_pending());
|
||||
}
|
||||
|
||||
drop(local_guard);
|
||||
|
||||
let ready = Pin::new(&mut io).poll_write(&mut cx, &[0x23]);
|
||||
assert!(ready.is_ready());
|
||||
assert_eq!(
|
||||
io.quota_write_retry_attempt, 0,
|
||||
"successful dual-lock acquisition must reset retry scheduler"
|
||||
);
|
||||
assert!(!io.quota_write_wake_scheduled);
|
||||
assert!(io.quota_write_retry_sleep.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn integration_cross_mode_waiters_remain_live_under_alternating_contention_then_resume() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("dual-lock-alt-integration-{}", std::process::id());
|
||||
let local_lock = quota_user_lock(&user);
|
||||
let cross_mode_lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
|
||||
|
||||
let mut waiters = Vec::new();
|
||||
for _ in 0..16usize {
|
||||
let user = user.clone();
|
||||
waiters.push(tokio::spawn(async move {
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
Some(2048),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
Instant::now(),
|
||||
);
|
||||
timeout(Duration::from_secs(2), io.write_all(&[0x31])).await
|
||||
}));
|
||||
}
|
||||
|
||||
let mut local_guard = Some(
|
||||
local_lock
|
||||
.try_lock()
|
||||
.expect("integration toggle must acquire local lock first"),
|
||||
);
|
||||
let mut cross_guard = None;
|
||||
|
||||
for idx in 0..24usize {
|
||||
tokio::time::sleep(Duration::from_millis(4)).await;
|
||||
if idx % 2 == 0 {
|
||||
drop(local_guard.take());
|
||||
cross_guard = cross_mode_lock.try_lock().ok();
|
||||
} else {
|
||||
drop(cross_guard.take());
|
||||
local_guard = local_lock.try_lock().ok();
|
||||
}
|
||||
}
|
||||
|
||||
drop(local_guard);
|
||||
drop(cross_guard);
|
||||
|
||||
for waiter in waiters {
|
||||
let done = waiter.await.expect("waiter task must not panic");
|
||||
assert!(
|
||||
done.is_ok(),
|
||||
"waiter must finish once alternating contention window ends"
|
||||
);
|
||||
assert!(done.expect("waiter timeout must not fire").is_ok());
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn light_fuzz_alternating_contention_matrix_preserves_lock_gating() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("dual-lock-alt-fuzz-{}", std::process::id());
|
||||
let local_lock = quota_user_lock(&user);
|
||||
let cross_mode_lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
|
||||
let mut seed = 0xD00D_BAAD_F00D_2026u64;
|
||||
|
||||
for _round in 0..64u32 {
|
||||
seed ^= seed << 7;
|
||||
seed ^= seed >> 9;
|
||||
seed ^= seed << 8;
|
||||
|
||||
let hold_mode = (seed % 3) as u8;
|
||||
let local_guard = if hold_mode == 0 {
|
||||
Some(
|
||||
local_lock
|
||||
.try_lock()
|
||||
.expect("fuzz local lock should be acquirable"),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let cross_guard = if hold_mode == 1 {
|
||||
Some(
|
||||
cross_mode_lock
|
||||
.try_lock()
|
||||
.expect("fuzz cross lock should be acquirable"),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user.clone(),
|
||||
Some(1024),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
Instant::now(),
|
||||
);
|
||||
|
||||
let write = timeout(Duration::from_millis(35), io.write_all(&[0x51])).await;
|
||||
if hold_mode == 2 {
|
||||
assert!(write.is_ok(), "unheld fuzz round must make progress");
|
||||
assert!(write.expect("unheld round timeout").is_ok());
|
||||
} else {
|
||||
assert!(
|
||||
write.is_err(),
|
||||
"held-lock fuzz round must remain pending inside bounded window"
|
||||
);
|
||||
}
|
||||
|
||||
drop(local_guard);
|
||||
drop(cross_guard);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn stress_fanout_alternating_contention_recovers_without_hanging() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("dual-lock-alt-stress-{}", std::process::id());
|
||||
let local_lock = quota_user_lock(&user);
|
||||
let cross_mode_lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
|
||||
|
||||
let mut waiters = Vec::new();
|
||||
for _ in 0..48usize {
|
||||
let user = user.clone();
|
||||
waiters.push(tokio::spawn(async move {
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
Some(4096),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
Instant::now(),
|
||||
);
|
||||
timeout(Duration::from_secs(3), io.write_all(&[0xA0, 0xA1])).await
|
||||
}));
|
||||
}
|
||||
|
||||
let mut local_guard = Some(
|
||||
local_lock
|
||||
.try_lock()
|
||||
.expect("stress toggle must acquire local lock first"),
|
||||
);
|
||||
let mut cross_guard = None;
|
||||
for idx in 0..40usize {
|
||||
tokio::time::sleep(Duration::from_millis(3)).await;
|
||||
if idx % 2 == 0 {
|
||||
drop(local_guard.take());
|
||||
cross_guard = cross_mode_lock.try_lock().ok();
|
||||
} else {
|
||||
drop(cross_guard.take());
|
||||
local_guard = local_lock.try_lock().ok();
|
||||
}
|
||||
}
|
||||
|
||||
drop(local_guard);
|
||||
drop(cross_guard);
|
||||
|
||||
for waiter in waiters {
|
||||
let done = waiter.await.expect("stress waiter task must not panic");
|
||||
assert!(done.is_ok(), "stress waiter timed out under alternating contention");
|
||||
assert!(done.expect("stress waiter timeout should not fire").is_ok());
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,74 @@
|
|||
use super::*;
|
||||
use crate::stats::Stats;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
|
||||
use std::task::{Context, Waker};
|
||||
use tokio::time::{Duration, Instant};
|
||||
|
||||
#[derive(Default)]
|
||||
struct WakeCounter {
|
||||
wakes: AtomicUsize,
|
||||
}
|
||||
|
||||
impl std::task::Wake for WakeCounter {
|
||||
fn wake(self: Arc<Self>) {
|
||||
self.wakes.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
fn wake_by_ref(self: &Arc<Self>) {
|
||||
self.wakes.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
fn quota_test_guard() -> impl Drop {
|
||||
super::quota_user_lock_test_scope()
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn adversarial_cross_mode_only_contention_backoff_attempt_must_ramp() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("dual-lock-backoff-{}", std::process::id());
|
||||
let cross_mode_lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
|
||||
let held_cross_mode_guard = cross_mode_lock
|
||||
.try_lock()
|
||||
.expect("test must hold cross-mode lock before polling");
|
||||
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
Some(2048),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
Instant::now(),
|
||||
);
|
||||
|
||||
let wake_counter = Arc::new(WakeCounter::default());
|
||||
let waker = Waker::from(Arc::clone(&wake_counter));
|
||||
let mut cx = Context::from_waker(&waker);
|
||||
|
||||
let first = Pin::new(&mut io).poll_write(&mut cx, &[0xAA]);
|
||||
assert!(first.is_pending(), "held cross-mode lock must block writer");
|
||||
|
||||
let started = Instant::now();
|
||||
let mut last_wakes = 0usize;
|
||||
while started.elapsed() < Duration::from_millis(120) {
|
||||
let wakes = wake_counter.wakes.load(Ordering::Relaxed);
|
||||
if wakes > last_wakes {
|
||||
last_wakes = wakes;
|
||||
let next = Pin::new(&mut io).poll_write(&mut cx, &[0xAB]);
|
||||
assert!(next.is_pending(), "writer must remain blocked while lock is held");
|
||||
}
|
||||
tokio::time::sleep(Duration::from_millis(1)).await;
|
||||
}
|
||||
|
||||
assert!(
|
||||
io.quota_write_retry_attempt >= 2,
|
||||
"retry attempt must ramp under sustained second-lock contention; got {}",
|
||||
io.quota_write_retry_attempt
|
||||
);
|
||||
|
||||
drop(held_cross_mode_guard);
|
||||
}
|
||||
|
|
@ -0,0 +1,325 @@
|
|||
use super::*;
|
||||
use crate::stats::Stats;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
|
||||
use std::task::{Context, Waker};
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadBuf};
|
||||
use tokio::time::{Duration, Instant, timeout};
|
||||
|
||||
#[derive(Default)]
|
||||
struct WakeCounter {
|
||||
wakes: AtomicUsize,
|
||||
}
|
||||
|
||||
impl std::task::Wake for WakeCounter {
|
||||
fn wake(self: Arc<Self>) {
|
||||
self.wakes.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
fn wake_by_ref(self: &Arc<Self>) {
|
||||
self.wakes.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
fn quota_test_guard() -> impl Drop {
|
||||
super::quota_user_lock_test_scope()
|
||||
}
|
||||
|
||||
fn build_context() -> (Arc<WakeCounter>, Context<'static>) {
|
||||
let wake_counter = Arc::new(WakeCounter::default());
|
||||
let waker = Waker::from(Arc::clone(&wake_counter));
|
||||
let leaked_waker: &'static Waker = Box::leak(Box::new(waker));
|
||||
(wake_counter, Context::from_waker(leaked_waker))
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn positive_uncontended_dual_locks_writer_completes_without_retry_state() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
format!("dual-lock-positive-{}", std::process::id()),
|
||||
Some(4096),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
Instant::now(),
|
||||
);
|
||||
|
||||
let (_wake_counter, mut cx) = build_context();
|
||||
let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x01, 0x02, 0x03]);
|
||||
assert!(poll.is_ready());
|
||||
assert_eq!(io.quota_write_retry_attempt, 0);
|
||||
assert!(!io.quota_write_wake_scheduled);
|
||||
assert!(io.quota_write_retry_sleep.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn negative_local_lock_contention_read_retry_attempt_ramps() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("dual-lock-local-contention-{}", std::process::id());
|
||||
let held = quota_user_lock(&user);
|
||||
let held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold local quota lock before polling");
|
||||
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::empty(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
Some(2048),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
Instant::now(),
|
||||
);
|
||||
|
||||
let (wake_counter, mut cx) = build_context();
|
||||
let mut one = [0u8; 1];
|
||||
let mut buf = ReadBuf::new(&mut one);
|
||||
let first = Pin::new(&mut io).poll_read(&mut cx, &mut buf);
|
||||
assert!(first.is_pending());
|
||||
|
||||
let started = Instant::now();
|
||||
let mut observed = 0usize;
|
||||
while started.elapsed() < Duration::from_millis(120) {
|
||||
let wakes = wake_counter.wakes.load(Ordering::Relaxed);
|
||||
if wakes > observed {
|
||||
observed = wakes;
|
||||
let mut step_buf = ReadBuf::new(&mut one);
|
||||
let next = Pin::new(&mut io).poll_read(&mut cx, &mut step_buf);
|
||||
assert!(next.is_pending());
|
||||
}
|
||||
tokio::time::sleep(Duration::from_millis(1)).await;
|
||||
}
|
||||
|
||||
assert!(
|
||||
io.quota_read_retry_attempt >= 2,
|
||||
"retry attempt must ramp under sustained local-lock contention; got {}",
|
||||
io.quota_read_retry_attempt
|
||||
);
|
||||
|
||||
drop(held_guard);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn edge_cross_mode_contention_release_resets_retry_scheduler_on_success() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("dual-lock-reset-{}", std::process::id());
|
||||
let cross_mode = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
|
||||
let held_guard = cross_mode
|
||||
.try_lock()
|
||||
.expect("test must hold cross-mode lock before polling");
|
||||
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
Some(2048),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
Instant::now(),
|
||||
);
|
||||
|
||||
let (wake_counter, mut cx) = build_context();
|
||||
let first = Pin::new(&mut io).poll_write(&mut cx, &[0x10]);
|
||||
assert!(first.is_pending());
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(20)).await;
|
||||
if wake_counter.wakes.load(Ordering::Relaxed) > 0 {
|
||||
let next = Pin::new(&mut io).poll_write(&mut cx, &[0x11]);
|
||||
assert!(next.is_pending());
|
||||
}
|
||||
|
||||
drop(held_guard);
|
||||
let ready = Pin::new(&mut io).poll_write(&mut cx, &[0x12]);
|
||||
assert!(ready.is_ready());
|
||||
assert_eq!(io.quota_write_retry_attempt, 0);
|
||||
assert!(!io.quota_write_wake_scheduled);
|
||||
assert!(io.quota_write_retry_sleep.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn adversarial_cross_mode_hold_blocks_many_waiters_without_usage_leak() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("dual-lock-adversarial-{}", std::process::id());
|
||||
let stats = Arc::new(Stats::new());
|
||||
let held = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
|
||||
let held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold cross-mode lock before launching waiters");
|
||||
|
||||
let mut tasks = Vec::new();
|
||||
for _ in 0..24usize {
|
||||
let stats = Arc::clone(&stats);
|
||||
let user = user.clone();
|
||||
tasks.push(tokio::spawn(async move {
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
stats,
|
||||
user,
|
||||
Some(1024),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
Instant::now(),
|
||||
);
|
||||
timeout(Duration::from_millis(40), io.write_all(&[0x33])).await
|
||||
}));
|
||||
}
|
||||
|
||||
for task in tasks {
|
||||
let timed = task.await.expect("waiter task must not panic");
|
||||
assert!(timed.is_err(), "held cross-mode lock must keep waiter pending");
|
||||
}
|
||||
|
||||
assert_eq!(stats.get_user_total_octets(&user), 0);
|
||||
drop(held_guard);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn integration_waiters_resume_after_cross_mode_release() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("dual-lock-integration-{}", std::process::id());
|
||||
let held = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
|
||||
let held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold cross-mode lock before starting waiter");
|
||||
|
||||
let task = tokio::spawn({
|
||||
let user = user.clone();
|
||||
async move {
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
Some(1024),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
Instant::now(),
|
||||
);
|
||||
io.write_all(&[0x44]).await
|
||||
}
|
||||
});
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(10)).await;
|
||||
drop(held_guard);
|
||||
|
||||
let done = timeout(Duration::from_secs(1), task)
|
||||
.await
|
||||
.expect("waiter task must complete after release")
|
||||
.expect("waiter task must not panic");
|
||||
assert!(done.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn light_fuzz_randomized_lock_holds_preserve_liveness_and_quota_bounds() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("dual-lock-fuzz-{}", std::process::id());
|
||||
let stats = Arc::new(Stats::new());
|
||||
let mut seed = 0xA55A_55AA_C3D2_E1F0u64;
|
||||
|
||||
for _round in 0..48u32 {
|
||||
seed ^= seed << 7;
|
||||
seed ^= seed >> 9;
|
||||
seed ^= seed << 8;
|
||||
|
||||
let hold_mode = (seed % 3) as u8;
|
||||
let mut local_lock = None;
|
||||
let mut cross_lock = None;
|
||||
let mut local_guard = None;
|
||||
let mut cross_guard = None;
|
||||
|
||||
if hold_mode == 0 {
|
||||
local_lock = Some(quota_user_lock(&user));
|
||||
local_guard = Some(
|
||||
local_lock
|
||||
.as_ref()
|
||||
.expect("local lock should be present")
|
||||
.try_lock()
|
||||
.expect("local lock should be acquirable in fuzz round"),
|
||||
);
|
||||
} else if hold_mode == 1 {
|
||||
cross_lock = Some(crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(
|
||||
&user,
|
||||
));
|
||||
cross_guard = Some(
|
||||
cross_lock
|
||||
.as_ref()
|
||||
.expect("cross lock should be present")
|
||||
.try_lock()
|
||||
.expect("cross lock should be acquirable in fuzz round"),
|
||||
);
|
||||
}
|
||||
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::clone(&stats),
|
||||
user.clone(),
|
||||
Some(4096),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
Instant::now(),
|
||||
);
|
||||
|
||||
let write = timeout(Duration::from_millis(25), io.write_all(&[0x7A])).await;
|
||||
if hold_mode == 2 {
|
||||
assert!(write.is_ok(), "unheld round must make progress");
|
||||
} else {
|
||||
assert!(write.is_err(), "held-lock round must stay blocked within timeout");
|
||||
}
|
||||
|
||||
drop(local_guard);
|
||||
drop(cross_guard);
|
||||
drop(local_lock);
|
||||
drop(cross_lock);
|
||||
}
|
||||
|
||||
assert!(stats.get_user_total_octets(&user) <= 4096);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn stress_fanout_waiters_complete_after_release_without_panics() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("dual-lock-stress-{}", std::process::id());
|
||||
let held = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
|
||||
let held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold cross-mode lock before stress fanout");
|
||||
|
||||
let waiters = 64usize;
|
||||
let mut tasks = Vec::new();
|
||||
for _ in 0..waiters {
|
||||
let user = user.clone();
|
||||
tasks.push(tokio::spawn(async move {
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::empty(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
Some(1024),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
Instant::now(),
|
||||
);
|
||||
let mut one = [0u8; 1];
|
||||
io.read(&mut one).await
|
||||
}));
|
||||
}
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(12)).await;
|
||||
drop(held_guard);
|
||||
|
||||
timeout(Duration::from_secs(2), async {
|
||||
for task in tasks {
|
||||
let result = task.await.expect("stress waiter task must not panic");
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
})
|
||||
.await
|
||||
.expect("all stress waiters must complete after release");
|
||||
}
|
||||
|
|
@ -0,0 +1,128 @@
|
|||
use super::*;
|
||||
use crate::stats::Stats;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::time::{Duration, timeout};
|
||||
|
||||
fn quota_test_guard() -> impl Drop {
|
||||
super::quota_user_lock_test_scope()
|
||||
}
|
||||
|
||||
fn make_stats_io(user: String) -> StatsIo<tokio::io::Sink> {
|
||||
StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
Some(4096),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
)
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn light_fuzz_1024_round_hold_release_cycles_preserve_same_user_liveness() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("dual-lock-race-fuzz-{}", std::process::id());
|
||||
let mut seed = 0xD1CE_BAAD_5EED_1234u64;
|
||||
|
||||
for round in 0..1024u32 {
|
||||
seed ^= seed << 7;
|
||||
seed ^= seed >> 9;
|
||||
seed ^= seed << 8;
|
||||
|
||||
let hold = (seed & 1) == 0;
|
||||
let hold_ms = (seed % 3) as u64;
|
||||
|
||||
let maybe_lock = if hold {
|
||||
Some(crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(
|
||||
&user,
|
||||
))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let maybe_guard = maybe_lock.as_ref().map(|lock| {
|
||||
lock.try_lock()
|
||||
.expect("cross-mode lock must be acquirable in fuzz round")
|
||||
});
|
||||
|
||||
if hold {
|
||||
let mut blocked_io = make_stats_io(user.clone());
|
||||
let blocked = timeout(Duration::from_millis(5), blocked_io.write_all(&[0xA5])).await;
|
||||
assert!(
|
||||
blocked.is_err(),
|
||||
"held round must block waiter before lock release (round={round})"
|
||||
);
|
||||
|
||||
if hold_ms > 0 {
|
||||
tokio::time::sleep(Duration::from_millis(hold_ms)).await;
|
||||
}
|
||||
} else {
|
||||
let mut free_io = make_stats_io(user.clone());
|
||||
let free = timeout(Duration::from_millis(120), free_io.write_all(&[0xA5])).await;
|
||||
assert!(
|
||||
free.is_ok(),
|
||||
"unheld round must complete promptly (round={round})"
|
||||
);
|
||||
assert!(free.expect("unheld round should complete").is_ok());
|
||||
}
|
||||
|
||||
drop(maybe_guard);
|
||||
|
||||
let done = timeout(Duration::from_millis(350), async {
|
||||
let user = user.clone();
|
||||
let mut io = make_stats_io(user);
|
||||
io.write_all(&[0xA6]).await
|
||||
})
|
||||
.await
|
||||
.expect("post-release write must complete in bounded time");
|
||||
assert!(done.is_ok());
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn stress_jittered_three_waiter_rounds_do_not_starve_after_release() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("dual-lock-race-stress-{}", std::process::id());
|
||||
let mut seed = 0xC0FF_EE77_4444_9999u64;
|
||||
|
||||
for round in 0..256u32 {
|
||||
seed ^= seed << 7;
|
||||
seed ^= seed >> 9;
|
||||
seed ^= seed << 8;
|
||||
|
||||
let hold_ms = (seed % 4) as u64;
|
||||
let lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
|
||||
let guard = lock
|
||||
.try_lock()
|
||||
.expect("cross-mode lock must be acquirable at round start");
|
||||
|
||||
let mut waiters = Vec::new();
|
||||
for _ in 0..3usize {
|
||||
let user = user.clone();
|
||||
waiters.push(tokio::spawn(async move {
|
||||
let mut io = make_stats_io(user);
|
||||
io.write_all(&[0x55]).await
|
||||
}));
|
||||
}
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(hold_ms)).await;
|
||||
drop(guard);
|
||||
|
||||
timeout(Duration::from_secs(1), async {
|
||||
for waiter in waiters {
|
||||
let done = waiter.await.expect("waiter task must not panic");
|
||||
assert!(
|
||||
done.is_ok(),
|
||||
"waiter must complete after release (round={round})"
|
||||
);
|
||||
}
|
||||
})
|
||||
.await
|
||||
.expect("all waiters must complete in bounded time after release");
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,332 @@
|
|||
use super::relay_bidirectional;
|
||||
use crate::error::ProxyError;
|
||||
use crate::stats::Stats;
|
||||
use crate::stream::BufferPool;
|
||||
use rand::rngs::StdRng;
|
||||
use rand::{RngExt, SeedableRng};
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::time::{Duration, timeout};
|
||||
|
||||
async fn read_available<R: tokio::io::AsyncRead + Unpin>(reader: &mut R, budget: Duration) -> usize {
|
||||
let start = tokio::time::Instant::now();
|
||||
let mut total = 0usize;
|
||||
let mut buf = [0u8; 128];
|
||||
|
||||
loop {
|
||||
let elapsed = start.elapsed();
|
||||
if elapsed >= budget {
|
||||
break;
|
||||
}
|
||||
let remaining = budget.saturating_sub(elapsed);
|
||||
match timeout(remaining, reader.read(&mut buf)).await {
|
||||
Ok(Ok(0)) => break,
|
||||
Ok(Ok(n)) => total = total.saturating_add(n),
|
||||
Ok(Err(_)) | Err(_) => break,
|
||||
}
|
||||
}
|
||||
|
||||
total
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn positive_quota_path_forwards_both_directions_within_limit() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let user = "quota-extended-positive-user";
|
||||
|
||||
let (mut client_peer, relay_client) = duplex(2048);
|
||||
let (relay_server, mut server_peer) = duplex(2048);
|
||||
let (client_reader, client_writer) = tokio::io::split(relay_client);
|
||||
let (server_reader, server_writer) = tokio::io::split(relay_server);
|
||||
|
||||
let relay = tokio::spawn(relay_bidirectional(
|
||||
client_reader,
|
||||
client_writer,
|
||||
server_reader,
|
||||
server_writer,
|
||||
256,
|
||||
256,
|
||||
user,
|
||||
Arc::clone(&stats),
|
||||
Some(16),
|
||||
Arc::new(BufferPool::new()),
|
||||
));
|
||||
|
||||
client_peer.write_all(&[0xAA, 0xBB, 0xCC, 0xDD]).await.unwrap();
|
||||
server_peer.read_exact(&mut [0u8; 4]).await.unwrap();
|
||||
|
||||
server_peer.write_all(&[0x11, 0x22, 0x33, 0x44]).await.unwrap();
|
||||
client_peer.read_exact(&mut [0u8; 4]).await.unwrap();
|
||||
|
||||
drop(client_peer);
|
||||
drop(server_peer);
|
||||
|
||||
let relay_result = timeout(Duration::from_secs(2), relay).await.unwrap().unwrap();
|
||||
assert!(relay_result.is_ok());
|
||||
assert!(stats.get_user_total_octets(user) <= 16);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn negative_preloaded_quota_forbids_any_forwarding() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let user = "quota-extended-negative-user";
|
||||
stats.add_user_octets_from(user, 8);
|
||||
|
||||
let (mut client_peer, relay_client) = duplex(1024);
|
||||
let (relay_server, mut server_peer) = duplex(1024);
|
||||
let (client_reader, client_writer) = tokio::io::split(relay_client);
|
||||
let (server_reader, server_writer) = tokio::io::split(relay_server);
|
||||
|
||||
let relay = tokio::spawn(relay_bidirectional(
|
||||
client_reader,
|
||||
client_writer,
|
||||
server_reader,
|
||||
server_writer,
|
||||
128,
|
||||
128,
|
||||
user,
|
||||
Arc::clone(&stats),
|
||||
Some(8),
|
||||
Arc::new(BufferPool::new()),
|
||||
));
|
||||
|
||||
client_peer.write_all(&[0xAA]).await.unwrap();
|
||||
server_peer.write_all(&[0xBB]).await.unwrap();
|
||||
|
||||
assert_eq!(read_available(&mut server_peer, Duration::from_millis(120)).await, 0);
|
||||
assert_eq!(read_available(&mut client_peer, Duration::from_millis(120)).await, 0);
|
||||
|
||||
let relay_result = timeout(Duration::from_secs(2), relay).await.unwrap().unwrap();
|
||||
assert!(matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. })));
|
||||
assert!(stats.get_user_total_octets(user) <= 8);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn edge_quota_one_ensures_at_most_one_byte_across_directions() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let user = "quota-extended-edge-user";
|
||||
|
||||
let (mut client_peer, relay_client) = duplex(1024);
|
||||
let (relay_server, mut server_peer) = duplex(1024);
|
||||
let (client_reader, client_writer) = tokio::io::split(relay_client);
|
||||
let (server_reader, server_writer) = tokio::io::split(relay_server);
|
||||
|
||||
let relay = tokio::spawn(relay_bidirectional(
|
||||
client_reader,
|
||||
client_writer,
|
||||
server_reader,
|
||||
server_writer,
|
||||
128,
|
||||
128,
|
||||
user,
|
||||
Arc::clone(&stats),
|
||||
Some(1),
|
||||
Arc::new(BufferPool::new()),
|
||||
));
|
||||
|
||||
let _ = tokio::join!(
|
||||
client_peer.write_all(&[0xFE]),
|
||||
server_peer.write_all(&[0xEF]),
|
||||
);
|
||||
|
||||
let mut buf = [0u8; 1];
|
||||
let delivered_s2c = timeout(Duration::from_millis(120), client_peer.read(&mut buf)).await.unwrap().unwrap_or(0);
|
||||
let delivered_c2s = timeout(Duration::from_millis(120), server_peer.read(&mut buf)).await.unwrap().unwrap_or(0);
|
||||
|
||||
assert!(delivered_s2c + delivered_c2s <= 1);
|
||||
|
||||
let relay_result = timeout(Duration::from_secs(2), relay).await.unwrap().unwrap();
|
||||
assert!(matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. })));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn adversarial_blackhat_alternating_jitter_does_not_overshoot_quota() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let user = "quota-extended-blackhat-user";
|
||||
let quota = 24u64;
|
||||
|
||||
let (mut client_peer, relay_client) = duplex(4096);
|
||||
let (relay_server, mut server_peer) = duplex(4096);
|
||||
let (client_reader, client_writer) = tokio::io::split(relay_client);
|
||||
let (server_reader, server_writer) = tokio::io::split(relay_server);
|
||||
|
||||
let relay = tokio::spawn(relay_bidirectional(
|
||||
client_reader,
|
||||
client_writer,
|
||||
server_reader,
|
||||
server_writer,
|
||||
256,
|
||||
256,
|
||||
user,
|
||||
Arc::clone(&stats),
|
||||
Some(quota),
|
||||
Arc::new(BufferPool::new()),
|
||||
));
|
||||
|
||||
let mut total_forwarded = 0usize;
|
||||
|
||||
for i in 0..256usize {
|
||||
if relay.is_finished() {
|
||||
break;
|
||||
}
|
||||
if (i & 1) == 0 {
|
||||
let _ = client_peer.write_all(&[(i as u8) ^ 0x57]).await;
|
||||
let mut one = [0u8; 1];
|
||||
if let Ok(Ok(n)) = timeout(Duration::from_millis(6), server_peer.read(&mut one)).await {
|
||||
total_forwarded += n;
|
||||
}
|
||||
} else {
|
||||
let _ = server_peer.write_all(&[(i as u8) ^ 0xA8]).await;
|
||||
let mut one = [0u8; 1];
|
||||
if let Ok(Ok(n)) = timeout(Duration::from_millis(6), client_peer.read(&mut one)).await {
|
||||
total_forwarded += n;
|
||||
}
|
||||
}
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(((i % 3) + 1) as u64)).await;
|
||||
}
|
||||
|
||||
let relay_result = timeout(Duration::from_secs(3), relay).await.unwrap().unwrap();
|
||||
assert!(matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. })));
|
||||
assert!(total_forwarded <= quota as usize);
|
||||
assert!(stats.get_user_total_octets(user) <= quota);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn light_fuzz_random_quota_schedule_preserves_quota_invariants() {
|
||||
let mut rng = StdRng::seed_from_u64(0xBEEF_C0DE);
|
||||
|
||||
for case in 0..32u64 {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let user = format!("quota-extended-fuzz-{case}");
|
||||
let quota = rng.random_range(1u64..=35u64);
|
||||
|
||||
let (mut client_peer, relay_client) = duplex(4096);
|
||||
let (relay_server, mut server_peer) = duplex(4096);
|
||||
let (client_reader, client_writer) = tokio::io::split(relay_client);
|
||||
let (server_reader, server_writer) = tokio::io::split(relay_server);
|
||||
|
||||
let relay_user = user.clone();
|
||||
let relay_stats = Arc::clone(&stats);
|
||||
let relay = tokio::spawn(async move {
|
||||
relay_bidirectional(
|
||||
client_reader,
|
||||
client_writer,
|
||||
server_reader,
|
||||
server_writer,
|
||||
256,
|
||||
256,
|
||||
&relay_user,
|
||||
Arc::clone(&relay_stats),
|
||||
Some(quota),
|
||||
Arc::new(BufferPool::new()),
|
||||
)
|
||||
.await
|
||||
});
|
||||
|
||||
let mut total_forwarded = 0usize;
|
||||
|
||||
for _ in 0..96usize {
|
||||
if relay.is_finished() {
|
||||
break;
|
||||
}
|
||||
|
||||
if rng.random::<bool>() {
|
||||
let _ = client_peer.write_all(&[rng.random::<u8>()]).await;
|
||||
let mut one = [0u8; 1];
|
||||
if let Ok(Ok(n)) = timeout(Duration::from_millis(4), server_peer.read(&mut one)).await {
|
||||
total_forwarded += n;
|
||||
}
|
||||
} else {
|
||||
let _ = server_peer.write_all(&[rng.random::<u8>()]).await;
|
||||
let mut one = [0u8; 1];
|
||||
if let Ok(Ok(n)) = timeout(Duration::from_millis(4), client_peer.read(&mut one)).await {
|
||||
total_forwarded += n;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
drop(client_peer);
|
||||
drop(server_peer);
|
||||
|
||||
let relay_result = timeout(Duration::from_secs(2), relay).await.unwrap().unwrap();
|
||||
assert!(relay_result.is_ok() || matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. })));
|
||||
assert!(total_forwarded <= quota as usize);
|
||||
assert!(stats.get_user_total_octets(&user) <= quota);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn stress_parallel_relays_for_one_user_obey_global_quota() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let user = "quota-extended-stress-user".to_string();
|
||||
let quota = 64u64;
|
||||
|
||||
let mut tasks = Vec::new();
|
||||
|
||||
for worker in 0..4u8 {
|
||||
let stats = Arc::clone(&stats);
|
||||
let user = user.clone();
|
||||
|
||||
tasks.push(tokio::spawn(async move {
|
||||
let (mut client_peer, relay_client) = duplex(2048);
|
||||
let (relay_server, mut server_peer) = duplex(2048);
|
||||
let (client_reader, client_writer) = tokio::io::split(relay_client);
|
||||
let (server_reader, server_writer) = tokio::io::split(relay_server);
|
||||
|
||||
let relay_user = user.clone();
|
||||
let relay_stats = Arc::clone(&stats);
|
||||
let relay = tokio::spawn(async move {
|
||||
relay_bidirectional(
|
||||
client_reader,
|
||||
client_writer,
|
||||
server_reader,
|
||||
server_writer,
|
||||
128,
|
||||
128,
|
||||
&relay_user,
|
||||
Arc::clone(&relay_stats),
|
||||
Some(quota),
|
||||
Arc::new(BufferPool::new()),
|
||||
)
|
||||
.await
|
||||
});
|
||||
|
||||
let mut total = 0usize;
|
||||
for step in 0..64u8 {
|
||||
if relay.is_finished() {
|
||||
break;
|
||||
}
|
||||
if (step as usize + worker as usize) % 2 == 0 {
|
||||
let _ = client_peer.write_all(&[(step ^ 0x5A)]).await;
|
||||
let mut one = [0u8; 1];
|
||||
if let Ok(Ok(n)) = timeout(Duration::from_millis(6), server_peer.read(&mut one)).await {
|
||||
total += n;
|
||||
}
|
||||
} else {
|
||||
let _ = server_peer.write_all(&[(step ^ 0xA5)]).await;
|
||||
let mut one = [0u8; 1];
|
||||
if let Ok(Ok(n)) = timeout(Duration::from_millis(6), client_peer.read(&mut one)).await {
|
||||
total += n;
|
||||
}
|
||||
}
|
||||
tokio::time::sleep(Duration::from_millis(1)).await;
|
||||
}
|
||||
|
||||
drop(client_peer);
|
||||
drop(server_peer);
|
||||
|
||||
let relay_result = timeout(Duration::from_secs(2), relay).await.unwrap().unwrap();
|
||||
assert!(relay_result.is_ok() || matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. })));
|
||||
total
|
||||
}));
|
||||
}
|
||||
|
||||
let mut delivered = 0usize;
|
||||
for task in tasks {
|
||||
delivered += task.await.unwrap();
|
||||
}
|
||||
|
||||
assert!(stats.get_user_total_octets(&user) <= quota);
|
||||
assert!(delivered <= quota as usize);
|
||||
}
|
||||
|
|
@ -0,0 +1,79 @@
|
|||
use super::*;
|
||||
use dashmap::DashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::time::{Duration, timeout};
|
||||
|
||||
#[test]
|
||||
fn tdd_explicit_quota_lock_evict_reclaims_only_unheld_entries() {
|
||||
let _guard = quota_user_lock_test_scope();
|
||||
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
map.clear();
|
||||
|
||||
let held_user = format!("quota-evict-held-{}", std::process::id());
|
||||
let stale_a_user = format!("quota-evict-stale-a-{}", std::process::id());
|
||||
let stale_b_user = format!("quota-evict-stale-b-{}", std::process::id());
|
||||
|
||||
let held = quota_user_lock(&held_user);
|
||||
let stale_a = quota_user_lock(&stale_a_user);
|
||||
let stale_b = quota_user_lock(&stale_b_user);
|
||||
|
||||
assert!(map.get(&held_user).is_some());
|
||||
assert!(map.get(&stale_a_user).is_some());
|
||||
assert!(map.get(&stale_b_user).is_some());
|
||||
|
||||
drop(stale_a);
|
||||
drop(stale_b);
|
||||
|
||||
quota_user_lock_evict();
|
||||
|
||||
assert!(
|
||||
map.get(&held_user).is_some(),
|
||||
"held entry must survive eviction"
|
||||
);
|
||||
assert!(
|
||||
map.get(&stale_a_user).is_none(),
|
||||
"unheld stale entry must be reclaimed"
|
||||
);
|
||||
assert!(
|
||||
map.get(&stale_b_user).is_none(),
|
||||
"unheld stale entry must be reclaimed"
|
||||
);
|
||||
|
||||
drop(held);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn tdd_periodic_quota_lock_evictor_reclaims_stale_entries_off_hot_path() {
|
||||
let _guard = quota_user_lock_test_scope();
|
||||
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
map.clear();
|
||||
|
||||
let held_user = format!("quota-evict-loop-held-{}", std::process::id());
|
||||
let stale_user = format!("quota-evict-loop-stale-{}", std::process::id());
|
||||
|
||||
let held = quota_user_lock(&held_user);
|
||||
let stale = quota_user_lock(&stale_user);
|
||||
|
||||
assert_eq!(map.len(), 2);
|
||||
drop(stale);
|
||||
|
||||
let evictor = spawn_quota_user_lock_evictor_for_tests(Duration::from_millis(5));
|
||||
|
||||
timeout(Duration::from_millis(200), async {
|
||||
loop {
|
||||
if map.get(&stale_user).is_none() {
|
||||
break;
|
||||
}
|
||||
tokio::time::sleep(Duration::from_millis(5)).await;
|
||||
}
|
||||
})
|
||||
.await
|
||||
.expect("periodic quota lock evictor must reclaim stale entry");
|
||||
|
||||
evictor.abort();
|
||||
|
||||
assert!(map.get(&held_user).is_some());
|
||||
assert!(map.get(&stale_user).is_none());
|
||||
|
||||
drop(held);
|
||||
}
|
||||
|
|
@ -0,0 +1,153 @@
|
|||
use super::*;
|
||||
use dashmap::DashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::task::JoinSet;
|
||||
use tokio::time::{Duration, timeout};
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn stress_background_evictor_with_high_churn_keeps_cache_bounded_and_live() {
|
||||
let _guard = quota_user_lock_test_scope();
|
||||
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
map.clear();
|
||||
|
||||
let evictor = spawn_quota_user_lock_evictor_for_tests(Duration::from_millis(5));
|
||||
|
||||
let mut tasks = JoinSet::new();
|
||||
for worker in 0..24u32 {
|
||||
tasks.spawn(async move {
|
||||
for round in 0..320u32 {
|
||||
let user = format!(
|
||||
"quota-evict-stress-user-{}-{}-{}",
|
||||
std::process::id(),
|
||||
worker,
|
||||
round
|
||||
);
|
||||
let lock = quota_user_lock(&user);
|
||||
if round % 19 == 0 {
|
||||
tokio::task::yield_now().await;
|
||||
}
|
||||
drop(lock);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
while let Some(done) = tasks.join_next().await {
|
||||
done.expect("stress worker must not panic");
|
||||
}
|
||||
|
||||
quota_user_lock_evict();
|
||||
tokio::time::sleep(Duration::from_millis(20)).await;
|
||||
|
||||
assert!(
|
||||
map.len() <= QUOTA_USER_LOCKS_MAX,
|
||||
"quota lock map must remain bounded after churn + eviction"
|
||||
);
|
||||
|
||||
let sanity_user = format!("quota-evict-stress-sanity-{}", std::process::id());
|
||||
let sanity_lock = quota_user_lock(&sanity_user);
|
||||
assert!(
|
||||
map.get(&sanity_user).is_some(),
|
||||
"sanity user should be cacheable after eviction reclaimed stale entries"
|
||||
);
|
||||
|
||||
drop(sanity_lock);
|
||||
evictor.abort();
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn adversarial_held_lock_survives_repeated_eviction_then_reclaims_after_release() {
|
||||
let _guard = quota_user_lock_test_scope();
|
||||
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
map.clear();
|
||||
|
||||
let held_user = format!("quota-evict-held-survive-{}", std::process::id());
|
||||
let held = quota_user_lock(&held_user);
|
||||
|
||||
let evictor = spawn_quota_user_lock_evictor_for_tests(Duration::from_millis(3));
|
||||
|
||||
for idx in 0..512u32 {
|
||||
let user = format!("quota-evict-held-churn-{}-{}", std::process::id(), idx);
|
||||
let temp = quota_user_lock(&user);
|
||||
drop(temp);
|
||||
if idx % 32 == 0 {
|
||||
tokio::task::yield_now().await;
|
||||
}
|
||||
}
|
||||
|
||||
let reacquired = quota_user_lock(&held_user);
|
||||
assert!(
|
||||
Arc::ptr_eq(&held, &reacquired),
|
||||
"held user lock identity must remain stable across repeated evictions"
|
||||
);
|
||||
assert!(
|
||||
map.get(&held_user).is_some(),
|
||||
"held user entry must not be reclaimed while externally referenced"
|
||||
);
|
||||
|
||||
drop(reacquired);
|
||||
drop(held);
|
||||
|
||||
timeout(Duration::from_millis(300), async {
|
||||
loop {
|
||||
if map.get(&held_user).is_none() {
|
||||
break;
|
||||
}
|
||||
tokio::time::sleep(Duration::from_millis(5)).await;
|
||||
}
|
||||
})
|
||||
.await
|
||||
.expect("released held lock must be reclaimed by periodic evictor");
|
||||
|
||||
evictor.abort();
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn stress_saturation_then_periodic_eviction_recovers_cacheability_without_inline_retain() {
|
||||
let _guard = quota_user_lock_test_scope();
|
||||
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
map.clear();
|
||||
|
||||
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
|
||||
let prefix = format!("quota-evict-saturated-{}", std::process::id());
|
||||
for idx in 0..QUOTA_USER_LOCKS_MAX {
|
||||
retained.push(quota_user_lock(&format!("{prefix}-{idx}")));
|
||||
}
|
||||
|
||||
assert_eq!(map.len(), QUOTA_USER_LOCKS_MAX);
|
||||
|
||||
let overflow_user = format!("quota-evict-overflow-user-{}", std::process::id());
|
||||
let overflow_before = quota_user_lock(&overflow_user);
|
||||
assert!(
|
||||
map.get(&overflow_user).is_none(),
|
||||
"saturated map must initially route new user to overflow stripe"
|
||||
);
|
||||
|
||||
drop(retained);
|
||||
|
||||
let evictor = spawn_quota_user_lock_evictor_for_tests(Duration::from_millis(4));
|
||||
|
||||
timeout(Duration::from_millis(400), async {
|
||||
loop {
|
||||
if map.len() < QUOTA_USER_LOCKS_MAX {
|
||||
break;
|
||||
}
|
||||
tokio::time::sleep(Duration::from_millis(5)).await;
|
||||
}
|
||||
})
|
||||
.await
|
||||
.expect("periodic evictor must reclaim stale saturated entries");
|
||||
|
||||
let overflow_after = quota_user_lock(&overflow_user);
|
||||
assert!(
|
||||
map.get(&overflow_user).is_some(),
|
||||
"after eviction, overflow user should become cacheable again"
|
||||
);
|
||||
assert!(
|
||||
Arc::strong_count(&overflow_after) >= 2,
|
||||
"cacheable lock should be held by map and caller"
|
||||
);
|
||||
|
||||
drop(overflow_before);
|
||||
drop(overflow_after);
|
||||
evictor.abort();
|
||||
}
|
||||
|
|
@ -0,0 +1,135 @@
|
|||
use super::*;
|
||||
use crate::stats::Stats;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
|
||||
use std::task::Waker;
|
||||
use std::task::{Context, Poll};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
|
||||
#[derive(Default)]
|
||||
struct WakeCounter {
|
||||
wakes: AtomicUsize,
|
||||
}
|
||||
|
||||
impl std::task::Wake for WakeCounter {
|
||||
fn wake(self: Arc<Self>) {
|
||||
self.wakes.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
fn wake_by_ref(self: &Arc<Self>) {
|
||||
self.wakes.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
fn build_context() -> (Arc<WakeCounter>, Context<'static>) {
|
||||
let wake_counter = Arc::new(WakeCounter::default());
|
||||
let waker = Waker::from(Arc::clone(&wake_counter));
|
||||
// Context stores a reference; leak one Waker for deterministic test scope.
|
||||
let leaked_waker: &'static Waker = Box::leak(Box::new(waker));
|
||||
(wake_counter, Context::from_waker(leaked_waker))
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn adversarial_map_churn_cannot_bypass_held_writer_lock() {
|
||||
let _guard = quota_user_lock_test_scope();
|
||||
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
map.clear();
|
||||
|
||||
let user = "quota-identity-writer-user";
|
||||
let held_lock = quota_user_lock(user);
|
||||
let _held_guard = held_lock
|
||||
.try_lock()
|
||||
.expect("test must hold initial user lock before StatsIo poll");
|
||||
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user.to_string(),
|
||||
Some(1024),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
|
||||
map.clear();
|
||||
let churned_lock = quota_user_lock(user);
|
||||
assert!(
|
||||
!Arc::ptr_eq(&held_lock, &churned_lock),
|
||||
"precondition: map churn should produce a distinct lock identity"
|
||||
);
|
||||
|
||||
let (_wake_counter, mut cx) = build_context();
|
||||
let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x11, 0x22, 0x33, 0x44]);
|
||||
|
||||
assert!(
|
||||
matches!(poll, Poll::Pending),
|
||||
"writer must remain pending on the originally-held lock identity"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn adversarial_map_churn_cannot_bypass_held_reader_lock() {
|
||||
let _guard = quota_user_lock_test_scope();
|
||||
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
map.clear();
|
||||
|
||||
let user = "quota-identity-reader-user";
|
||||
let held_lock = quota_user_lock(user);
|
||||
let _held_guard = held_lock
|
||||
.try_lock()
|
||||
.expect("test must hold initial user lock before StatsIo poll");
|
||||
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::empty(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user.to_string(),
|
||||
Some(1024),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
|
||||
map.clear();
|
||||
let churned_lock = quota_user_lock(user);
|
||||
assert!(
|
||||
!Arc::ptr_eq(&held_lock, &churned_lock),
|
||||
"precondition: map churn should produce a distinct lock identity"
|
||||
);
|
||||
|
||||
let (_wake_counter, mut cx) = build_context();
|
||||
let mut storage = [0u8; 8];
|
||||
let mut read_buf = ReadBuf::new(&mut storage);
|
||||
let poll = Pin::new(&mut io).poll_read(&mut cx, &mut read_buf);
|
||||
|
||||
assert!(
|
||||
matches!(poll, Poll::Pending),
|
||||
"reader must remain pending on the originally-held lock identity"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn business_no_lock_contention_keeps_writer_progress() {
|
||||
let _guard = quota_user_lock_test_scope();
|
||||
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
map.clear();
|
||||
|
||||
let user = "quota-identity-progress-user";
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user.to_string(),
|
||||
Some(1024),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
|
||||
let (_wake_counter, mut cx) = build_context();
|
||||
let poll = Pin::new(&mut io).poll_write(&mut cx, &[0xAA, 0xBB]);
|
||||
|
||||
assert!(
|
||||
matches!(poll, Poll::Ready(Ok(2))),
|
||||
"writer should progress immediately without contention"
|
||||
);
|
||||
}
|
||||
|
|
@ -127,7 +127,7 @@ fn quota_lock_saturation_returns_stable_overflow_lock_without_cache_growth() {
|
|||
}
|
||||
|
||||
#[test]
|
||||
fn quota_lock_reclaims_unreferenced_entries_before_ephemeral_fallback() {
|
||||
fn quota_lock_reclaims_unreferenced_entries_after_explicit_eviction_pass() {
|
||||
let _guard = super::quota_user_lock_test_scope();
|
||||
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
map.clear();
|
||||
|
|
@ -142,6 +142,8 @@ fn quota_lock_reclaims_unreferenced_entries_before_ephemeral_fallback() {
|
|||
|
||||
drop(retained);
|
||||
|
||||
quota_user_lock_evict();
|
||||
|
||||
let overflow_user = format!("quota-reclaim-overflow-{}", std::process::id());
|
||||
let overflow = quota_user_lock(&overflow_user);
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,249 @@
|
|||
use super::*;
|
||||
use crate::stats::Stats;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
|
||||
use std::task::{Context, Waker};
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::time::{Duration, Instant, timeout};
|
||||
|
||||
#[derive(Default)]
|
||||
struct WakeCounter {
|
||||
wakes: AtomicUsize,
|
||||
}
|
||||
|
||||
impl std::task::Wake for WakeCounter {
|
||||
fn wake(self: Arc<Self>) {
|
||||
self.wakes.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
fn wake_by_ref(self: &Arc<Self>) {
|
||||
self.wakes.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
fn quota_test_guard() -> impl Drop {
|
||||
super::quota_user_lock_test_scope()
|
||||
}
|
||||
|
||||
fn build_context() -> (Arc<WakeCounter>, Context<'static>) {
|
||||
let wake_counter = Arc::new(WakeCounter::default());
|
||||
let waker = Waker::from(Arc::clone(&wake_counter));
|
||||
let leaked_waker: &'static Waker = Box::leak(Box::new(waker));
|
||||
(wake_counter, Context::from_waker(leaked_waker))
|
||||
}
|
||||
|
||||
fn sleep_slot_ptr(slot: &Option<Pin<Box<tokio::time::Sleep>>>) -> usize {
|
||||
slot.as_ref()
|
||||
.map(|sleep| (&**sleep) as *const tokio::time::Sleep as usize)
|
||||
.unwrap_or(0)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn tdd_single_pending_timer_does_not_allocate_on_each_repoll() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("retry-alloc-single-pending-{}", std::process::id());
|
||||
let lock = quota_user_lock(&user);
|
||||
let held_guard = lock
|
||||
.try_lock()
|
||||
.expect("test must hold local lock to force retry scheduling");
|
||||
|
||||
reset_quota_retry_sleep_allocs_for_tests();
|
||||
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
Some(2048),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
Instant::now(),
|
||||
);
|
||||
|
||||
let (_wake_counter, mut cx) = build_context();
|
||||
|
||||
let first = Pin::new(&mut io).poll_write(&mut cx, &[0xA1]);
|
||||
assert!(first.is_pending());
|
||||
let allocs_after_first = quota_retry_sleep_allocs_for_tests();
|
||||
let ptr_after_first = sleep_slot_ptr(&io.quota_write_retry_sleep);
|
||||
|
||||
let second = Pin::new(&mut io).poll_write(&mut cx, &[0xA2]);
|
||||
assert!(second.is_pending());
|
||||
let allocs_after_second = quota_retry_sleep_allocs_for_tests();
|
||||
let ptr_after_second = sleep_slot_ptr(&io.quota_write_retry_sleep);
|
||||
|
||||
assert_eq!(allocs_after_first, 1, "first pending poll must allocate one timer");
|
||||
assert_eq!(
|
||||
allocs_after_second, 1,
|
||||
"repoll while the same timer is pending must not allocate again"
|
||||
);
|
||||
assert_eq!(
|
||||
ptr_after_first, ptr_after_second,
|
||||
"repoll while pending should retain the same timer allocation"
|
||||
);
|
||||
|
||||
drop(held_guard);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn tdd_retry_cycle_allocates_once_per_fired_timer_cycle_not_per_poll() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("retry-alloc-per-cycle-{}", std::process::id());
|
||||
let lock = quota_user_lock(&user);
|
||||
let held_guard = lock
|
||||
.try_lock()
|
||||
.expect("test must hold local lock to keep write path pending");
|
||||
|
||||
reset_quota_retry_sleep_allocs_for_tests();
|
||||
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
Some(2048),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
Instant::now(),
|
||||
);
|
||||
|
||||
let (wake_counter, mut cx) = build_context();
|
||||
|
||||
let mut polls = 0u64;
|
||||
let mut observed_wakes = 0usize;
|
||||
let started = Instant::now();
|
||||
while started.elapsed() < Duration::from_millis(70) {
|
||||
let poll = Pin::new(&mut io).poll_write(&mut cx, &[0xB1]);
|
||||
polls = polls.saturating_add(1);
|
||||
assert!(poll.is_pending());
|
||||
|
||||
let wakes = wake_counter.wakes.load(Ordering::Relaxed);
|
||||
if wakes > observed_wakes {
|
||||
observed_wakes = wakes;
|
||||
}
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(1)).await;
|
||||
}
|
||||
|
||||
let allocs = quota_retry_sleep_allocs_for_tests();
|
||||
assert!(allocs >= 2, "multiple fired cycles should allocate multiple timers");
|
||||
assert!(
|
||||
allocs < polls,
|
||||
"timer allocations must be bounded by cycles, not by every repoll (allocs={allocs}, polls={polls})"
|
||||
);
|
||||
|
||||
drop(held_guard);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn adversarial_backoff_latency_envelope_stays_bounded_under_contention() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("retry-latency-envelope-{}", std::process::id());
|
||||
let lock = quota_user_lock(&user);
|
||||
let held_guard = lock
|
||||
.try_lock()
|
||||
.expect("test must hold local lock for sustained contention");
|
||||
|
||||
reset_quota_retry_sleep_allocs_for_tests();
|
||||
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
Some(2048),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
Instant::now(),
|
||||
);
|
||||
|
||||
let (wake_counter, mut cx) = build_context();
|
||||
|
||||
let first = Pin::new(&mut io).poll_write(&mut cx, &[0xC1]);
|
||||
assert!(first.is_pending());
|
||||
|
||||
let started = Instant::now();
|
||||
let mut last_wakes = 0usize;
|
||||
let mut wake_instants = Vec::new();
|
||||
|
||||
while started.elapsed() < Duration::from_millis(120) {
|
||||
let wakes = wake_counter.wakes.load(Ordering::Relaxed);
|
||||
if wakes > last_wakes {
|
||||
last_wakes = wakes;
|
||||
wake_instants.push(Instant::now());
|
||||
let pending = Pin::new(&mut io).poll_write(&mut cx, &[0xC2]);
|
||||
assert!(pending.is_pending());
|
||||
}
|
||||
tokio::time::sleep(Duration::from_millis(1)).await;
|
||||
}
|
||||
|
||||
let mut max_gap = Duration::from_millis(0);
|
||||
for idx in 1..wake_instants.len() {
|
||||
let gap = wake_instants[idx].saturating_duration_since(wake_instants[idx - 1]);
|
||||
if gap > max_gap {
|
||||
max_gap = gap;
|
||||
}
|
||||
}
|
||||
|
||||
assert!(
|
||||
max_gap <= Duration::from_millis(35),
|
||||
"retry wake gap must remain bounded in test profile; observed max gap={max_gap:?}"
|
||||
);
|
||||
assert!(
|
||||
quota_retry_sleep_allocs_for_tests() <= 16,
|
||||
"allocation cycles must remain bounded during a short contention window"
|
||||
);
|
||||
|
||||
drop(held_guard);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn micro_benchmark_release_to_completion_latency_stays_bounded() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let rounds = 96usize;
|
||||
let mut samples_ms = Vec::with_capacity(rounds);
|
||||
|
||||
for round in 0..rounds {
|
||||
let user = format!("retry-release-latency-{}-{round}", std::process::id());
|
||||
let lock = quota_user_lock(&user);
|
||||
let held_guard = lock
|
||||
.try_lock()
|
||||
.expect("test must hold local lock before spawning blocked writer");
|
||||
|
||||
let writer = tokio::spawn(async move {
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
Some(2048),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
Instant::now(),
|
||||
);
|
||||
io.write_all(&[0xD1]).await
|
||||
});
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(2)).await;
|
||||
let release_at = Instant::now();
|
||||
drop(held_guard);
|
||||
|
||||
let done = timeout(Duration::from_millis(120), writer)
|
||||
.await
|
||||
.expect("blocked writer must complete after release")
|
||||
.expect("writer task must not panic");
|
||||
assert!(done.is_ok());
|
||||
|
||||
samples_ms.push(release_at.elapsed().as_millis() as u64);
|
||||
}
|
||||
|
||||
samples_ms.sort_unstable();
|
||||
let p95_idx = ((samples_ms.len() * 95) / 100).min(samples_ms.len().saturating_sub(1));
|
||||
let p95_ms = samples_ms[p95_idx];
|
||||
|
||||
assert!(
|
||||
p95_ms <= 40,
|
||||
"contention release->completion p95 must stay bounded; p95_ms={p95_ms}, samples={samples_ms:?}"
|
||||
);
|
||||
}
|
||||
|
|
@ -0,0 +1,241 @@
|
|||
use super::*;
|
||||
use crate::stats::Stats;
|
||||
use dashmap::DashMap;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
|
||||
use std::task::{Context, Waker};
|
||||
use tokio::io::ReadBuf;
|
||||
use tokio::time::{Duration, Instant};
|
||||
|
||||
#[derive(Default)]
|
||||
struct WakeCounter {
|
||||
wakes: AtomicUsize,
|
||||
}
|
||||
|
||||
impl std::task::Wake for WakeCounter {
|
||||
fn wake(self: Arc<Self>) {
|
||||
self.wakes.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
fn wake_by_ref(self: &Arc<Self>) {
|
||||
self.wakes.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
fn quota_test_guard() -> impl Drop {
|
||||
super::quota_user_lock_test_scope()
|
||||
}
|
||||
|
||||
fn saturate_quota_user_locks() -> Vec<Arc<std::sync::Mutex<()>>> {
|
||||
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
map.clear();
|
||||
|
||||
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
|
||||
for idx in 0..QUOTA_USER_LOCKS_MAX {
|
||||
retained.push(quota_user_lock(&format!("quota-retry-bench-saturate-{idx}")));
|
||||
}
|
||||
retained
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn stress_contention_wake_rate_decays_with_backoff_curve() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let _retained = saturate_quota_user_locks();
|
||||
let user = format!("quota-backoff-bench-{}", std::process::id());
|
||||
let stats = Arc::new(Stats::new());
|
||||
|
||||
let lock = quota_user_lock(&user);
|
||||
let held_guard = lock
|
||||
.try_lock()
|
||||
.expect("test must hold quota lock before benchmark run");
|
||||
|
||||
let waiters = 64usize;
|
||||
let mut ios = Vec::with_capacity(waiters);
|
||||
let mut wake_counters = Vec::with_capacity(waiters);
|
||||
|
||||
for _ in 0..waiters {
|
||||
ios.push(StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::clone(&stats),
|
||||
user.clone(),
|
||||
Some(4096),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
));
|
||||
}
|
||||
|
||||
for io in &mut ios {
|
||||
let counter = Arc::new(WakeCounter::default());
|
||||
let waker = Waker::from(Arc::clone(&counter));
|
||||
let mut cx = Context::from_waker(&waker);
|
||||
let pending = Pin::new(io).poll_write(&mut cx, &[0x71]);
|
||||
assert!(pending.is_pending());
|
||||
wake_counters.push(counter);
|
||||
}
|
||||
|
||||
let mut observed = vec![0usize; waiters];
|
||||
let start = Instant::now();
|
||||
let mut wakes_at_40ms = 0usize;
|
||||
let mut wakes_at_160ms = 0usize;
|
||||
|
||||
while start.elapsed() < Duration::from_millis(200) {
|
||||
for (idx, counter) in wake_counters.iter().enumerate() {
|
||||
let wakes = counter.wakes.load(Ordering::Relaxed);
|
||||
if wakes > observed[idx] {
|
||||
observed[idx] = wakes;
|
||||
let waker = Waker::from(Arc::clone(counter));
|
||||
let mut cx = Context::from_waker(&waker);
|
||||
let pending = Pin::new(&mut ios[idx]).poll_write(&mut cx, &[0x72]);
|
||||
assert!(pending.is_pending());
|
||||
}
|
||||
}
|
||||
|
||||
let elapsed = start.elapsed();
|
||||
if elapsed >= Duration::from_millis(40) && wakes_at_40ms == 0 {
|
||||
wakes_at_40ms = wake_counters
|
||||
.iter()
|
||||
.map(|counter| counter.wakes.load(Ordering::Relaxed))
|
||||
.sum();
|
||||
}
|
||||
if elapsed >= Duration::from_millis(160) && wakes_at_160ms == 0 {
|
||||
wakes_at_160ms = wake_counters
|
||||
.iter()
|
||||
.map(|counter| counter.wakes.load(Ordering::Relaxed))
|
||||
.sum();
|
||||
}
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(1)).await;
|
||||
}
|
||||
|
||||
let total_wakes: usize = wake_counters
|
||||
.iter()
|
||||
.map(|counter| counter.wakes.load(Ordering::Relaxed))
|
||||
.sum();
|
||||
|
||||
let wakes_at_200ms = total_wakes;
|
||||
let early_window_wakes = wakes_at_40ms;
|
||||
let late_window_wakes = wakes_at_200ms.saturating_sub(wakes_at_160ms);
|
||||
|
||||
assert!(
|
||||
total_wakes <= waiters * 28,
|
||||
"backoff benchmark exceeded wake budget; waiters={waiters}, wakes={total_wakes}"
|
||||
);
|
||||
|
||||
assert!(
|
||||
early_window_wakes > 0,
|
||||
"benchmark failed to observe early contention wakes"
|
||||
);
|
||||
|
||||
assert!(
|
||||
late_window_wakes * 4 <= early_window_wakes * 3,
|
||||
"wake-rate decay invariant violated; early_0_40ms={early_window_wakes}, late_160_200ms={late_window_wakes}, total={total_wakes}"
|
||||
);
|
||||
|
||||
drop(held_guard);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn stress_read_contention_wake_rate_decays_with_backoff_curve() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let _retained = saturate_quota_user_locks();
|
||||
let user = format!("quota-backoff-read-bench-{}", std::process::id());
|
||||
let stats = Arc::new(Stats::new());
|
||||
|
||||
let lock = quota_user_lock(&user);
|
||||
let held_guard = lock
|
||||
.try_lock()
|
||||
.expect("test must hold quota lock before read benchmark run");
|
||||
|
||||
let waiters = 64usize;
|
||||
let mut ios = Vec::with_capacity(waiters);
|
||||
let mut wake_counters = Vec::with_capacity(waiters);
|
||||
|
||||
for _ in 0..waiters {
|
||||
ios.push(StatsIo::new(
|
||||
tokio::io::empty(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::clone(&stats),
|
||||
user.clone(),
|
||||
Some(4096),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
));
|
||||
}
|
||||
|
||||
for io in &mut ios {
|
||||
let counter = Arc::new(WakeCounter::default());
|
||||
let waker = Waker::from(Arc::clone(&counter));
|
||||
let mut cx = Context::from_waker(&waker);
|
||||
let mut storage = [0u8; 1];
|
||||
let mut buf = ReadBuf::new(&mut storage);
|
||||
let pending = Pin::new(io).poll_read(&mut cx, &mut buf);
|
||||
assert!(pending.is_pending());
|
||||
wake_counters.push(counter);
|
||||
}
|
||||
|
||||
let mut observed = vec![0usize; waiters];
|
||||
let start = Instant::now();
|
||||
let mut wakes_at_40ms = 0usize;
|
||||
let mut wakes_at_160ms = 0usize;
|
||||
|
||||
while start.elapsed() < Duration::from_millis(200) {
|
||||
for (idx, counter) in wake_counters.iter().enumerate() {
|
||||
let wakes = counter.wakes.load(Ordering::Relaxed);
|
||||
if wakes > observed[idx] {
|
||||
observed[idx] = wakes;
|
||||
let waker = Waker::from(Arc::clone(counter));
|
||||
let mut cx = Context::from_waker(&waker);
|
||||
let mut storage = [0u8; 1];
|
||||
let mut buf = ReadBuf::new(&mut storage);
|
||||
let pending = Pin::new(&mut ios[idx]).poll_read(&mut cx, &mut buf);
|
||||
assert!(pending.is_pending());
|
||||
}
|
||||
}
|
||||
|
||||
let elapsed = start.elapsed();
|
||||
if elapsed >= Duration::from_millis(40) && wakes_at_40ms == 0 {
|
||||
wakes_at_40ms = wake_counters
|
||||
.iter()
|
||||
.map(|counter| counter.wakes.load(Ordering::Relaxed))
|
||||
.sum();
|
||||
}
|
||||
if elapsed >= Duration::from_millis(160) && wakes_at_160ms == 0 {
|
||||
wakes_at_160ms = wake_counters
|
||||
.iter()
|
||||
.map(|counter| counter.wakes.load(Ordering::Relaxed))
|
||||
.sum();
|
||||
}
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(1)).await;
|
||||
}
|
||||
|
||||
let total_wakes: usize = wake_counters
|
||||
.iter()
|
||||
.map(|counter| counter.wakes.load(Ordering::Relaxed))
|
||||
.sum();
|
||||
|
||||
let wakes_at_200ms = total_wakes;
|
||||
let early_window_wakes = wakes_at_40ms;
|
||||
let late_window_wakes = wakes_at_200ms.saturating_sub(wakes_at_160ms);
|
||||
|
||||
assert!(
|
||||
total_wakes <= waiters * 28,
|
||||
"read backoff benchmark exceeded wake budget; waiters={waiters}, wakes={total_wakes}"
|
||||
);
|
||||
|
||||
assert!(
|
||||
early_window_wakes > 0,
|
||||
"read benchmark failed to observe early contention wakes"
|
||||
);
|
||||
|
||||
assert!(
|
||||
late_window_wakes * 4 <= early_window_wakes * 3,
|
||||
"read wake-rate decay invariant violated; early_0_40ms={early_window_wakes}, late_160_200ms={late_window_wakes}, total={total_wakes}"
|
||||
);
|
||||
|
||||
drop(held_guard);
|
||||
}
|
||||
|
|
@ -0,0 +1,339 @@
|
|||
use super::*;
|
||||
use crate::stats::Stats;
|
||||
use dashmap::DashMap;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
|
||||
use std::task::{Context, Waker};
|
||||
use tokio::io::ReadBuf;
|
||||
use tokio::time::{Duration, Instant};
|
||||
|
||||
#[derive(Default)]
|
||||
struct WakeCounter {
|
||||
wakes: AtomicUsize,
|
||||
}
|
||||
|
||||
impl std::task::Wake for WakeCounter {
|
||||
fn wake(self: Arc<Self>) {
|
||||
self.wakes.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
fn wake_by_ref(self: &Arc<Self>) {
|
||||
self.wakes.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
fn quota_test_guard() -> impl Drop {
|
||||
super::quota_user_lock_test_scope()
|
||||
}
|
||||
|
||||
fn saturate_quota_user_locks() -> Vec<Arc<std::sync::Mutex<()>>> {
|
||||
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
map.clear();
|
||||
|
||||
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
|
||||
for idx in 0..QUOTA_USER_LOCKS_MAX {
|
||||
retained.push(quota_user_lock(&format!("quota-retry-backoff-saturate-{idx}")));
|
||||
}
|
||||
retained
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn positive_uncontended_writer_keeps_retry_wakes_zero() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::clone(&stats),
|
||||
"quota-backoff-positive".to_string(),
|
||||
Some(2048),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
|
||||
let wake_counter = Arc::new(WakeCounter::default());
|
||||
let waker = Waker::from(Arc::clone(&wake_counter));
|
||||
let mut cx = Context::from_waker(&waker);
|
||||
|
||||
let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x41, 0x42]);
|
||||
assert!(poll.is_ready(), "uncontended writer must complete immediately");
|
||||
assert_eq!(
|
||||
wake_counter.wakes.load(Ordering::Relaxed),
|
||||
0,
|
||||
"uncontended path must not schedule deferred contention wakes"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn adversarial_writer_sustained_contention_executor_repoll_is_rate_limited() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let _retained = saturate_quota_user_locks();
|
||||
let user = "quota-backoff-adversarial-writer";
|
||||
let stats = Arc::new(Stats::new());
|
||||
|
||||
let lock = quota_user_lock(user);
|
||||
let held_guard = lock
|
||||
.try_lock()
|
||||
.expect("test must hold quota lock before polling writer");
|
||||
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::clone(&stats),
|
||||
user.to_string(),
|
||||
Some(2048),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
|
||||
let wake_counter = Arc::new(WakeCounter::default());
|
||||
let waker = Waker::from(Arc::clone(&wake_counter));
|
||||
let mut cx = Context::from_waker(&waker);
|
||||
|
||||
let first = Pin::new(&mut io).poll_write(&mut cx, &[0xAA]);
|
||||
assert!(first.is_pending());
|
||||
|
||||
let start = Instant::now();
|
||||
let mut observed = 0usize;
|
||||
while start.elapsed() < Duration::from_millis(80) {
|
||||
let wakes = wake_counter.wakes.load(Ordering::Relaxed);
|
||||
if wakes > observed {
|
||||
observed = wakes;
|
||||
let pending = Pin::new(&mut io).poll_write(&mut cx, &[0xAB]);
|
||||
assert!(pending.is_pending());
|
||||
}
|
||||
tokio::time::sleep(Duration::from_millis(1)).await;
|
||||
}
|
||||
|
||||
assert!(
|
||||
wake_counter.wakes.load(Ordering::Relaxed) <= 16,
|
||||
"sustained contention must be rate limited; observed wakes={} in 80ms",
|
||||
wake_counter.wakes.load(Ordering::Relaxed)
|
||||
);
|
||||
|
||||
drop(held_guard);
|
||||
let ready = Pin::new(&mut io).poll_write(&mut cx, &[0xAC]);
|
||||
assert!(ready.is_ready());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn adversarial_reader_sustained_contention_executor_repoll_is_rate_limited() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let _retained = saturate_quota_user_locks();
|
||||
let user = "quota-backoff-adversarial-reader";
|
||||
let stats = Arc::new(Stats::new());
|
||||
|
||||
let lock = quota_user_lock(user);
|
||||
let held_guard = lock
|
||||
.try_lock()
|
||||
.expect("test must hold quota lock before polling reader");
|
||||
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::empty(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::clone(&stats),
|
||||
user.to_string(),
|
||||
Some(2048),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
|
||||
let wake_counter = Arc::new(WakeCounter::default());
|
||||
let waker = Waker::from(Arc::clone(&wake_counter));
|
||||
let mut cx = Context::from_waker(&waker);
|
||||
let mut storage = [0u8; 1];
|
||||
|
||||
let mut buf = ReadBuf::new(&mut storage);
|
||||
let first = Pin::new(&mut io).poll_read(&mut cx, &mut buf);
|
||||
assert!(first.is_pending());
|
||||
|
||||
let start = Instant::now();
|
||||
let mut observed = 0usize;
|
||||
while start.elapsed() < Duration::from_millis(80) {
|
||||
let wakes = wake_counter.wakes.load(Ordering::Relaxed);
|
||||
if wakes > observed {
|
||||
observed = wakes;
|
||||
let mut next = ReadBuf::new(&mut storage);
|
||||
let pending = Pin::new(&mut io).poll_read(&mut cx, &mut next);
|
||||
assert!(pending.is_pending());
|
||||
}
|
||||
tokio::time::sleep(Duration::from_millis(1)).await;
|
||||
}
|
||||
|
||||
assert!(
|
||||
wake_counter.wakes.load(Ordering::Relaxed) <= 16,
|
||||
"sustained contention must be rate limited; observed wakes={} in 80ms",
|
||||
wake_counter.wakes.load(Ordering::Relaxed)
|
||||
);
|
||||
|
||||
drop(held_guard);
|
||||
let mut done = ReadBuf::new(&mut storage);
|
||||
let ready = Pin::new(&mut io).poll_read(&mut cx, &mut done);
|
||||
assert!(ready.is_ready());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn edge_backoff_attempt_resets_after_contention_release() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let _retained = saturate_quota_user_locks();
|
||||
let user = "quota-backoff-edge-reset";
|
||||
let stats = Arc::new(Stats::new());
|
||||
|
||||
let lock = quota_user_lock(user);
|
||||
let held_guard = lock
|
||||
.try_lock()
|
||||
.expect("test must hold quota lock before polling writer");
|
||||
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::clone(&stats),
|
||||
user.to_string(),
|
||||
Some(2048),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
|
||||
let wake_counter = Arc::new(WakeCounter::default());
|
||||
let waker = Waker::from(Arc::clone(&wake_counter));
|
||||
let mut cx = Context::from_waker(&waker);
|
||||
|
||||
let initial = Pin::new(&mut io).poll_write(&mut cx, &[0x31]);
|
||||
assert!(initial.is_pending());
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(15)).await;
|
||||
let wakes = wake_counter.wakes.load(Ordering::Relaxed);
|
||||
if wakes > 0 {
|
||||
let pending = Pin::new(&mut io).poll_write(&mut cx, &[0x32]);
|
||||
assert!(pending.is_pending());
|
||||
}
|
||||
|
||||
drop(held_guard);
|
||||
let ready = Pin::new(&mut io).poll_write(&mut cx, &[0x33]);
|
||||
assert!(ready.is_ready());
|
||||
assert!(
|
||||
!io.quota_write_wake_scheduled,
|
||||
"successful write must clear deferred wake scheduling flag"
|
||||
);
|
||||
assert!(
|
||||
io.quota_write_retry_sleep.is_none(),
|
||||
"successful write must clear deferred sleep slot"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn light_fuzz_writer_repoll_schedule_keeps_wake_budget_bounded() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let _retained = saturate_quota_user_locks();
|
||||
let user = "quota-backoff-fuzz-writer";
|
||||
let stats = Arc::new(Stats::new());
|
||||
|
||||
let lock = quota_user_lock(user);
|
||||
let held_guard = lock
|
||||
.try_lock()
|
||||
.expect("test must hold quota lock before fuzz loop");
|
||||
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::clone(&stats),
|
||||
user.to_string(),
|
||||
Some(2048),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
|
||||
let wake_counter = Arc::new(WakeCounter::default());
|
||||
let waker = Waker::from(Arc::clone(&wake_counter));
|
||||
let mut cx = Context::from_waker(&waker);
|
||||
|
||||
let mut seed = 0x5EED_CAFE_7788_9900u64;
|
||||
for _ in 0..64 {
|
||||
let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x51]);
|
||||
assert!(poll.is_pending());
|
||||
|
||||
seed ^= seed << 7;
|
||||
seed ^= seed >> 9;
|
||||
seed ^= seed << 8;
|
||||
let sleep_ms = (seed % 4) as u64;
|
||||
tokio::time::sleep(Duration::from_millis(sleep_ms)).await;
|
||||
}
|
||||
|
||||
assert!(
|
||||
wake_counter.wakes.load(Ordering::Relaxed) <= 24,
|
||||
"fuzzed repoll schedule must keep wake budget bounded; observed wakes={}",
|
||||
wake_counter.wakes.load(Ordering::Relaxed)
|
||||
);
|
||||
|
||||
drop(held_guard);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn stress_multi_waiter_contention_keeps_global_wake_budget_bounded() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let _retained = saturate_quota_user_locks();
|
||||
let user = format!("quota-backoff-stress-{}", std::process::id());
|
||||
let stats = Arc::new(Stats::new());
|
||||
|
||||
let lock = quota_user_lock(&user);
|
||||
let held_guard = lock
|
||||
.try_lock()
|
||||
.expect("test must hold quota lock before launching stress waiters");
|
||||
|
||||
let waiters = 48usize;
|
||||
let mut ios = Vec::with_capacity(waiters);
|
||||
let mut wake_counters = Vec::with_capacity(waiters);
|
||||
|
||||
for _ in 0..waiters {
|
||||
ios.push(StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::clone(&stats),
|
||||
user.clone(),
|
||||
Some(4096),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
));
|
||||
}
|
||||
|
||||
for io in &mut ios {
|
||||
let counter = Arc::new(WakeCounter::default());
|
||||
let waker = Waker::from(Arc::clone(&counter));
|
||||
let mut cx = Context::from_waker(&waker);
|
||||
let pending = Pin::new(io).poll_write(&mut cx, &[0x61]);
|
||||
assert!(pending.is_pending());
|
||||
wake_counters.push(counter);
|
||||
}
|
||||
|
||||
let start = Instant::now();
|
||||
while start.elapsed() < Duration::from_millis(120) {
|
||||
for (idx, counter) in wake_counters.iter().enumerate() {
|
||||
if counter.wakes.load(Ordering::Relaxed) > 0 {
|
||||
let waker = Waker::from(Arc::clone(counter));
|
||||
let mut cx = Context::from_waker(&waker);
|
||||
let pending = Pin::new(&mut ios[idx]).poll_write(&mut cx, &[0x62]);
|
||||
assert!(pending.is_pending());
|
||||
}
|
||||
}
|
||||
tokio::time::sleep(Duration::from_millis(1)).await;
|
||||
}
|
||||
|
||||
let total_wakes: usize = wake_counters
|
||||
.iter()
|
||||
.map(|counter| counter.wakes.load(Ordering::Relaxed))
|
||||
.sum();
|
||||
|
||||
assert!(
|
||||
total_wakes <= waiters * 20,
|
||||
"stress contention must keep aggregate wake budget bounded; waiters={waiters}, wakes={total_wakes}"
|
||||
);
|
||||
|
||||
drop(held_guard);
|
||||
}
|
||||
|
|
@ -0,0 +1,246 @@
|
|||
use super::*;
|
||||
use crate::stats::Stats;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
|
||||
use std::task::{Context, Poll, Waker};
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadBuf};
|
||||
use tokio::time::{Duration, timeout};
|
||||
|
||||
#[derive(Default)]
|
||||
struct WakeCounter {
|
||||
wakes: AtomicUsize,
|
||||
}
|
||||
|
||||
impl std::task::Wake for WakeCounter {
|
||||
fn wake(self: Arc<Self>) {
|
||||
self.wakes.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
fn wake_by_ref(self: &Arc<Self>) {
|
||||
self.wakes.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
fn quota_test_guard() -> impl Drop {
|
||||
super::quota_user_lock_test_scope()
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn positive_uncontended_quota_limited_writer_completes() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::clone(&stats),
|
||||
"tdd-uncontended".to_string(),
|
||||
Some(1024),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
|
||||
let result = io.write_all(&[0x41, 0x42, 0x43]).await;
|
||||
assert!(result.is_ok(), "uncontended writer must complete");
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn adversarial_contended_writers_without_repoll_must_not_wake_storm() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("tdd-writer-storm-{}", std::process::id());
|
||||
let held = quota_user_lock(&user);
|
||||
let _held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold quota lock before polling writers");
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
let writers = 24usize;
|
||||
let mut ios = Vec::with_capacity(writers);
|
||||
let mut wake_counters = Vec::with_capacity(writers);
|
||||
|
||||
for _ in 0..writers {
|
||||
ios.push(StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::clone(&stats),
|
||||
user.clone(),
|
||||
Some(1024),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
));
|
||||
}
|
||||
|
||||
for io in &mut ios {
|
||||
let counter = Arc::new(WakeCounter::default());
|
||||
let waker = Waker::from(Arc::clone(&counter));
|
||||
let mut cx = Context::from_waker(&waker);
|
||||
let poll = Pin::new(io).poll_write(&mut cx, &[0xAA]);
|
||||
assert!(poll.is_pending(), "writer must be pending under held lock");
|
||||
wake_counters.push(counter);
|
||||
}
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(25)).await;
|
||||
|
||||
let total_wakes: usize = wake_counters
|
||||
.iter()
|
||||
.map(|counter| counter.wakes.load(Ordering::Relaxed))
|
||||
.sum();
|
||||
|
||||
assert!(
|
||||
total_wakes <= writers * 4,
|
||||
"retry scheduler must remain bounded without repoll; observed wakes={total_wakes}, writers={writers}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn adversarial_contended_readers_without_repoll_must_not_wake_storm() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("tdd-reader-storm-{}", std::process::id());
|
||||
let held = quota_user_lock(&user);
|
||||
let _held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold quota lock before polling readers");
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
let readers = 24usize;
|
||||
let mut ios = Vec::with_capacity(readers);
|
||||
let mut wake_counters = Vec::with_capacity(readers);
|
||||
|
||||
for _ in 0..readers {
|
||||
ios.push(StatsIo::new(
|
||||
tokio::io::empty(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::clone(&stats),
|
||||
user.clone(),
|
||||
Some(1024),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
));
|
||||
}
|
||||
|
||||
for io in &mut ios {
|
||||
let counter = Arc::new(WakeCounter::default());
|
||||
let waker = Waker::from(Arc::clone(&counter));
|
||||
let mut cx = Context::from_waker(&waker);
|
||||
let mut storage = [0u8; 1];
|
||||
let mut buf = ReadBuf::new(&mut storage);
|
||||
let poll = Pin::new(io).poll_read(&mut cx, &mut buf);
|
||||
assert!(poll.is_pending(), "reader must be pending under held lock");
|
||||
wake_counters.push(counter);
|
||||
}
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(25)).await;
|
||||
|
||||
let total_wakes: usize = wake_counters
|
||||
.iter()
|
||||
.map(|counter| counter.wakes.load(Ordering::Relaxed))
|
||||
.sum();
|
||||
|
||||
assert!(
|
||||
total_wakes <= readers * 4,
|
||||
"retry scheduler must remain bounded without repoll; observed wakes={total_wakes}, readers={readers}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn integration_contended_waiters_resume_after_lock_release() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("tdd-resume-{}", std::process::id());
|
||||
let held = quota_user_lock(&user);
|
||||
let held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold quota lock before launching waiters");
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
let mut waiters = Vec::new();
|
||||
for _ in 0..12 {
|
||||
let stats = Arc::clone(&stats);
|
||||
let user = user.clone();
|
||||
waiters.push(tokio::spawn(async move {
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
stats,
|
||||
user,
|
||||
Some(2048),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
io.write_all(&[0x5A]).await
|
||||
}));
|
||||
}
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(5)).await;
|
||||
drop(held_guard);
|
||||
|
||||
timeout(Duration::from_secs(1), async {
|
||||
for waiter in waiters {
|
||||
let result = waiter.await.expect("waiter task must not panic");
|
||||
assert!(result.is_ok(), "waiter must complete after release");
|
||||
}
|
||||
})
|
||||
.await
|
||||
.expect("all waiters must complete in bounded time");
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn light_fuzz_contention_rounds_keep_retry_wakes_bounded() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let mut seed = 0x9E37_79B9_AA55_1234u64;
|
||||
for round in 0..20u32 {
|
||||
seed ^= seed << 7;
|
||||
seed ^= seed >> 9;
|
||||
seed ^= seed << 8;
|
||||
|
||||
let writers = 8 + (seed as usize % 12);
|
||||
let sleep_ms = 10 + (seed as u64 % 15);
|
||||
let user = format!("tdd-fuzz-{}-{round}", std::process::id());
|
||||
|
||||
let held = quota_user_lock(&user);
|
||||
let _held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold quota lock in fuzz round");
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
let mut ios = Vec::with_capacity(writers);
|
||||
let mut wake_counters = Vec::with_capacity(writers);
|
||||
|
||||
for _ in 0..writers {
|
||||
ios.push(StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::clone(&stats),
|
||||
user.clone(),
|
||||
Some(2048),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
));
|
||||
}
|
||||
|
||||
for io in &mut ios {
|
||||
let counter = Arc::new(WakeCounter::default());
|
||||
let waker = Waker::from(Arc::clone(&counter));
|
||||
let mut cx = Context::from_waker(&waker);
|
||||
let poll = Pin::new(io).poll_write(&mut cx, &[0x7A]);
|
||||
assert!(matches!(poll, Poll::Pending));
|
||||
wake_counters.push(counter);
|
||||
}
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(sleep_ms)).await;
|
||||
|
||||
let total_wakes: usize = wake_counters
|
||||
.iter()
|
||||
.map(|counter| counter.wakes.load(Ordering::Relaxed))
|
||||
.sum();
|
||||
|
||||
assert!(
|
||||
total_wakes <= writers * 4,
|
||||
"fuzz round must keep wakes bounded; round={round}, writers={writers}, wakes={total_wakes}, sleep_ms={sleep_ms}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
@ -137,10 +137,10 @@ async fn quota_lock_contention_writer_schedules_single_deferred_wake_until_lock_
|
|||
for _ in 0..8 {
|
||||
tokio::task::yield_now().await;
|
||||
}
|
||||
assert_eq!(
|
||||
wake_counter.wakes.load(Ordering::Relaxed),
|
||||
wakes_after_first_yield,
|
||||
"writer contention should not schedule unbounded wake storms before lock acquisition"
|
||||
let wakes_after_second_window = wake_counter.wakes.load(Ordering::Relaxed);
|
||||
assert!(
|
||||
wakes_after_second_window <= wakes_after_first_yield.saturating_add(2),
|
||||
"writer contention should keep retry wakes bounded before lock acquisition: before={wakes_after_first_yield}, after={wakes_after_second_window}"
|
||||
);
|
||||
|
||||
drop(held_lock);
|
||||
|
|
|
|||
|
|
@ -1884,6 +1884,32 @@ impl Stats {
|
|||
stats.octets_to_client.fetch_add(bytes, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn sub_user_octets_to(&self, user: &str, bytes: u64) {
|
||||
if !self.telemetry_user_enabled() {
|
||||
return;
|
||||
}
|
||||
self.maybe_cleanup_user_stats();
|
||||
let Some(stats) = self.user_stats.get(user) else {
|
||||
return;
|
||||
};
|
||||
|
||||
Self::touch_user_stats(stats.value());
|
||||
let counter = &stats.octets_to_client;
|
||||
let mut current = counter.load(Ordering::Relaxed);
|
||||
loop {
|
||||
let next = current.saturating_sub(bytes);
|
||||
match counter.compare_exchange_weak(
|
||||
current,
|
||||
next,
|
||||
Ordering::Relaxed,
|
||||
Ordering::Relaxed,
|
||||
) {
|
||||
Ok(_) => break,
|
||||
Err(actual) => current = actual,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn increment_user_msgs_from(&self, user: &str) {
|
||||
if !self.telemetry_user_enabled() {
|
||||
return;
|
||||
|
|
@ -2440,3 +2466,7 @@ mod connection_lease_security_tests;
|
|||
#[cfg(test)]
|
||||
#[path = "tests/replay_checker_security_tests.rs"]
|
||||
mod replay_checker_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/user_octets_sub_security_tests.rs"]
|
||||
mod user_octets_sub_security_tests;
|
||||
|
|
|
|||
|
|
@ -0,0 +1,151 @@
|
|||
use super::*;
|
||||
use std::sync::Arc;
|
||||
use std::thread;
|
||||
|
||||
#[test]
|
||||
fn sub_user_octets_to_underflow_saturates_at_zero() {
|
||||
let stats = Stats::new();
|
||||
let user = "sub-underflow-user";
|
||||
|
||||
stats.add_user_octets_to(user, 3);
|
||||
stats.sub_user_octets_to(user, 100);
|
||||
|
||||
assert_eq!(stats.get_user_total_octets(user), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sub_user_octets_to_does_not_affect_octets_from_client() {
|
||||
let stats = Stats::new();
|
||||
let user = "sub-isolation-user";
|
||||
|
||||
stats.add_user_octets_from(user, 17);
|
||||
stats.add_user_octets_to(user, 5);
|
||||
stats.sub_user_octets_to(user, 3);
|
||||
|
||||
assert_eq!(stats.get_user_total_octets(user), 19);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn light_fuzz_add_sub_model_matches_saturating_reference() {
|
||||
let stats = Stats::new();
|
||||
let user = "sub-fuzz-user";
|
||||
let mut seed = 0x91D2_4CB8_EE77_1101u64;
|
||||
let mut model_to = 0u64;
|
||||
|
||||
for _ in 0..8192 {
|
||||
seed ^= seed << 7;
|
||||
seed ^= seed >> 9;
|
||||
seed ^= seed << 8;
|
||||
|
||||
let amt = ((seed >> 8) & 0x3f) + 1;
|
||||
if (seed & 1) == 0 {
|
||||
stats.add_user_octets_to(user, amt);
|
||||
model_to = model_to.saturating_add(amt);
|
||||
} else {
|
||||
stats.sub_user_octets_to(user, amt);
|
||||
model_to = model_to.saturating_sub(amt);
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(stats.get_user_total_octets(user), model_to);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stress_parallel_add_sub_never_underflows_or_panics() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let user = "sub-stress-user";
|
||||
// Pre-fund with a large offset so subtractions never saturate at zero.
|
||||
// This guarantees commutative updates, making the final state deterministic.
|
||||
let base_offset = 10_000_000u64;
|
||||
stats.add_user_octets_to(user, base_offset);
|
||||
|
||||
let mut workers = Vec::new();
|
||||
|
||||
for tid in 0..16u64 {
|
||||
let stats_for_thread = Arc::clone(&stats);
|
||||
workers.push(thread::spawn(move || {
|
||||
let mut seed = 0xD00D_1000_0000_0000u64 ^ tid;
|
||||
let mut net_delta = 0i64;
|
||||
for _ in 0..4096 {
|
||||
seed ^= seed << 7;
|
||||
seed ^= seed >> 9;
|
||||
seed ^= seed << 8;
|
||||
let amt = ((seed >> 8) & 0x1f) + 1;
|
||||
|
||||
if (seed & 1) == 0 {
|
||||
stats_for_thread.add_user_octets_to(user, amt);
|
||||
net_delta += amt as i64;
|
||||
} else {
|
||||
stats_for_thread.sub_user_octets_to(user, amt);
|
||||
net_delta -= amt as i64;
|
||||
}
|
||||
}
|
||||
|
||||
net_delta
|
||||
}));
|
||||
}
|
||||
|
||||
let mut expected_net_delta = 0i64;
|
||||
for worker in workers {
|
||||
expected_net_delta += worker
|
||||
.join()
|
||||
.expect("sub-user stress worker must not panic");
|
||||
}
|
||||
|
||||
let expected_total = (base_offset as i64 + expected_net_delta) as u64;
|
||||
let total = stats.get_user_total_octets(user);
|
||||
assert_eq!(
|
||||
total, expected_total,
|
||||
"concurrent add/sub lost updates or suffered ABA races"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sub_user_octets_to_missing_user_is_noop() {
|
||||
let stats = Stats::new();
|
||||
stats.sub_user_octets_to("missing-user", 1024);
|
||||
assert_eq!(stats.get_user_total_octets("missing-user"), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stress_parallel_per_user_models_remain_exact() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let mut workers = Vec::new();
|
||||
|
||||
for tid in 0..16u64 {
|
||||
let stats_for_thread = Arc::clone(&stats);
|
||||
workers.push(thread::spawn(move || {
|
||||
let user = format!("sub-per-user-{tid}");
|
||||
let mut seed = 0xFACE_0000_0000_0000u64 ^ tid;
|
||||
let mut model = 0u64;
|
||||
|
||||
for _ in 0..4096 {
|
||||
seed ^= seed << 7;
|
||||
seed ^= seed >> 9;
|
||||
seed ^= seed << 8;
|
||||
let amt = ((seed >> 8) & 0x3f) + 1;
|
||||
|
||||
if (seed & 1) == 0 {
|
||||
stats_for_thread.add_user_octets_to(&user, amt);
|
||||
model = model.saturating_add(amt);
|
||||
} else {
|
||||
stats_for_thread.sub_user_octets_to(&user, amt);
|
||||
model = model.saturating_sub(amt);
|
||||
}
|
||||
}
|
||||
|
||||
(user, model)
|
||||
}));
|
||||
}
|
||||
|
||||
for worker in workers {
|
||||
let (user, model) = worker
|
||||
.join()
|
||||
.expect("per-user subtract stress worker must not panic");
|
||||
assert_eq!(
|
||||
stats.get_user_total_octets(&user),
|
||||
model,
|
||||
"per-user parallel model diverged"
|
||||
);
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue