Compare commits

...

14 Commits

Author SHA1 Message Date
Alexey 0475844701
Merge branch 'flow' into flow 2026-03-23 11:35:44 +03:00
David Osipov 1abf9bd05c
Refactor CI workflows: rename build job and streamline stress testing setup 2026-03-23 12:27:57 +04:00
David Osipov 6f17d4d231
Add comprehensive security tests for quota management and relay functionality
- Introduced `relay_dual_lock_race_harness_security_tests.rs` to validate user liveness during lock hold and release cycles.
- Added `relay_quota_extended_attack_surface_security_tests.rs` to cover various quota scenarios including positive, negative, edge cases, and adversarial conditions.
- Implemented `relay_quota_lock_eviction_lifecycle_tdd_tests.rs` to ensure proper eviction of stale entries and lifecycle management of quota locks.
- Created `relay_quota_lock_eviction_stress_security_tests.rs` to stress test the eviction mechanism under high churn conditions.
- Enhanced `relay_quota_lock_pressure_adversarial_tests.rs` to verify reclaiming of unreferenced entries after explicit eviction.
- Developed `relay_quota_retry_allocation_latency_security_tests.rs` to benchmark and validate latency and allocation behavior under contention.
2026-03-23 12:04:41 +04:00
Alexey bf30e93284
Merge pull request #545 from Dimasssss/patch-1
Update CONFIG_PARAMS.en.md and FAQ
2026-03-23 11:00:08 +03:00
David Osipov 91be148b72
Security hardening, concurrency fixes, and expanded test coverage
This commit introduces a comprehensive set of improvements to enhance
the security, reliability, and configurability of the proxy server,
specifically targeting adversarial resilience and high-load concurrency.

Security & Cryptography:
- Zeroize MTProto cryptographic key material (`dec_key`, `enc_key`)
  immediately after use to prevent memory leakage on early returns.
- Move TLS handshake replay tracking after full policy/ALPN validation
  to prevent cache poisoning by unauthenticated probes.
- Add `proxy_protocol_trusted_cidrs` configuration to restrict PROXY
  protocol headers to trusted networks, rejecting spoofed IPs.

Adversarial Resilience & DoS Mitigation:
- Implement "Tiny Frame Debt" tracking in the middle-relay to prevent
  CPU exhaustion from malicious 0-byte or 1-byte frame floods.
- Add `mask_relay_max_bytes` to strictly bound unauthenticated fallback
  connections, preventing the proxy from being abused as an open relay.
- Add a 5ms prefetch window (`mask_classifier_prefetch_timeout_ms`) to
  correctly assemble and classify fragmented HTTP/1.1 and HTTP/2 probes
  (e.g., `PRI * HTTP/2.0`) before routing them to masking heuristics.
- Prevent recursive masking loops (FD exhaustion) by verifying the mask
  target is not the proxy's own listener via local interface enumeration.

Concurrency & Reliability:
- Eliminate executor waker storms during quota lock contention by replacing
  the spin-waker task with inline `Sleep` and exponential backoff.
- Roll back user quota reservations (`rollback_me2c_quota_reservation`)
  if a network write fails, preventing Head-of-Line (HoL) blocking from
  permanently burning data quotas.
- Recover gracefully from idle-registry `Mutex` poisoning instead of
  panicking, ensuring isolated thread failures do not break the proxy.
- Fix `auth_probe_scan_start_offset` modulo logic to ensure bounds safety.

Testing:
- Add extensive adversarial, timing, fuzzing, and invariant test suites
  for both the client and handshake modules.
2026-03-22 23:09:49 +04:00
Dimasssss d4cda6d546
Update CONFIG_PARAMS.en.md 2026-03-22 21:56:21 +03:00
Alexey e35d69c61f
Merge pull request #544 from avbor/main
DOCS: VPS doube hop manual Ru\En
2026-03-22 21:45:13 +03:00
Dimasssss a353a94175
Update FAQ.en.md 2026-03-22 21:35:39 +03:00
Dimasssss b856250b2c
Update FAQ.ru.md 2026-03-22 21:30:17 +03:00
Alexander 97d1476ded
Merge branch 'flow' into main 2026-03-22 20:52:58 +03:00
Alexander cde14fc1bf
Create VPS_DOUBLE_HOP.en.md
Added VPS double hop with AmneziaWG manual
2026-03-22 20:35:09 +03:00
Alexander 5723d50d0b
Create VPS_DOUBLE_HOP.ru.md
Added VPS double hop with AmneziaWG manual
2026-03-22 20:04:14 +03:00
Alexey 3eb384e02a
Update middle_relay.rs 2026-03-22 17:53:32 +03:00
Dimasssss c960e0e245
Update CONFIG_PARAMS.en.md 2026-03-22 17:44:52 +03:00
97 changed files with 14742 additions and 255 deletions

View File

@ -11,7 +11,7 @@ env:
jobs:
build:
name: Build
name: Compile, Test, Lint
runs-on: ubuntu-latest
permissions:
@ -39,23 +39,11 @@ jobs:
restore-keys: |
${{ runner.os }}-cargo-
- name: Build Release
run: cargo build --release --verbose
- name: Compile (no tests)
run: cargo check --workspace --all-features --lib --bins --verbose
- name: Run tests
run: cargo test --verbose
- name: Stress quota-lock suites (PR only)
if: github.event_name == 'pull_request'
env:
RUST_TEST_THREADS: 16
run: |
set -euo pipefail
for i in $(seq 1 12); do
echo "[quota-lock-stress] iteration ${i}/12"
cargo test quota_lock_ --bin telemt -- --nocapture --test-threads 16
cargo test relay_quota_wake --bin telemt -- --nocapture --test-threads 16
done
- name: Run tests (single pass)
run: cargo test --workspace --all-features --verbose
# clippy dont fail on warnings because of active development of telemt
# and many warnings

57
.github/workflows/stress.yml vendored Normal file
View File

@ -0,0 +1,57 @@
name: Stress Tests
on:
workflow_dispatch:
schedule:
- cron: '0 2 * * *'
pull_request:
branches: ["*"]
paths:
- src/proxy/**
- src/transport/**
- src/stream/**
- src/protocol/**
- src/tls_front/**
- Cargo.toml
- Cargo.lock
env:
CARGO_TERM_COLOR: always
jobs:
quota-lock-stress:
name: Quota-lock stress loop
runs-on: ubuntu-latest
permissions:
contents: read
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Install latest stable Rust toolchain
uses: dtolnay/rust-toolchain@stable
- name: Cache cargo registry and build artifacts
uses: actions/cache@v4
with:
path: |
~/.cargo/registry
~/.cargo/git
target
key: ${{ runner.os }}-cargo-stress-${{ hashFiles('**/Cargo.lock') }}
restore-keys: |
${{ runner.os }}-cargo-stress-
${{ runner.os }}-cargo-
- name: Run quota-lock stress suites
env:
RUST_TEST_THREADS: 16
run: |
set -euo pipefail
for i in $(seq 1 12); do
echo "[quota-lock-stress] iteration ${i}/12"
cargo test quota_lock_ --bin telemt -- --nocapture --test-threads 16
cargo test relay_quota_wake --bin telemt -- --nocapture --test-threads 16
done

38
Cargo.lock generated
View File

@ -1454,9 +1454,9 @@ dependencies = [
[[package]]
name = "iri-string"
version = "0.7.10"
version = "0.7.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c91338f0783edbd6195decb37bae672fd3b165faffb89bf7b9e6942f8b1a731a"
checksum = "d8e7418f59cc01c88316161279a7f665217ae316b388e58a0d10e29f54f1e5eb"
dependencies = [
"memchr",
"serde",
@ -1486,7 +1486,7 @@ dependencies = [
"cesu8",
"cfg-if",
"combine",
"jni-sys",
"jni-sys 0.3.1",
"log",
"thiserror 1.0.69",
"walkdir",
@ -1495,9 +1495,31 @@ dependencies = [
[[package]]
name = "jni-sys"
version = "0.3.0"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130"
checksum = "41a652e1f9b6e0275df1f15b32661cf0d4b78d4d87ddec5e0c3c20f097433258"
dependencies = [
"jni-sys 0.4.1",
]
[[package]]
name = "jni-sys"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c6377a88cb3910bee9b0fa88d4f42e1d2da8e79915598f65fb0c7ee14c878af2"
dependencies = [
"jni-sys-macros",
]
[[package]]
name = "jni-sys-macros"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38c0b942f458fe50cdac086d2f946512305e5631e720728f2a61aabcd47a6264"
dependencies = [
"quote",
"syn",
]
[[package]]
name = "jobserver"
@ -1659,9 +1681,9 @@ dependencies = [
[[package]]
name = "moka"
version = "0.12.14"
version = "0.12.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "85f8024e1c8e71c778968af91d43700ce1d11b219d127d79fb2934153b82b42b"
checksum = "957228ad12042ee839f93c8f257b62b4c0ab5eaae1d4fa60de53b27c9d7c5046"
dependencies = [
"crossbeam-channel",
"crossbeam-epoch",
@ -2771,7 +2793,7 @@ checksum = "7b2093cf4c8eb1e67749a6762251bc9cd836b6fc171623bd0a9d324d37af2417"
[[package]]
name = "telemt"
version = "3.3.29"
version = "3.3.30"
dependencies = [
"aes",
"anyhow",

View File

@ -1,8 +1,11 @@
[package]
name = "telemt"
version = "3.3.29"
version = "3.3.30"
edition = "2024"
[features]
redteam_offline_expected_fail = []
[dependencies]
# C
libc = "0.2"

View File

@ -20,7 +20,7 @@ This document lists all configuration keys accepted by `config.toml`.
| Parameter | Type | Default | Constraints / validation | Description |
|---|---|---|---|---|
| data_path | `String \| null` | `null` | — | Optional runtime data directory path. |
| prefer_ipv6 | `bool` | `false` | — | Prefer IPv6 where applicable in runtime logic. |
| prefer_ipv6 | `bool` | `false` | Deprecated. Use `network.prefer`. | Deprecated legacy IPv6 preference flag migrated to `network.prefer`. |
| fast_mode | `bool` | `true` | — | Enables fast-path optimizations for traffic processing. |
| use_middle_proxy | `bool` | `true` | none | Enables ME transport mode; if `false`, runtime falls back to direct DC routing. |
| proxy_secret_path | `String \| null` | `"proxy-secret"` | Path may be `null`. | Path to Telegram infrastructure proxy-secret file used by ME handshake logic. |
@ -44,6 +44,7 @@ This document lists all configuration keys accepted by `config.toml`.
| me_writer_cmd_channel_capacity | `usize` | `4096` | Must be `> 0`. | Capacity of per-writer command channel. |
| me_route_channel_capacity | `usize` | `768` | Must be `> 0`. | Capacity of per-connection ME response route channel. |
| me_c2me_channel_capacity | `usize` | `1024` | Must be `> 0`. | Capacity of per-client command queue (client reader -> ME sender). |
| me_c2me_send_timeout_ms | `u64` | `4000` | `0..=60000`. | Maximum wait for enqueueing client->ME commands when the per-client queue is full (`0` keeps legacy unbounded wait). |
| me_reader_route_data_wait_ms | `u64` | `2` | `0..=20`. | Bounded wait for routing ME DATA to per-connection queue (`0` = no wait). |
| me_d2c_flush_batch_max_frames | `usize` | `32` | `1..=512`. | Max ME->client frames coalesced before flush. |
| me_d2c_flush_batch_max_bytes | `usize` | `131072` | `4096..=2_097_152`. | Max ME->client payload bytes coalesced before flush. |
@ -105,6 +106,8 @@ This document lists all configuration keys accepted by `config.toml`.
| me_warn_rate_limit_ms | `u64` | `5000` | Must be `> 0`. | Cooldown for repetitive ME warning logs (ms). |
| me_route_no_writer_mode | `"async_recovery_failfast" \| "inline_recovery_legacy" \| "hybrid_async_persistent"` | `"hybrid_async_persistent"` | — | Route behavior when no writer is immediately available. |
| me_route_no_writer_wait_ms | `u64` | `250` | `10..=5000`. | Max wait in async-recovery failfast mode (ms). |
| me_route_hybrid_max_wait_ms | `u64` | `3000` | `50..=60000`. | Maximum cumulative wait in hybrid no-writer mode before failfast fallback (ms). |
| me_route_blocking_send_timeout_ms | `u64` | `250` | `0..=5000`. | Maximum wait for blocking route-channel send fallback (`0` keeps legacy unbounded wait). |
| me_route_inline_recovery_attempts | `u32` | `3` | Must be `> 0`. | Inline recovery attempts in legacy mode. |
| me_route_inline_recovery_wait_ms | `u64` | `3000` | `10..=30000`. | Max inline recovery wait in legacy mode (ms). |
| fast_mode_min_tls_record | `usize` | `0` | — | Minimum TLS record size when fast-mode coalescing is enabled (`0` disables). |
@ -124,6 +127,7 @@ This document lists all configuration keys accepted by `config.toml`.
| me_secret_atomic_snapshot | `bool` | `true` | — | Keeps selector and secret bytes from the same snapshot atomically. |
| proxy_secret_len_max | `usize` | `256` | Must be within `[32, 4096]`. | Upper length limit for accepted proxy-secret bytes. |
| me_pool_drain_ttl_secs | `u64` | `90` | none | Time window where stale writers remain fallback-eligible after map change. |
| me_instadrain | `bool` | `false` | — | Forces draining stale writers to be removed on the next cleanup tick, bypassing TTL/deadline waiting. |
| me_pool_drain_threshold | `u64` | `128` | — | Max draining stale writers before batch force-close (`0` disables threshold cleanup). |
| me_pool_drain_soft_evict_enabled | `bool` | `true` | — | Enables gradual soft-eviction of stale writers during drain/reinit instead of immediate hard close. |
| me_pool_drain_soft_evict_grace_secs | `u64` | `30` | `0..=3600`. | Grace period before stale writers become soft-evict candidates. |
@ -203,6 +207,7 @@ This document lists all configuration keys accepted by `config.toml`.
| metrics_listen | `String \| null` | `null` | — | Full metrics bind address (`IP:PORT`), overrides `metrics_port`. |
| metrics_whitelist | `IpNetwork[]` | `["127.0.0.1/32", "::1/128"]` | — | CIDR whitelist for metrics endpoint access. |
| max_connections | `u32` | `10000` | — | Max concurrent client connections (`0` = unlimited). |
| accept_permit_timeout_ms | `u64` | `250` | `0..=60000`. | Maximum wait for acquiring a connection-slot permit before the accepted connection is dropped (`0` keeps legacy unbounded wait). |
Note: When `server.proxy_protocol` is enabled, incoming PROXY protocol headers are parsed from the first bytes of the connection and the client source address is replaced with `src_addr` from the header. For security, the peer source IP (the direct connection address) is verified against `server.proxy_protocol_trusted_cidrs`; if this list is empty, PROXY headers are rejected and the connection is considered untrusted.
@ -229,7 +234,7 @@ Note: When `server.proxy_protocol` is enabled, incoming PROXY protocol headers a
|---|---|---|---|---|
| ip | `IpAddr` | — | — | Listener bind IP. |
| announce | `String \| null` | — | — | Public IP/domain announced in proxy links (priority over `announce_ip`). |
| announce_ip | `IpAddr \| null` | — | | Deprecated legacy announce IP (migrated to `announce` if needed). |
| announce_ip | `IpAddr \| null` | — | Deprecated. Use `announce`. | Deprecated legacy announce IP (migrated to `announce` if needed). |
| proxy_protocol | `bool \| null` | `null` | — | Per-listener override for PROXY protocol enable flag. |
| reuse_allow | `bool` | `false` | — | Enables `SO_REUSEPORT` for multi-instance bind sharing. |
@ -269,6 +274,8 @@ Note: When `server.proxy_protocol` is enabled, incoming PROXY protocol headers a
| mask_shape_bucket_cap_bytes | `usize` | `4096` | Must be `>= mask_shape_bucket_floor_bytes`. | Maximum bucket size used by shape-channel hardening; traffic above cap is not padded further. |
| mask_shape_above_cap_blur | `bool` | `false` | Requires `mask_shape_hardening = true`; requires `mask_shape_above_cap_blur_max_bytes > 0`. | Adds bounded randomized tail bytes even when forwarded size already exceeds cap. |
| mask_shape_above_cap_blur_max_bytes | `usize` | `512` | Must be `<= 1048576`; must be `> 0` when `mask_shape_above_cap_blur = true`. | Maximum randomized extra bytes appended above cap. |
| mask_relay_max_bytes | `usize` | `5242880` | Must be `> 0`; must be `<= 67108864`. | Maximum relayed bytes per direction on unauthenticated masking fallback path. |
| mask_classifier_prefetch_timeout_ms | `u64` | `5` | Must be within `[5, 50]`. | Timeout budget (ms) for extending fragmented initial classifier window on masking fallback. |
| mask_timing_normalization_enabled | `bool` | `false` | Requires `mask_timing_normalization_floor_ms > 0`; requires `ceiling >= floor`. | Enables timing envelope normalization on masking outcomes. |
| mask_timing_normalization_floor_ms | `u64` | `0` | Must be `> 0` when timing normalization is enabled; must be `<= ceiling`. | Lower bound (ms) for masking outcome normalization target. |
| mask_timing_normalization_ceiling_ms | `u64` | `0` | Must be `>= floor`; must be `<= 60000`. | Upper bound (ms) for masking outcome normalization target. |

View File

@ -3,7 +3,7 @@
1. Go to @MTProxybot bot.
2. Enter the command `/newproxy`
3. Send the server IP and port. For example: 1.2.3.4:443
4. Open the config `nano /etc/telemt.toml`.
4. Open the config `nano /etc/telemt/telemt.toml`.
5. Copy and send the user secret from the [access.users] section to the bot.
6. Copy the tag received from the bot. For example 1234567890abcdef1234567890abcdef.
> [!WARNING]
@ -33,6 +33,9 @@ hello = "ad_tag"
hello2 = "ad_tag2"
```
## Why is middle proxy (ME) needed
https://github.com/telemt/telemt/discussions/167
## How many people can use 1 link
By default, 1 link can be used by any number of people.

View File

@ -3,7 +3,7 @@
1. Зайти в бота @MTProxybot.
2. Ввести команду `/newproxy`
3. Отправить IP и порт сервера. Например: 1.2.3.4:443
4. Открыть конфиг `nano /etc/telemt.toml`.
4. Открыть конфиг `nano /etc/telemt/telemt.toml`.
5. Скопировать и отправить боту секрет пользователя из раздела [access.users].
6. Скопировать полученный tag у бота. Например 1234567890abcdef1234567890abcdef.
> [!WARNING]
@ -33,6 +33,10 @@ hello = "ad_tag"
hello2 = "ad_tag2"
```
## Зачем нужен middle proxy (ME)
https://github.com/telemt/telemt/discussions/167
## Сколько человек может пользоваться 1 ссылкой
По умолчанию 1 ссылкой может пользоваться сколько угодно человек.

283
docs/VPS_DOUBLE_HOP.en.md Normal file
View File

@ -0,0 +1,283 @@
<img src="https://gist.githubusercontent.com/avbor/1f8a128e628f47249aae6e058a57610b/raw/19013276c035e91058e0a9799ab145f8e70e3ff5/scheme.svg">
## Concept
- **Server A** (__conditionally Russian Federation_):\
Entry point, receives Telegram proxy user traffic via **HAProxy** (port `443`)\
and sends it to the tunnel to Server **B**.\
Internal IP in the tunnel — `10.10.10.2`\
Port for HAProxy clients — `443\tcp`
- **Server B** (_conditionally Netherlands_):\
Exit point, runs **telemt** and accepts client connections through Server **A**.\
The server must have unrestricted access to Telegram servers.\
Internal IP in the tunnel — `10.10.10.1`\
AmneziaWG port — `8443\udp`\
Port for telemt clients — `443\tcp`
---
## Step 1. Setting up the AmneziaWG tunnel (A <-> B)
[AmneziaWG](https://github.com/amnezia-vpn/amneziawg-linux-kernel-module) must be installed on all servers.\
All following commands are given for **Ubuntu 24.04**.\
For RHEL-based distributions, installation instructions are available at the link above.
### Installing AmneziaWG (Servers A and B)
The following steps must be performed on each server:
#### 1. Adding the AmneziaWG repository and installing required packages:
```bash
sudo apt install -y software-properties-common python3-launchpadlib gnupg2 linux-headers-$(uname -r) && \
sudo add-apt-repository ppa:amnezia/ppa && \
sudo apt-get install -y amneziawg
```
#### 2. Generating a unique key pair:
```bash
cd /etc/amnezia/amneziawg && \
awg genkey | tee private.key | awg pubkey > public.key
```
As a result, you will get two files in the `/etc/amnezia/amneziawg` folder:\
`private.key` - private, and\
`public.key` - public server keys
#### 3. Configuring network interfaces:
Obfuscation parameters `S1`, `S2`, `H1`, `H2`, `H3`, `H4` must be strictly identical on both servers.\
Parameters `Jc`, `Jmin` and `Jmax` can differ.\
Parameters `I1-I5` ([Custom Protocol Signature](https://docs.amnezia.org/documentation/amnezia-wg/)) must be specified on the client side (Server **A**).
Recommendations for choosing values:
```text
Jc — 1 ≤ Jc ≤ 128; from 4 to 12 inclusive
Jmin — Jmax > Jmin < 1280*; recommended 8
Jmax — Jmin < Jmax 1280*; recommended 80
S1 — S1 ≤ 1132* (1280* - 148 = 1132); S1 + 56 ≠ S2;
recommended range from 15 to 150 inclusive
S2 — S2 ≤ 1188* (1280* - 92 = 1188);
recommended range from 15 to 150 inclusive
H1/H2/H3/H4 — must be unique and differ from each other;
recommended range from 5 to 2147483647 inclusive
* It is assumed that the Internet connection has an MTU of 1280.
```
> [!IMPORTANT]
> It is recommended to use your own, unique values.\
> You can use the [generator](https://htmlpreview.github.io/?https://gist.githubusercontent.com/avbor/955782b5c37b06240b243aa375baeac5/raw/e8b269ff0089a27effd88f8d925179b78e5666c4/awg-gen.html) to select parameters.
#### Server B Configuration (Netherlands):
Create the interface configuration file (`awg0`)
```bash
nano /etc/amnezia/amneziawg/awg0.conf
```
File content
```ini
[Interface]
Address = 10.10.10.1/24
ListenPort = 8443
PrivateKey = <PRIVATE_KEY_SERVER_B>
SaveConfig = true
Jc = 4
Jmin = 8
Jmax = 80
S1 = 29
S2 = 15
H1 = 2087563914
H2 = 188817757
H3 = 101784570
H4 = 432174303
[Peer]
PublicKey = <PUBLIC_KEY_SERVER_A>
AllowedIPs = 10.10.10.2/32
```
`ListenPort` - the port on which the server will wait for connections, you can choose any free one.\
`<PRIVATE_KEY_SERVER_B>` - the content of the `private.key` file from Server **B**.\
`<PUBLIC_KEY_SERVER_A>` - the content of the `public.key` file from Server **A**.
Open the port on the firewall (if enabled):
```bash
sudo ufw allow from <PUBLIC_IP_SERVER_A> to any port 8443 proto udp
```
`<PUBLIC_IP_SERVER_A>` - the external IP address of Server **A**.
#### Server A Configuration (Russian Federation):
Create the interface configuration file (awg0)
```bash
nano /etc/amnezia/amneziawg/awg0.conf
```
File content
```ini
[Interface]
Address = 10.10.10.2/24
PrivateKey = <PRIVATE_KEY_SERVER_A>
Jc = 4
Jmin = 8
Jmax = 80
S1 = 29
S2 = 15
H1 = 2087563914
H2 = 188817757
H3 = 101784570
H4 = 432174303
I1 = <b 0xc10000000108981eba846e21f74e00>
I2 = <b 0xc20000000108981eba846e21f74e00>
I3 = <b 0xc30000000108981eba846e21f74e00>
I4 = <b 0x43981eba846e21f74e>
I5 = <b 0x43981eba846e21f74e>
[Peer]
PublicKey = <PUBLIC_KEY_SERVER_B>
Endpoint = <PUBLIC_IP_SERVER_B>:8443
AllowedIPs = 10.10.10.1/32
PersistentKeepalive = 25
```
`<PRIVATE_KEY_SERVER_A>` - the content of the `private.key` file from Server **A**.\
`<PUBLIC_KEY_SERVER_B>` - the content of the `public.key` file from Server **B**.\
`<PUBLIC_IP_SERVER_B>` - the public IP address of Server **B**.
Enable the tunnel on both servers:
```bash
sudo systemctl enable --now awg-quick@awg0
```
Make sure Server B is accessible from Server A through the tunnel.
```bash
ping 10.10.10.1
PING 10.10.10.1 (10.10.10.1) 56(84) bytes of data.
64 bytes from 10.10.10.1: icmp_seq=1 ttl=64 time=35.1 ms
64 bytes from 10.10.10.1: icmp_seq=2 ttl=64 time=35.0 ms
64 bytes from 10.10.10.1: icmp_seq=3 ttl=64 time=35.1 ms
^C
```
---
## Step 2. Installing telemt on Server B (conditionally Netherlands)
Installation and configuration are described [here](https://github.com/telemt/telemt/blob/main/docs/QUICK_START_GUIDE.ru.md) or [here](https://gitlab.com/An0nX/telemt-docker#-quick-start-docker-compose).\
It is assumed that telemt expects connections on port `443\tcp`.
In the telemt config, you must enable the `Proxy` protocol and restrict connections to it only through the tunnel.
```toml
[server]
port = 443
listen_addr_ipv4 = "10.10.10.1"
proxy_protocol = true
```
Also, for correct link generation, specify the FQDN or IP address and port of Server `A`
```toml
[general.links]
show = "*"
public_host = "<FQDN_OR_IP_SERVER_A>"
public_port = 443
```
Open the port on the firewall (if enabled):
```bash
sudo ufw allow from 10.10.10.2 to any port 443 proto tcp
```
---
## Step 3. Configuring HAProxy on Server A (Russian Federation)
Since the version in the standard Ubuntu repository is relatively old, it makes sense to use the official Docker image.\
[Instructions](https://docs.docker.com/engine/install/ubuntu/) for installing Docker on Ubuntu.
> [!WARNING]
> By default, regular users do not have rights to use ports < 1024.
> Attempts to run HAProxy on port 443 can lead to errors:
> ```
> [ALERT] (8) : Binding [/usr/local/etc/haproxy/haproxy.cfg:17] for frontend tcp_in_443:
> protocol tcpv4: cannot bind socket (Permission denied) for [0.0.0.0:443].
> ```
> There are two simple ways to bypass this restriction, choose one:
> 1. At the OS level, change the net.ipv4.ip_unprivileged_port_start setting to allow users to use all ports:
> ```
> echo "net.ipv4.ip_unprivileged_port_start = 0" | sudo tee -a /etc/sysctl.conf && sudo sysctl -p
> ```
> or
>
> 2. Run HAProxy as root:
> Uncomment the `user: "root"` parameter in docker-compose.yaml.
#### Create a folder for HAProxy:
```bash
mkdir -p /opt/docker-compose/haproxy && cd $_
```
#### Create the docker-compose.yaml file
`nano docker-compose.yaml`
File content
```yaml
services:
haproxy:
image: haproxy:latest
container_name: haproxy
restart: unless-stopped
# user: "root"
network_mode: "host"
volumes:
- ./haproxy.cfg:/usr/local/etc/haproxy/haproxy.cfg:ro
logging:
driver: "json-file"
options:
max-size: "1m"
max-file: "1"
```
#### Create the haproxy.cfg config file
Accept connections on port 443\tcp and send them through the tunnel to Server `B` 10.10.10.1:443
`nano haproxy.cfg`
File content
```haproxy
global
log stdout format raw local0
maxconn 10000
defaults
log global
mode tcp
option tcplog
option clitcpka
option srvtcpka
timeout connect 5s
timeout client 2h
timeout server 2h
timeout check 5s
frontend tcp_in_443
bind *:443
maxconn 8000
option tcp-smart-accept
default_backend telemt_nodes
backend telemt_nodes
option tcp-smart-connect
server server_a 10.10.10.1:443 check inter 5s rise 2 fall 3 send-proxy-v2
```
> [!WARNING]
> **The file must end with an empty line, otherwise HAProxy will not start!**
#### Allow port 443\tcp in the firewall (if enabled)
```bash
sudo ufw allow 443/tcp
```
#### Start the HAProxy container
```bash
docker compose up -d
```
If everything is configured correctly, you can now try connecting Telegram clients using links from the telemt log\api.

287
docs/VPS_DOUBLE_HOP.ru.md Normal file
View File

@ -0,0 +1,287 @@
<img src="https://gist.githubusercontent.com/avbor/1f8a128e628f47249aae6e058a57610b/raw/19013276c035e91058e0a9799ab145f8e70e3ff5/scheme.svg">
## Концепция
- **Сервер A** (_РФ_):\
Точка входа, принимает трафик пользователей Telegram-прокси через **HAProxy** (порт `443`)\
и отправляет в туннель на Сервер **B**.\
Внутренний IP в туннеле — `10.10.10.2`\
Порт для клиентов HAProxy — `443\tcp`
- **Сервер B** (_условно Нидерланды_):\
Точка выхода, на нем работает **telemt** и принимает подключения клиентов через Сервер **A**.\
На сервере должен быть неограниченный доступ до серверов Telegram.\
Внутренний IP в туннеле — `10.10.10.1`\
Порт AmneziaWG — `8443\udp`\
Порт для клиентов telemt — `443\tcp`
---
## Шаг 1. Настройка туннеля AmneziaWG (A <-> B)
На всех серверах необходимо установить [amneziawg](https://github.com/amnezia-vpn/amneziawg-linux-kernel-module).\
Далее все команды даны для **Ununtu 24.04**.\
Для RHEL-based дистрибутивов инструкция по установке есть по ссылке выше.
### Установка AmneziaWG (Сервера A и B)
На каждом из серверов необходимо выполнить следующие шаги:
#### 1. Добавление репозитория AmneziaWG и установка необходимых пакетов:
```bash
sudo apt install -y software-properties-common python3-launchpadlib gnupg2 linux-headers-$(uname -r) && \
sudo add-apt-repository ppa:amnezia/ppa && \
sudo apt-get install -y amneziawg
```
#### 2. Генерация уникальной пары ключей:
```bash
cd /etc/amnezia/amneziawg && \
awg genkey | tee private.key | awg pubkey > public.key
```
В результате вы получите в папке `/etc/amnezia/amneziawg` два файла:\
`private.key` - приватный и\
`public.key` - публичный ключи сервера
#### 3. Настройка сетевых интерфейсов:
Параметры обфускации `S1`, `S2`, `H1`, `H2`, `H3`, `H4` должны быть строго идентичными на обоих серверах.\
Параметры `Jc`, `Jmin` и `Jmax` могут отличатся.\
Параметры `I1-I5` [(Custom Protocol Signature)](https://docs.amnezia.org/documentation/amnezia-wg/) нужно указывать на стороне _клиента_ (Сервер **А**).
Рекомендации по выбору значений:
```text
Jc — 1 ≤ Jc ≤ 128; от 4 до 12 включительно
Jmin — Jmax > Jmin < 1280*; рекомендовано 8
Jmax — Jmin < Jmax 1280*; рекомендовано 80
S1 — S1 ≤ 1132* (1280* - 148 = 1132); S1 + 56 ≠ S2;
рекомендованный диапазон от 15 до 150 включительно
S2 — S2 ≤ 1188* (1280* - 92 = 1188);
рекомендованный диапазон от 15 до 150 включительно
H1/H2/H3/H4 — должны быть уникальны и отличаться друг от друга;
рекомендованный диапазон от 5 до 2147483647 включительно
* Предполагается, что подключение к Интернету имеет MTU 1280.
```
> [!IMPORTANT]
> Рекомендуется использовать собственные, уникальные значения.\
> Для выбора параметров можете воспользоваться [генератором](https://htmlpreview.github.io/?https://gist.githubusercontent.com/avbor/955782b5c37b06240b243aa375baeac5/raw/e8b269ff0089a27effd88f8d925179b78e5666c4/awg-gen.html).
#### Конфигурация Сервера B (_Нидерланды_):
Создаем файл конфигурации интерфейса (`awg0`)
```bash
nano /etc/amnezia/amneziawg/awg0.conf
```
Содержимое файла
```ini
[Interface]
Address = 10.10.10.1/24
ListenPort = 8443
PrivateKey = <PRIVATE_KEY_SERVER_B>
SaveConfig = true
Jc = 4
Jmin = 8
Jmax = 80
S1 = 29
S2 = 15
H1 = 2087563914
H2 = 188817757
H3 = 101784570
H4 = 432174303
[Peer]
PublicKey = <PUBLIC_KEY_SERVER_A>
AllowedIPs = 10.10.10.2/32
```
`ListenPort` - порт, на котором сервер будет ждать подключения, можете выбрать любой свободный.\
`<PRIVATE_KEY_SERVER_B>` - содержимое файла `private.key` с сервера **B**.\
`<PUBLIC_KEY_SERVER_A>` - содержимое файла `public.key` с сервера **A**.
Открываем порт на фаерволе (если включен):
```bash
sudo ufw allow from <PUBLIC_IP_SERVER_A> to any port 8443 proto udp
```
`<PUBLIC_IP_SERVER_A>` - внешний IP адрес Сервера **A**.
#### Конфигурация Сервера A (_РФ_):
Создаем файл конфигурации интерфейса (`awg0`)
```bash
nano /etc/amnezia/amneziawg/awg0.conf
```
Содержимое файла
```ini
[Interface]
Address = 10.10.10.2/24
PrivateKey = <PRIVATE_KEY_SERVER_A>
Jc = 4
Jmin = 8
Jmax = 80
S1 = 29
S2 = 15
H1 = 2087563914
H2 = 188817757
H3 = 101784570
H4 = 432174303
I1 = <b 0xc10000000108981eba846e21f74e00>
I2 = <b 0xc20000000108981eba846e21f74e00>
I3 = <b 0xc30000000108981eba846e21f74e00>
I4 = <b 0x43981eba846e21f74e>
I5 = <b 0x43981eba846e21f74e>
[Peer]
PublicKey = <PUBLIC_KEY_SERVER_B>
Endpoint = <PUBLIC_IP_SERVER_B>:8443
AllowedIPs = 10.10.10.1/32
PersistentKeepalive = 25
```
`<PRIVATE_KEY_SERVER_A>` - содержимое файла `private.key` с сервера **A**.\
`<PUBLIC_KEY_SERVER_B>` - содержимое файла `public.key` с сервера **B**.\
`<PUBLIC_IP_SERVER_B>` - публичный IP адресс сервера **B**.
#### Включаем туннель на обоих серверах:
```bash
sudo systemctl enable --now awg-quick@awg0
```
Убедитесь, что с Сервера `A` доступен Сервер `B` через туннель.
```bash
ping 10.10.10.1
PING 10.10.10.1 (10.10.10.1) 56(84) bytes of data.
64 bytes from 10.10.10.1: icmp_seq=1 ttl=64 time=35.1 ms
64 bytes from 10.10.10.1: icmp_seq=2 ttl=64 time=35.0 ms
64 bytes from 10.10.10.1: icmp_seq=3 ttl=64 time=35.1 ms
^C
```
---
## Шаг 2. Установка telemt на Сервере B (_условно Нидерланды_)
Установка и настройка описаны [здесь](https://github.com/telemt/telemt/blob/main/docs/QUICK_START_GUIDE.ru.md) или [здесь](https://gitlab.com/An0nX/telemt-docker#-quick-start-docker-compose).\
Подразумевается что telemt ожидает подключения на порту `443\tcp`.
В конфиге telemt необходимо включить протокол `Proxy` и ограничить подключения к нему только через туннель.
```toml
[server]
port = 443
listen_addr_ipv4 = "10.10.10.1"
proxy_protocol = true
```
А также, для правильной генерации ссылок, указать FQDN или IP адрес и порт Сервера `A`
```toml
[general.links]
show = "*"
public_host = "<FQDN_OR_IP_SERVER_A>"
public_port = 443
```
Открываем порт на фаерволе (если включен):
```bash
sudo ufw allow from 10.10.10.2 to any port 443 proto tcp
```
---
### Шаг 3. Настройка HAProxy на Сервере A (_РФ_)
Т.к. в стандартном репозитории Ubuntu версия относительно старая, имеет смысл воспользоваться официальным образом Docker.\
[Инструкция](https://docs.docker.com/engine/install/ubuntu/) по установке Docker на Ubuntu.
> [!WARNING]
> По умолчанию у обычных пользователей нет прав на использование портов < 1024.\
> Попытки запустить HAProxy на 443 порту могут приводить к ошибкам:
> ```
> [ALERT] (8) : Binding [/usr/local/etc/haproxy/haproxy.cfg:17] for frontend tcp_in_443:
> protocol tcpv4: cannot bind socket (Permission denied) for [0.0.0.0:443].
> ```
> Есть два простых способа обойти это ограничение, выберите что-то одно:
> 1. На уровне ОС изменить настройку net.ipv4.ip_unprivileged_port_start, разрешив пользователям использовать все порты:
> ```
> echo "net.ipv4.ip_unprivileged_port_start = 0" | sudo tee -a /etc/sysctl.conf && sudo sysctl -p
> ```
> или
>
> 2. Запустить HAProxy под root:\
> Раскомментируйте в docker-compose.yaml параметр `user: "root"`.
#### Создаем папку для HAProxy:
```bash
mkdir -p /opt/docker-compose/haproxy && cd $_
```
#### Создаем файл docker-compose.yaml
`nano docker-compose.yaml`
Содержимое файла
```yaml
services:
haproxy:
image: haproxy:latest
container_name: haproxy
restart: unless-stopped
# user: "root"
network_mode: "host"
volumes:
- ./haproxy.cfg:/usr/local/etc/haproxy/haproxy.cfg:ro
logging:
driver: "json-file"
options:
max-size: "1m"
max-file: "1"
```
#### Создаем файл конфига haproxy.cfg
Принимаем подключения на порту 443\tcp и отправляем их через туннель на Сервер `B` 10.10.10.1:443
`nano haproxy.cfg`
Содержимое файла
```haproxy
global
log stdout format raw local0
maxconn 10000
defaults
log global
mode tcp
option tcplog
option clitcpka
option srvtcpka
timeout connect 5s
timeout client 2h
timeout server 2h
timeout check 5s
frontend tcp_in_443
bind *:443
maxconn 8000
option tcp-smart-accept
default_backend telemt_nodes
backend telemt_nodes
option tcp-smart-connect
server server_a 10.10.10.1:443 check inter 5s rise 2 fall 3 send-proxy-v2
```
>[!WARNING]
>**Файл должен заканчиваться пустой строкой, иначе HAProxy не запуститься!**
#### Разрешаем порт 443\tcp в фаерволе (если включен)
```bash
sudo ufw allow 443/tcp
```
#### Запускаем контейнер HAProxy
```bash
docker compose up -d
```
Если все настроено верно, то теперь можно пробовать подключить клиентов Telegram с использованием ссылок из лога\api telemt.

View File

@ -553,6 +553,20 @@ pub(crate) fn default_mask_shape_above_cap_blur_max_bytes() -> usize {
512
}
#[cfg(not(test))]
pub(crate) fn default_mask_relay_max_bytes() -> usize {
5 * 1024 * 1024
}
#[cfg(test)]
pub(crate) fn default_mask_relay_max_bytes() -> usize {
32 * 1024
}
pub(crate) fn default_mask_classifier_prefetch_timeout_ms() -> u64 {
5
}
pub(crate) fn default_mask_timing_normalization_enabled() -> bool {
false
}

View File

@ -600,6 +600,9 @@ fn warn_non_hot_changes(old: &ProxyConfig, new: &ProxyConfig, non_hot_changed: b
|| old.censorship.mask_shape_above_cap_blur != new.censorship.mask_shape_above_cap_blur
|| old.censorship.mask_shape_above_cap_blur_max_bytes
!= new.censorship.mask_shape_above_cap_blur_max_bytes
|| old.censorship.mask_relay_max_bytes != new.censorship.mask_relay_max_bytes
|| old.censorship.mask_classifier_prefetch_timeout_ms
!= new.censorship.mask_classifier_prefetch_timeout_ms
|| old.censorship.mask_timing_normalization_enabled
!= new.censorship.mask_timing_normalization_enabled
|| old.censorship.mask_timing_normalization_floor_ms

View File

@ -430,6 +430,25 @@ impl ProxyConfig {
));
}
if config.censorship.mask_relay_max_bytes == 0 {
return Err(ProxyError::Config(
"censorship.mask_relay_max_bytes must be > 0".to_string(),
));
}
if config.censorship.mask_relay_max_bytes > 67_108_864 {
return Err(ProxyError::Config(
"censorship.mask_relay_max_bytes must be <= 67108864".to_string(),
));
}
if !(5..=50).contains(&config.censorship.mask_classifier_prefetch_timeout_ms) {
return Err(ProxyError::Config(
"censorship.mask_classifier_prefetch_timeout_ms must be within [5, 50]"
.to_string(),
));
}
if config.censorship.mask_timing_normalization_ceiling_ms
< config.censorship.mask_timing_normalization_floor_ms
{
@ -1134,6 +1153,10 @@ mod load_security_tests;
#[path = "tests/load_mask_shape_security_tests.rs"]
mod load_mask_shape_security_tests;
#[cfg(test)]
#[path = "tests/load_mask_classifier_prefetch_timeout_security_tests.rs"]
mod load_mask_classifier_prefetch_timeout_security_tests;
#[cfg(test)]
mod tests {
use super::*;

View File

@ -0,0 +1,75 @@
use super::*;
use std::fs;
use std::path::PathBuf;
use std::time::{SystemTime, UNIX_EPOCH};
fn write_temp_config(contents: &str) -> PathBuf {
let nonce = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("system time must be after unix epoch")
.as_nanos();
let path = std::env::temp_dir()
.join(format!("telemt-load-mask-prefetch-timeout-security-{nonce}.toml"));
fs::write(&path, contents).expect("temp config write must succeed");
path
}
fn remove_temp_config(path: &PathBuf) {
let _ = fs::remove_file(path);
}
#[test]
fn load_rejects_mask_classifier_prefetch_timeout_below_min_bound() {
let path = write_temp_config(
r#"
[censorship]
mask_classifier_prefetch_timeout_ms = 4
"#,
);
let err = ProxyConfig::load(&path)
.expect_err("prefetch timeout below minimum security bound must be rejected");
let msg = err.to_string();
assert!(
msg.contains("censorship.mask_classifier_prefetch_timeout_ms must be within [5, 50]"),
"error must explain timeout bound invariant, got: {msg}"
);
remove_temp_config(&path);
}
#[test]
fn load_rejects_mask_classifier_prefetch_timeout_above_max_bound() {
let path = write_temp_config(
r#"
[censorship]
mask_classifier_prefetch_timeout_ms = 51
"#,
);
let err = ProxyConfig::load(&path)
.expect_err("prefetch timeout above max security bound must be rejected");
let msg = err.to_string();
assert!(
msg.contains("censorship.mask_classifier_prefetch_timeout_ms must be within [5, 50]"),
"error must explain timeout bound invariant, got: {msg}"
);
remove_temp_config(&path);
}
#[test]
fn load_accepts_mask_classifier_prefetch_timeout_within_bounds() {
let path = write_temp_config(
r#"
[censorship]
mask_classifier_prefetch_timeout_ms = 20
"#,
);
let cfg = ProxyConfig::load(&path)
.expect("prefetch timeout within security bounds must be accepted");
assert_eq!(cfg.censorship.mask_classifier_prefetch_timeout_ms, 20);
remove_temp_config(&path);
}

View File

@ -236,3 +236,57 @@ mask_shape_above_cap_blur_max_bytes = 8
remove_temp_config(&path);
}
#[test]
fn load_rejects_zero_mask_relay_max_bytes() {
let path = write_temp_config(
r#"
[censorship]
mask_relay_max_bytes = 0
"#,
);
let err = ProxyConfig::load(&path).expect_err("mask_relay_max_bytes must be > 0");
let msg = err.to_string();
assert!(
msg.contains("censorship.mask_relay_max_bytes must be > 0"),
"error must explain non-zero relay cap invariant, got: {msg}"
);
remove_temp_config(&path);
}
#[test]
fn load_rejects_mask_relay_max_bytes_above_upper_bound() {
let path = write_temp_config(
r#"
[censorship]
mask_relay_max_bytes = 67108865
"#,
);
let err = ProxyConfig::load(&path)
.expect_err("mask_relay_max_bytes above hard cap must be rejected");
let msg = err.to_string();
assert!(
msg.contains("censorship.mask_relay_max_bytes must be <= 67108864"),
"error must explain relay cap upper bound invariant, got: {msg}"
);
remove_temp_config(&path);
}
#[test]
fn load_accepts_valid_mask_relay_max_bytes() {
let path = write_temp_config(
r#"
[censorship]
mask_relay_max_bytes = 8388608
"#,
);
let cfg = ProxyConfig::load(&path).expect("valid mask_relay_max_bytes must be accepted");
assert_eq!(cfg.censorship.mask_relay_max_bytes, 8_388_608);
remove_temp_config(&path);
}

View File

@ -1450,6 +1450,14 @@ pub struct AntiCensorshipConfig {
#[serde(default = "default_mask_shape_above_cap_blur_max_bytes")]
pub mask_shape_above_cap_blur_max_bytes: usize,
/// Maximum bytes relayed per direction on unauthenticated masking fallback paths.
#[serde(default = "default_mask_relay_max_bytes")]
pub mask_relay_max_bytes: usize,
/// Prefetch timeout (ms) for extending fragmented masking classifier window.
#[serde(default = "default_mask_classifier_prefetch_timeout_ms")]
pub mask_classifier_prefetch_timeout_ms: u64,
/// Enable outcome-time normalization envelope for masking fallback.
#[serde(default = "default_mask_timing_normalization_enabled")]
pub mask_timing_normalization_enabled: bool,
@ -1488,6 +1496,8 @@ impl Default for AntiCensorshipConfig {
mask_shape_bucket_cap_bytes: default_mask_shape_bucket_cap_bytes(),
mask_shape_above_cap_blur: default_mask_shape_above_cap_blur(),
mask_shape_above_cap_blur_max_bytes: default_mask_shape_above_cap_blur_max_bytes(),
mask_relay_max_bytes: default_mask_relay_max_bytes(),
mask_classifier_prefetch_timeout_ms: default_mask_classifier_prefetch_timeout_ms(),
mask_timing_normalization_enabled: default_mask_timing_normalization_enabled(),
mask_timing_normalization_floor_ms: default_mask_timing_normalization_floor_ms(),
mask_timing_normalization_ceiling_ms: default_mask_timing_normalization_ceiling_ms(),

View File

@ -32,6 +32,14 @@ pub(crate) struct RuntimeWatches {
pub(crate) detected_ip_v6: Option<IpAddr>,
}
const QUOTA_USER_LOCK_EVICT_INTERVAL_SECS: u64 = 60;
fn spawn_quota_lock_maintenance_task() -> tokio::task::JoinHandle<()> {
crate::proxy::relay::spawn_quota_user_lock_evictor(std::time::Duration::from_secs(
QUOTA_USER_LOCK_EVICT_INTERVAL_SECS,
))
}
#[allow(clippy::too_many_arguments)]
pub(crate) async fn spawn_runtime_tasks(
config: &Arc<ProxyConfig>,
@ -69,6 +77,8 @@ pub(crate) async fn spawn_runtime_tasks(
rc_clone.run_periodic_cleanup().await;
});
spawn_quota_lock_maintenance_task();
let detected_ip_v4: Option<IpAddr> = probe.detected_ipv4.map(IpAddr::V4);
let detected_ip_v6: Option<IpAddr> = probe.detected_ipv6.map(IpAddr::V6);
debug!(
@ -360,3 +370,24 @@ pub(crate) async fn mark_runtime_ready(startup_tracker: &Arc<StartupTracker>) {
.await;
startup_tracker.mark_ready().await;
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn tdd_runtime_quota_lock_maintenance_path_spawns_single_evictor_task() {
crate::proxy::relay::reset_quota_user_lock_evictor_spawn_count_for_tests();
let handle = spawn_quota_lock_maintenance_task();
tokio::time::sleep(std::time::Duration::from_millis(5)).await;
assert_eq!(
crate::proxy::relay::quota_user_lock_evictor_spawn_count_for_tests(),
1,
"runtime maintenance path must spawn exactly one quota lock evictor task per call"
);
handle.abort();
}
}

View File

@ -186,6 +186,67 @@ fn handshake_timeout_with_mask_grace(config: &ProxyConfig) -> Duration {
}
}
const MASK_CLASSIFIER_PREFETCH_WINDOW: usize = 16;
#[cfg(test)]
const MASK_CLASSIFIER_PREFETCH_TIMEOUT: Duration = Duration::from_millis(5);
fn mask_classifier_prefetch_timeout(config: &ProxyConfig) -> Duration {
Duration::from_millis(config.censorship.mask_classifier_prefetch_timeout_ms)
}
fn should_prefetch_mask_classifier_window(initial_data: &[u8]) -> bool {
if initial_data.len() >= MASK_CLASSIFIER_PREFETCH_WINDOW {
return false;
}
if initial_data.is_empty() {
// Empty initial_data means there is no client probe prefix to refine.
// Prefetching in this case can consume fallback relay payload bytes and
// accidentally route them through shaping heuristics.
return false;
}
if initial_data[0] == 0x16 || initial_data.starts_with(b"SSH-") {
return false;
}
initial_data.iter().all(|b| b.is_ascii_alphabetic() || *b == b' ')
}
#[cfg(test)]
async fn extend_masking_initial_window<R>(reader: &mut R, initial_data: &mut Vec<u8>)
where
R: AsyncRead + Unpin,
{
extend_masking_initial_window_with_timeout(reader, initial_data, MASK_CLASSIFIER_PREFETCH_TIMEOUT)
.await;
}
async fn extend_masking_initial_window_with_timeout<R>(
reader: &mut R,
initial_data: &mut Vec<u8>,
prefetch_timeout: Duration,
)
where
R: AsyncRead + Unpin,
{
if !should_prefetch_mask_classifier_window(initial_data) {
return;
}
let need = MASK_CLASSIFIER_PREFETCH_WINDOW.saturating_sub(initial_data.len());
if need == 0 {
return;
}
let mut extra = [0u8; MASK_CLASSIFIER_PREFETCH_WINDOW];
if let Ok(Ok(n)) = timeout(prefetch_timeout, reader.read(&mut extra[..need])).await
&& n > 0
{
initial_data.extend_from_slice(&extra[..n]);
}
}
fn masking_outcome<R, W>(
reader: R,
writer: W,
@ -200,6 +261,15 @@ where
W: AsyncWrite + Unpin + Send + 'static,
{
HandshakeOutcome::NeedsMasking(Box::pin(async move {
let mut reader = reader;
let mut initial_data = initial_data;
extend_masking_initial_window_with_timeout(
&mut reader,
&mut initial_data,
mask_classifier_prefetch_timeout(&config),
)
.await;
handle_bad_client(
reader,
writer,
@ -1321,6 +1391,38 @@ mod masking_shape_classifier_fuzz_redteam_expected_fail_tests;
#[path = "tests/client_masking_probe_evasion_blackhat_tests.rs"]
mod masking_probe_evasion_blackhat_tests;
#[cfg(test)]
#[path = "tests/client_masking_fragmented_classifier_security_tests.rs"]
mod masking_fragmented_classifier_security_tests;
#[cfg(test)]
#[path = "tests/client_masking_replay_timing_security_tests.rs"]
mod masking_replay_timing_security_tests;
#[cfg(test)]
#[path = "tests/client_masking_http2_fragmented_preface_security_tests.rs"]
mod masking_http2_fragmented_preface_security_tests;
#[cfg(test)]
#[path = "tests/client_masking_prefetch_invariant_security_tests.rs"]
mod masking_prefetch_invariant_security_tests;
#[cfg(test)]
#[path = "tests/client_masking_prefetch_timing_matrix_security_tests.rs"]
mod masking_prefetch_timing_matrix_security_tests;
#[cfg(test)]
#[path = "tests/client_masking_prefetch_config_runtime_security_tests.rs"]
mod masking_prefetch_config_runtime_security_tests;
#[cfg(test)]
#[path = "tests/client_masking_prefetch_config_pipeline_integration_security_tests.rs"]
mod masking_prefetch_config_pipeline_integration_security_tests;
#[cfg(test)]
#[path = "tests/client_masking_prefetch_strict_boundary_security_tests.rs"]
mod masking_prefetch_strict_boundary_security_tests;
#[cfg(test)]
#[path = "tests/client_beobachten_ttl_bounds_security_tests.rs"]
mod beobachten_ttl_bounds_security_tests;

View File

@ -121,6 +121,19 @@ fn auth_probe_eviction_offset(peer_ip: IpAddr, now: Instant) -> usize {
hasher.finish() as usize
}
fn auth_probe_scan_start_offset(
peer_ip: IpAddr,
now: Instant,
state_len: usize,
scan_limit: usize,
) -> usize {
if state_len == 0 || scan_limit == 0 {
return 0;
}
auth_probe_eviction_offset(peer_ip, now) % state_len
}
fn auth_probe_is_throttled(peer_ip: IpAddr, now: Instant) -> bool {
let peer_ip = normalize_auth_probe_ip(peer_ip);
let state = auth_probe_state_map();
@ -269,11 +282,7 @@ fn auth_probe_record_failure_with_state(
let mut eviction_candidate: Option<(IpAddr, u32, Instant)> = None;
let state_len = state.len();
let scan_limit = state_len.min(AUTH_PROBE_PRUNE_SCAN_LIMIT);
let start_offset = if state_len == 0 {
0
} else {
auth_probe_eviction_offset(peer_ip, now) % state_len
};
let start_offset = auth_probe_scan_start_offset(peer_ip, now, state_len, scan_limit);
let mut scanned = 0usize;
for entry in state.iter().skip(start_offset) {
@ -769,7 +778,7 @@ where
let mut dec_key_input = Zeroizing::new(Vec::with_capacity(PREKEY_LEN + secret.len()));
dec_key_input.extend_from_slice(dec_prekey);
dec_key_input.extend_from_slice(&secret);
let dec_key = sha256(&dec_key_input);
let dec_key = Zeroizing::new(sha256(&dec_key_input));
let mut dec_iv_arr = [0u8; IV_LEN];
dec_iv_arr.copy_from_slice(dec_iv_bytes);
@ -805,7 +814,7 @@ where
let mut enc_key_input = Zeroizing::new(Vec::with_capacity(PREKEY_LEN + secret.len()));
enc_key_input.extend_from_slice(enc_prekey);
enc_key_input.extend_from_slice(&secret);
let enc_key = sha256(&enc_key_input);
let enc_key = Zeroizing::new(sha256(&enc_key_input));
let mut enc_iv_arr = [0u8; IV_LEN];
enc_iv_arr.copy_from_slice(enc_iv_bytes);
@ -830,9 +839,9 @@ where
user: user.clone(),
dc_idx,
proto_tag,
dec_key,
dec_key: *dec_key,
dec_iv,
enc_key,
enc_key: *enc_key,
enc_iv,
peer,
is_tls,
@ -979,6 +988,18 @@ mod saturation_poison_security_tests;
#[path = "tests/handshake_auth_probe_hardening_adversarial_tests.rs"]
mod auth_probe_hardening_adversarial_tests;
#[cfg(test)]
#[path = "tests/handshake_auth_probe_scan_budget_security_tests.rs"]
mod auth_probe_scan_budget_security_tests;
#[cfg(test)]
#[path = "tests/handshake_auth_probe_scan_offset_stress_tests.rs"]
mod auth_probe_scan_offset_stress_tests;
#[cfg(test)]
#[path = "tests/handshake_auth_probe_eviction_bias_security_tests.rs"]
mod auth_probe_eviction_bias_security_tests;
#[cfg(test)]
#[path = "tests/handshake_advanced_clever_tests.rs"]
mod advanced_clever_tests;
@ -995,6 +1016,10 @@ mod real_bug_stress_tests;
#[path = "tests/handshake_timing_manual_bench_tests.rs"]
mod timing_manual_bench_tests;
#[cfg(test)]
#[path = "tests/handshake_key_material_zeroization_security_tests.rs"]
mod handshake_key_material_zeroization_security_tests;
/// Compile-time guard: HandshakeSuccess holds cryptographic key material and
/// must never be Copy. A Copy impl would allow silent key duplication,
/// undermining the zeroize-on-drop guarantee.

View File

@ -4,14 +4,23 @@ use crate::config::ProxyConfig;
use crate::network::dns_overrides::resolve_socket_addr;
use crate::stats::beobachten::BeobachtenStore;
use crate::transport::proxy_protocol::{ProxyProtocolV1Builder, ProxyProtocolV2Builder};
use rand::{Rng, RngExt};
use std::net::SocketAddr;
#[cfg(unix)]
use nix::ifaddrs::getifaddrs;
use rand::rngs::StdRng;
use rand::{Rng, RngExt, SeedableRng};
use std::net::{IpAddr, SocketAddr};
use std::str;
use std::time::Duration;
#[cfg(unix)]
use std::sync::{Mutex, OnceLock};
#[cfg(test)]
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::{Duration, Instant as StdInstant};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpStream;
#[cfg(unix)]
use tokio::net::UnixStream;
#[cfg(unix)]
use tokio::sync::Mutex as AsyncMutex;
use tokio::time::{Instant, timeout};
use tracing::debug;
@ -30,13 +39,23 @@ const MASK_RELAY_IDLE_TIMEOUT: Duration = Duration::from_secs(5);
#[cfg(test)]
const MASK_RELAY_IDLE_TIMEOUT: Duration = Duration::from_millis(100);
const MASK_BUFFER_SIZE: usize = 8192;
#[cfg(unix)]
#[cfg(not(test))]
const LOCAL_INTERFACE_CACHE_TTL: Duration = Duration::from_secs(300);
#[cfg(all(unix, test))]
const LOCAL_INTERFACE_CACHE_TTL: Duration = Duration::from_secs(1);
struct CopyOutcome {
total: usize,
ended_by_eof: bool,
}
async fn copy_with_idle_timeout<R, W>(reader: &mut R, writer: &mut W) -> CopyOutcome
async fn copy_with_idle_timeout<R, W>(
reader: &mut R,
writer: &mut W,
byte_cap: usize,
shutdown_on_eof: bool,
) -> CopyOutcome
where
R: AsyncRead + Unpin,
W: AsyncWrite + Unpin,
@ -44,14 +63,31 @@ where
let mut buf = [0u8; MASK_BUFFER_SIZE];
let mut total = 0usize;
let mut ended_by_eof = false;
if byte_cap == 0 {
return CopyOutcome {
total,
ended_by_eof,
};
}
loop {
let read_res = timeout(MASK_RELAY_IDLE_TIMEOUT, reader.read(&mut buf)).await;
let remaining_budget = byte_cap.saturating_sub(total);
if remaining_budget == 0 {
break;
}
let read_len = remaining_budget.min(MASK_BUFFER_SIZE);
let read_res = timeout(MASK_RELAY_IDLE_TIMEOUT, reader.read(&mut buf[..read_len])).await;
let n = match read_res {
Ok(Ok(n)) => n,
Ok(Err(_)) | Err(_) => break,
};
if n == 0 {
ended_by_eof = true;
if shutdown_on_eof {
let _ = timeout(MASK_RELAY_IDLE_TIMEOUT, writer.shutdown()).await;
}
break;
}
total = total.saturating_add(n);
@ -68,6 +104,39 @@ where
}
}
fn is_http_probe(data: &[u8]) -> bool {
// RFC 7540 section 3.5: HTTP/2 client preface starts with "PRI ".
const HTTP_METHODS: [&[u8]; 10] = [
b"GET ",
b"POST",
b"HEAD",
b"PUT ",
b"DELETE",
b"OPTIONS",
b"CONNECT",
b"TRACE",
b"PATCH",
b"PRI ",
];
if data.is_empty() {
return false;
}
let window = &data[..data.len().min(16)];
for method in HTTP_METHODS {
if data.len() >= method.len() && window.starts_with(method) {
return true;
}
if (2..=3).contains(&window.len()) && method.starts_with(window) {
return true;
}
}
false
}
fn next_mask_shape_bucket(total: usize, floor: usize, cap: usize) -> usize {
if total == 0 || floor == 0 || cap < floor {
return total;
@ -125,6 +194,11 @@ async fn maybe_write_shape_padding<W>(
let mut remaining = target_total - total_sent;
let mut pad_chunk = [0u8; 1024];
let deadline = Instant::now() + MASK_TIMEOUT;
// Use a Send RNG so relay futures remain spawn-safe under Tokio.
let mut rng = {
let mut seed_source = rand::rng();
StdRng::from_rng(&mut seed_source)
};
while remaining > 0 {
let now = Instant::now();
@ -133,10 +207,7 @@ async fn maybe_write_shape_padding<W>(
}
let write_len = remaining.min(pad_chunk.len());
{
let mut rng = rand::rng();
rng.fill_bytes(&mut pad_chunk[..write_len]);
}
rng.fill_bytes(&mut pad_chunk[..write_len]);
let write_budget = deadline.saturating_duration_since(now);
match timeout(write_budget, mask_write.write_all(&pad_chunk[..write_len])).await {
Ok(Ok(())) => {}
@ -167,11 +238,11 @@ where
}
}
async fn consume_client_data_with_timeout<R>(reader: R)
async fn consume_client_data_with_timeout_and_cap<R>(reader: R, byte_cap: usize)
where
R: AsyncRead + Unpin,
{
if timeout(MASK_RELAY_TIMEOUT, consume_client_data(reader))
if timeout(MASK_RELAY_TIMEOUT, consume_client_data(reader, byte_cap))
.await
.is_err()
{
@ -190,6 +261,9 @@ fn mask_outcome_target_budget(config: &ProxyConfig) -> Duration {
if config.censorship.mask_timing_normalization_enabled {
let floor = config.censorship.mask_timing_normalization_floor_ms;
let ceiling = config.censorship.mask_timing_normalization_ceiling_ms;
if floor == 0 {
return MASK_TIMEOUT;
}
if ceiling > floor {
let mut rng = rand::rng();
return Duration::from_millis(rng.random_range(floor..=ceiling));
@ -219,14 +293,7 @@ async fn wait_mask_outcome_budget(started: Instant, config: &ProxyConfig) {
/// Detect client type based on initial data
fn detect_client_type(data: &[u8]) -> &'static str {
// Check for HTTP request
if data.len() > 4
&& (data.starts_with(b"GET ")
|| data.starts_with(b"POST")
|| data.starts_with(b"HEAD")
|| data.starts_with(b"PUT ")
|| data.starts_with(b"DELETE")
|| data.starts_with(b"OPTIONS"))
{
if is_http_probe(data) {
return "HTTP";
}
@ -248,6 +315,244 @@ fn detect_client_type(data: &[u8]) -> &'static str {
"unknown"
}
fn parse_mask_host_ip_literal(host: &str) -> Option<IpAddr> {
if host.starts_with('[') && host.ends_with(']') {
return host[1..host.len() - 1].parse::<IpAddr>().ok();
}
host.parse::<IpAddr>().ok()
}
fn canonical_ip(ip: IpAddr) -> IpAddr {
match ip {
IpAddr::V6(v6) => v6.to_ipv4_mapped().map(IpAddr::V4).unwrap_or(IpAddr::V6(v6)),
IpAddr::V4(v4) => IpAddr::V4(v4),
}
}
#[cfg(unix)]
fn collect_local_interface_ips() -> Vec<IpAddr> {
#[cfg(test)]
LOCAL_INTERFACE_ENUMERATIONS.fetch_add(1, Ordering::Relaxed);
let mut out = Vec::new();
if let Ok(addrs) = getifaddrs() {
for iface in addrs {
if let Some(address) = iface.address {
if let Some(v4) = address.as_sockaddr_in() {
out.push(canonical_ip(IpAddr::V4(v4.ip())));
} else if let Some(v6) = address.as_sockaddr_in6() {
out.push(canonical_ip(IpAddr::V6(v6.ip())));
}
}
}
}
out
}
fn choose_interface_snapshot(previous: &[IpAddr], refreshed: Vec<IpAddr>) -> Vec<IpAddr> {
if refreshed.is_empty() && !previous.is_empty() {
return previous.to_vec();
}
refreshed
}
#[cfg(unix)]
#[derive(Default)]
struct LocalInterfaceCache {
ips: Vec<IpAddr>,
refreshed_at: Option<StdInstant>,
}
#[cfg(unix)]
static LOCAL_INTERFACE_CACHE: OnceLock<Mutex<LocalInterfaceCache>> = OnceLock::new();
#[cfg(unix)]
static LOCAL_INTERFACE_REFRESH_LOCK: OnceLock<AsyncMutex<()>> = OnceLock::new();
#[cfg(all(unix, test))]
fn local_interface_ips() -> Vec<IpAddr> {
let cache = LOCAL_INTERFACE_CACHE.get_or_init(|| Mutex::new(LocalInterfaceCache::default()));
let mut guard = cache.lock().unwrap_or_else(|poison| poison.into_inner());
let stale = guard
.refreshed_at
.is_none_or(|at| at.elapsed() >= LOCAL_INTERFACE_CACHE_TTL);
if stale {
let refreshed = collect_local_interface_ips();
guard.ips = choose_interface_snapshot(&guard.ips, refreshed);
guard.refreshed_at = Some(StdInstant::now());
}
guard.ips.clone()
}
#[cfg(unix)]
async fn local_interface_ips_async() -> Vec<IpAddr> {
let cache = LOCAL_INTERFACE_CACHE.get_or_init(|| Mutex::new(LocalInterfaceCache::default()));
{
let guard = cache.lock().unwrap_or_else(|poison| poison.into_inner());
let stale = guard
.refreshed_at
.is_none_or(|at| at.elapsed() >= LOCAL_INTERFACE_CACHE_TTL);
if !stale {
return guard.ips.clone();
}
}
let refresh_lock = LOCAL_INTERFACE_REFRESH_LOCK.get_or_init(|| AsyncMutex::new(()));
let _refresh_guard = refresh_lock.lock().await;
{
let guard = cache.lock().unwrap_or_else(|poison| poison.into_inner());
let stale = guard
.refreshed_at
.is_none_or(|at| at.elapsed() >= LOCAL_INTERFACE_CACHE_TTL);
if !stale {
return guard.ips.clone();
}
}
let refreshed = tokio::task::spawn_blocking(collect_local_interface_ips)
.await
.unwrap_or_default();
let mut guard = cache.lock().unwrap_or_else(|poison| poison.into_inner());
let stale = guard
.refreshed_at
.is_none_or(|at| at.elapsed() >= LOCAL_INTERFACE_CACHE_TTL);
if stale {
guard.ips = choose_interface_snapshot(&guard.ips, refreshed);
guard.refreshed_at = Some(StdInstant::now());
}
guard.ips.clone()
}
#[cfg(all(not(unix), test))]
fn local_interface_ips() -> Vec<IpAddr> {
Vec::new()
}
#[cfg(not(unix))]
async fn local_interface_ips_async() -> Vec<IpAddr> {
Vec::new()
}
#[cfg(test)]
static LOCAL_INTERFACE_ENUMERATIONS: AtomicUsize = AtomicUsize::new(0);
#[cfg(test)]
fn reset_local_interface_enumerations_for_tests() {
LOCAL_INTERFACE_ENUMERATIONS.store(0, Ordering::Relaxed);
#[cfg(unix)]
if let Some(cache) = LOCAL_INTERFACE_CACHE.get() {
let mut guard = cache.lock().unwrap_or_else(|poison| poison.into_inner());
guard.ips.clear();
guard.refreshed_at = None;
}
}
#[cfg(test)]
fn local_interface_enumerations_for_tests() -> usize {
LOCAL_INTERFACE_ENUMERATIONS.load(Ordering::Relaxed)
}
fn is_mask_target_local_listener_with_interfaces(
mask_host: &str,
mask_port: u16,
local_addr: SocketAddr,
resolved_override: Option<SocketAddr>,
interface_ips: &[IpAddr],
) -> bool {
if mask_port != local_addr.port() {
return false;
}
let local_ip = canonical_ip(local_addr.ip());
let literal_mask_ip = parse_mask_host_ip_literal(mask_host).map(canonical_ip);
if let Some(addr) = resolved_override {
let resolved_ip = canonical_ip(addr.ip());
if resolved_ip == local_ip {
return true;
}
if local_ip.is_unspecified()
&& (resolved_ip.is_loopback()
|| resolved_ip.is_unspecified()
|| interface_ips.contains(&resolved_ip))
{
return true;
}
}
if let Some(mask_ip) = literal_mask_ip {
if mask_ip == local_ip {
return true;
}
if local_ip.is_unspecified()
&& (mask_ip.is_loopback()
|| mask_ip.is_unspecified()
|| interface_ips.contains(&mask_ip))
{
return true;
}
}
false
}
#[cfg(test)]
fn is_mask_target_local_listener(
mask_host: &str,
mask_port: u16,
local_addr: SocketAddr,
resolved_override: Option<SocketAddr>,
) -> bool {
if mask_port != local_addr.port() {
return false;
}
let interfaces = local_interface_ips();
is_mask_target_local_listener_with_interfaces(
mask_host,
mask_port,
local_addr,
resolved_override,
&interfaces,
)
}
async fn is_mask_target_local_listener_async(
mask_host: &str,
mask_port: u16,
local_addr: SocketAddr,
resolved_override: Option<SocketAddr>,
) -> bool {
if mask_port != local_addr.port() {
return false;
}
let interfaces = local_interface_ips_async().await;
is_mask_target_local_listener_with_interfaces(
mask_host,
mask_port,
local_addr,
resolved_override,
&interfaces,
)
}
fn masking_beobachten_ttl(config: &ProxyConfig) -> Duration {
let minutes = config.general.beobachten_minutes;
let clamped = minutes.clamp(1, 24 * 60);
Duration::from_secs(clamped.saturating_mul(60))
}
fn build_mask_proxy_header(
version: u8,
peer: SocketAddr,
@ -290,13 +595,14 @@ pub async fn handle_bad_client<R, W>(
{
let client_type = detect_client_type(initial_data);
if config.general.beobachten {
let ttl = Duration::from_secs(config.general.beobachten_minutes.saturating_mul(60));
let ttl = masking_beobachten_ttl(config);
beobachten.record(client_type, peer.ip(), ttl);
}
if !config.censorship.mask {
// Masking disabled, just consume data
consume_client_data_with_timeout(reader).await;
consume_client_data_with_timeout_and_cap(reader, config.censorship.mask_relay_max_bytes)
.await;
return;
}
@ -341,6 +647,7 @@ pub async fn handle_bad_client<R, W>(
config.censorship.mask_shape_above_cap_blur,
config.censorship.mask_shape_above_cap_blur_max_bytes,
config.censorship.mask_shape_hardening_aggressive_mode,
config.censorship.mask_relay_max_bytes,
),
)
.await
@ -353,12 +660,12 @@ pub async fn handle_bad_client<R, W>(
Ok(Err(e)) => {
wait_mask_connect_budget_if_needed(connect_started, config).await;
debug!(error = %e, "Failed to connect to mask unix socket");
consume_client_data_with_timeout(reader).await;
consume_client_data_with_timeout_and_cap(reader, config.censorship.mask_relay_max_bytes).await;
wait_mask_outcome_budget(outcome_started, config).await;
}
Err(_) => {
debug!("Timeout connecting to mask unix socket");
consume_client_data_with_timeout(reader).await;
consume_client_data_with_timeout_and_cap(reader, config.censorship.mask_relay_max_bytes).await;
wait_mask_outcome_budget(outcome_started, config).await;
}
}
@ -372,6 +679,28 @@ pub async fn handle_bad_client<R, W>(
.unwrap_or(&config.censorship.tls_domain);
let mask_port = config.censorship.mask_port;
// Fail closed when fallback points at our own listener endpoint.
// Self-referential masking can create recursive proxy loops under
// misconfiguration and leak distinguishable load spikes to adversaries.
let resolved_mask_addr = resolve_socket_addr(mask_host, mask_port);
if is_mask_target_local_listener_async(mask_host, mask_port, local_addr, resolved_mask_addr)
.await
{
let outcome_started = Instant::now();
debug!(
client_type = client_type,
host = %mask_host,
port = mask_port,
local = %local_addr,
"Mask target resolves to local listener; refusing self-referential masking fallback"
);
consume_client_data_with_timeout_and_cap(reader, config.censorship.mask_relay_max_bytes).await;
wait_mask_outcome_budget(outcome_started, config).await;
return;
}
let outcome_started = Instant::now();
debug!(
client_type = client_type,
host = %mask_host,
@ -381,10 +710,9 @@ pub async fn handle_bad_client<R, W>(
);
// Apply runtime DNS override for mask target when configured.
let mask_addr = resolve_socket_addr(mask_host, mask_port)
let mask_addr = resolved_mask_addr
.map(|addr| addr.to_string())
.unwrap_or_else(|| format!("{}:{}", mask_host, mask_port));
let outcome_started = Instant::now();
let connect_started = Instant::now();
let connect_result = timeout(MASK_TIMEOUT, TcpStream::connect(&mask_addr)).await;
match connect_result {
@ -413,6 +741,7 @@ pub async fn handle_bad_client<R, W>(
config.censorship.mask_shape_above_cap_blur,
config.censorship.mask_shape_above_cap_blur_max_bytes,
config.censorship.mask_shape_hardening_aggressive_mode,
config.censorship.mask_relay_max_bytes,
),
)
.await
@ -425,12 +754,12 @@ pub async fn handle_bad_client<R, W>(
Ok(Err(e)) => {
wait_mask_connect_budget_if_needed(connect_started, config).await;
debug!(error = %e, "Failed to connect to mask host");
consume_client_data_with_timeout(reader).await;
consume_client_data_with_timeout_and_cap(reader, config.censorship.mask_relay_max_bytes).await;
wait_mask_outcome_budget(outcome_started, config).await;
}
Err(_) => {
debug!("Timeout connecting to mask host");
consume_client_data_with_timeout(reader).await;
consume_client_data_with_timeout_and_cap(reader, config.censorship.mask_relay_max_bytes).await;
wait_mask_outcome_budget(outcome_started, config).await;
}
}
@ -449,6 +778,7 @@ async fn relay_to_mask<R, W, MR, MW>(
shape_above_cap_blur: bool,
shape_above_cap_blur_max_bytes: usize,
shape_hardening_aggressive_mode: bool,
mask_relay_max_bytes: usize,
) where
R: AsyncRead + Unpin + Send + 'static,
W: AsyncWrite + Unpin + Send + 'static,
@ -464,8 +794,18 @@ async fn relay_to_mask<R, W, MR, MW>(
}
let (upstream_copy, downstream_copy) = tokio::join!(
async { copy_with_idle_timeout(&mut reader, &mut mask_write).await },
async { copy_with_idle_timeout(&mut mask_read, &mut writer).await }
async {
copy_with_idle_timeout(
&mut reader,
&mut mask_write,
mask_relay_max_bytes,
!shape_hardening_enabled,
)
.await
},
async {
copy_with_idle_timeout(&mut mask_read, &mut writer, mask_relay_max_bytes, true).await
}
);
let total_sent = initial_data.len().saturating_add(upstream_copy.total);
@ -491,13 +831,36 @@ async fn relay_to_mask<R, W, MR, MW>(
let _ = writer.shutdown().await;
}
/// Just consume all data from client without responding
async fn consume_client_data<R: AsyncRead + Unpin>(mut reader: R) {
let mut buf = vec![0u8; MASK_BUFFER_SIZE];
while let Ok(n) = reader.read(&mut buf).await {
/// Just consume all data from client without responding.
async fn consume_client_data<R: AsyncRead + Unpin>(mut reader: R, byte_cap: usize) {
if byte_cap == 0 {
return;
}
// Keep drain path fail-closed under slow-loris stalls.
let mut buf = [0u8; MASK_BUFFER_SIZE];
let mut total = 0usize;
loop {
let remaining_budget = byte_cap.saturating_sub(total);
if remaining_budget == 0 {
break;
}
let read_len = remaining_budget.min(MASK_BUFFER_SIZE);
let n = match timeout(MASK_RELAY_IDLE_TIMEOUT, reader.read(&mut buf[..read_len])).await {
Ok(Ok(n)) => n,
Ok(Err(_)) | Err(_) => break,
};
if n == 0 {
break;
}
total = total.saturating_add(n);
if total >= byte_cap {
break;
}
}
}
@ -521,6 +884,10 @@ mod masking_shape_above_cap_blur_security_tests;
#[path = "tests/masking_timing_normalization_security_tests.rs"]
mod masking_timing_normalization_security_tests;
#[cfg(test)]
#[path = "tests/masking_timing_budget_coupling_security_tests.rs"]
mod masking_timing_budget_coupling_security_tests;
#[cfg(test)]
#[path = "tests/masking_ab_envelope_blur_integration_security_tests.rs"]
mod masking_ab_envelope_blur_integration_security_tests;
@ -548,3 +915,75 @@ mod masking_aggressive_mode_security_tests;
#[cfg(test)]
#[path = "tests/masking_timing_sidechannel_redteam_expected_fail_tests.rs"]
mod masking_timing_sidechannel_redteam_expected_fail_tests;
#[cfg(test)]
#[path = "tests/masking_self_target_loop_security_tests.rs"]
mod masking_self_target_loop_security_tests;
#[cfg(test)]
#[path = "tests/masking_classification_completeness_security_tests.rs"]
mod masking_classification_completeness_security_tests;
#[cfg(test)]
#[path = "tests/masking_relay_guardrails_security_tests.rs"]
mod masking_relay_guardrails_security_tests;
#[cfg(test)]
#[path = "tests/masking_connect_failure_close_matrix_security_tests.rs"]
mod masking_connect_failure_close_matrix_security_tests;
#[cfg(test)]
#[path = "tests/masking_additional_hardening_security_tests.rs"]
mod masking_additional_hardening_security_tests;
#[cfg(test)]
#[path = "tests/masking_consume_idle_timeout_security_tests.rs"]
mod masking_consume_idle_timeout_security_tests;
#[cfg(test)]
#[path = "tests/masking_http2_probe_classification_security_tests.rs"]
mod masking_http2_probe_classification_security_tests;
#[cfg(test)]
#[path = "tests/masking_http_probe_boundary_security_tests.rs"]
mod masking_http_probe_boundary_security_tests;
#[cfg(test)]
#[path = "tests/masking_rng_hoist_perf_regression_tests.rs"]
mod masking_rng_hoist_perf_regression_tests;
#[cfg(test)]
#[path = "tests/masking_http2_preface_integration_security_tests.rs"]
mod masking_http2_preface_integration_security_tests;
#[cfg(test)]
#[path = "tests/masking_consume_stress_adversarial_tests.rs"]
mod masking_consume_stress_adversarial_tests;
#[cfg(test)]
#[path = "tests/masking_interface_cache_security_tests.rs"]
mod masking_interface_cache_security_tests;
#[cfg(test)]
#[path = "tests/masking_interface_cache_defense_in_depth_security_tests.rs"]
mod masking_interface_cache_defense_in_depth_security_tests;
#[cfg(test)]
#[path = "tests/masking_interface_cache_concurrency_security_tests.rs"]
mod masking_interface_cache_concurrency_security_tests;
#[cfg(test)]
#[path = "tests/masking_production_cap_regression_security_tests.rs"]
mod masking_production_cap_regression_security_tests;
#[cfg(test)]
#[path = "tests/masking_extended_attack_surface_security_tests.rs"]
mod masking_extended_attack_surface_security_tests;
#[cfg(test)]
#[path = "tests/masking_padding_timeout_adversarial_tests.rs"]
mod masking_padding_timeout_adversarial_tests;
#[cfg(all(test, feature = "redteam_offline_expected_fail"))]
#[path = "tests/masking_offline_target_redteam_expected_fail_tests.rs"]
mod masking_offline_target_redteam_expected_fail_tests;

View File

@ -1,5 +1,7 @@
use std::collections::hash_map::RandomState;
use std::collections::{BTreeSet, HashMap};
#[cfg(test)]
use std::future::Future;
use std::hash::{BuildHasher, Hash};
use std::net::{IpAddr, SocketAddr};
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
@ -39,10 +41,14 @@ const C2ME_CHANNEL_CAPACITY_FALLBACK: usize = 128;
const C2ME_SOFT_PRESSURE_MIN_FREE_SLOTS: usize = 64;
const C2ME_SENDER_FAIRNESS_BUDGET: usize = 32;
const RELAY_IDLE_IO_POLL_MAX: Duration = Duration::from_secs(1);
const TINY_FRAME_DEBT_PER_TINY: u32 = 8;
const TINY_FRAME_DEBT_LIMIT: u32 = 512;
#[cfg(test)]
const C2ME_SEND_TIMEOUT: Duration = Duration::from_millis(50);
#[cfg(not(test))]
const C2ME_SEND_TIMEOUT: Duration = Duration::from_secs(5);
#[cfg(test)]
const RELAY_TEST_STEP_TIMEOUT: Duration = Duration::from_secs(1);
const ME_D2C_FLUSH_BATCH_MAX_FRAMES_MIN: usize = 1;
const ME_D2C_FLUSH_BATCH_MAX_BYTES_MIN: usize = 4096;
const ME_D2C_FRAME_BUF_SHRINK_HYSTERESIS_FACTOR: usize = 2;
@ -94,10 +100,23 @@ fn relay_idle_candidate_registry() -> &'static Mutex<RelayIdleCandidateRegistry>
RELAY_IDLE_CANDIDATE_REGISTRY.get_or_init(|| Mutex::new(RelayIdleCandidateRegistry::default()))
}
fn relay_idle_candidate_registry_lock() -> std::sync::MutexGuard<'static, RelayIdleCandidateRegistry> {
let registry = relay_idle_candidate_registry();
match registry.lock() {
Ok(guard) => guard,
Err(poisoned) => {
let mut guard = poisoned.into_inner();
// Fail closed after panic while holding registry lock: drop all
// candidates and pressure cursors to avoid stale cross-session state.
*guard = RelayIdleCandidateRegistry::default();
registry.clear_poison();
guard
}
}
}
fn mark_relay_idle_candidate(conn_id: u64) -> bool {
let Ok(mut guard) = relay_idle_candidate_registry().lock() else {
return false;
};
let mut guard = relay_idle_candidate_registry_lock();
if guard.by_conn_id.contains_key(&conn_id) {
return false;
@ -116,9 +135,7 @@ fn mark_relay_idle_candidate(conn_id: u64) -> bool {
}
fn clear_relay_idle_candidate(conn_id: u64) {
let Ok(mut guard) = relay_idle_candidate_registry().lock() else {
return;
};
let mut guard = relay_idle_candidate_registry_lock();
if let Some(meta) = guard.by_conn_id.remove(&conn_id) {
guard.ordered.remove(&(meta.mark_order_seq, conn_id));
@ -127,23 +144,17 @@ fn clear_relay_idle_candidate(conn_id: u64) {
#[cfg(test)]
fn oldest_relay_idle_candidate() -> Option<u64> {
let Ok(guard) = relay_idle_candidate_registry().lock() else {
return None;
};
let guard = relay_idle_candidate_registry_lock();
guard.ordered.iter().next().map(|(_, conn_id)| *conn_id)
}
fn note_relay_pressure_event() {
let Ok(mut guard) = relay_idle_candidate_registry().lock() else {
return;
};
let mut guard = relay_idle_candidate_registry_lock();
guard.pressure_event_seq = guard.pressure_event_seq.wrapping_add(1);
}
fn relay_pressure_event_seq() -> u64 {
let Ok(guard) = relay_idle_candidate_registry().lock() else {
return 0;
};
let guard = relay_idle_candidate_registry_lock();
guard.pressure_event_seq
}
@ -152,9 +163,7 @@ fn maybe_evict_idle_candidate_on_pressure(
seen_pressure_seq: &mut u64,
stats: &Stats,
) -> bool {
let Ok(mut guard) = relay_idle_candidate_registry().lock() else {
return false;
};
let mut guard = relay_idle_candidate_registry_lock();
let latest_pressure_seq = guard.pressure_event_seq;
if latest_pressure_seq == *seen_pressure_seq {
@ -199,13 +208,9 @@ fn maybe_evict_idle_candidate_on_pressure(
#[cfg(test)]
fn clear_relay_idle_pressure_state_for_testing() {
if let Some(registry) = RELAY_IDLE_CANDIDATE_REGISTRY.get()
&& let Ok(mut guard) = registry.lock()
{
guard.by_conn_id.clear();
guard.ordered.clear();
guard.pressure_event_seq = 0;
guard.pressure_consumed_seq = 0;
if RELAY_IDLE_CANDIDATE_REGISTRY.get().is_some() {
let mut guard = relay_idle_candidate_registry_lock();
*guard = RelayIdleCandidateRegistry::default();
}
RELAY_IDLE_MARK_SEQ.store(0, Ordering::Relaxed);
}
@ -259,6 +264,7 @@ impl RelayClientIdlePolicy {
struct RelayClientIdleState {
last_client_frame_at: Instant,
soft_idle_marked: bool,
tiny_frame_debt: u32,
}
impl RelayClientIdleState {
@ -266,6 +272,7 @@ impl RelayClientIdleState {
Self {
last_client_frame_at: now,
soft_idle_marked: false,
tiny_frame_debt: 0,
}
}
@ -535,6 +542,7 @@ fn quota_exceeded_for_user(stats: &Stats, user: &str, quota_limit: Option<u64>)
quota_limit.is_some_and(|quota| stats.get_user_total_octets(user) >= quota)
}
#[cfg_attr(not(test), allow(dead_code))]
fn quota_would_be_exceeded_for_user(
stats: &Stats,
user: &str,
@ -551,15 +559,6 @@ fn quota_soft_cap(limit: u64, overshoot: u64) -> u64 {
limit.saturating_add(overshoot)
}
fn quota_exceeded_for_user_soft(
stats: &Stats,
user: &str,
quota_limit: Option<u64>,
overshoot: u64,
) -> bool {
quota_limit.is_some_and(|quota| stats.get_user_total_octets(user) >= quota_soft_cap(quota, overshoot))
}
fn quota_would_be_exceeded_for_user_soft(
stats: &Stats,
user: &str,
@ -567,11 +566,8 @@ fn quota_would_be_exceeded_for_user_soft(
bytes: u64,
overshoot: u64,
) -> bool {
quota_limit.is_some_and(|quota| {
let cap = quota_soft_cap(quota, overshoot);
let used = stats.get_user_total_octets(user);
used >= cap || bytes > cap.saturating_sub(used)
})
let capped_limit = quota_limit.map(|quota| quota_soft_cap(quota, overshoot));
quota_would_be_exceeded_for_user(stats, user, capped_limit, bytes)
}
fn classify_me_d2c_flush_reason(
@ -617,6 +613,16 @@ fn observe_me_d2c_flush_event(
}
}
fn rollback_me2c_quota_reservation(
stats: &Stats,
user: &str,
bytes_me2c: &AtomicU64,
reserved_bytes: u64,
) {
stats.sub_user_octets_to(user, reserved_bytes);
bytes_me2c.fetch_sub(reserved_bytes, Ordering::Relaxed);
}
#[cfg(test)]
fn quota_user_lock_test_guard() -> &'static Mutex<()> {
static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
@ -630,6 +636,19 @@ fn quota_user_lock_test_scope() -> std::sync::MutexGuard<'static, ()> {
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
#[cfg(test)]
fn relay_idle_pressure_test_guard() -> &'static Mutex<()> {
static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
TEST_LOCK.get_or_init(|| Mutex::new(()))
}
#[cfg(test)]
pub(crate) fn relay_idle_pressure_test_scope() -> std::sync::MutexGuard<'static, ()> {
relay_idle_pressure_test_guard()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
fn quota_overflow_user_lock(user: &str) -> Arc<AsyncMutex<()>> {
let stripes = QUOTA_USER_OVERFLOW_LOCKS.get_or_init(|| {
(0..QUOTA_OVERFLOW_LOCK_STRIPES)
@ -665,6 +684,11 @@ fn quota_user_lock(user: &str) -> Arc<AsyncMutex<()>> {
}
}
#[cfg(test)]
pub(crate) fn cross_mode_quota_user_lock_for_tests(user: &str) -> Arc<AsyncMutex<()>> {
crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(user)
}
async fn enqueue_c2me_command(
tx: &mpsc::Sender<C2MeCommand>,
cmd: C2MeCommand,
@ -690,6 +714,16 @@ async fn enqueue_c2me_command(
}
}
#[cfg(test)]
async fn run_relay_test_step_timeout<F, T>(context: &'static str, fut: F) -> T
where
F: Future<Output = T>,
{
timeout(RELAY_TEST_STEP_TIMEOUT, fut)
.await
.unwrap_or_else(|_| panic!("{context} exceeded {}s", RELAY_TEST_STEP_TIMEOUT.as_secs()))
}
pub(crate) async fn handle_via_middle_proxy<R, W>(
mut crypto_reader: CryptoReader<R>,
crypto_writer: CryptoWriter<W>,
@ -710,6 +744,8 @@ where
{
let user = success.user.clone();
let quota_limit = config.access.user_data_quota.get(&user).copied();
let cross_mode_quota_lock =
quota_limit.map(|_| crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user));
let peer = success.peer;
let proto_tag = success.proto_tag;
let pool_generation = me_pool.current_generation();
@ -836,6 +872,7 @@ where
let stats_clone = stats.clone();
let rng_clone = rng.clone();
let user_clone = user.clone();
let cross_mode_quota_lock_me_writer = cross_mode_quota_lock.clone();
let last_downstream_activity_ms_clone = last_downstream_activity_ms.clone();
let bytes_me2c_clone = bytes_me2c.clone();
let d2c_flush_policy = MeD2cFlushPolicy::from_config(&config);
@ -857,7 +894,7 @@ where
let first_is_downstream_activity =
matches!(&first, MeResponse::Data { .. } | MeResponse::Ack(_));
match process_me_writer_response(
match process_me_writer_response_with_cross_mode_lock(
first,
&mut writer,
proto_tag,
@ -867,6 +904,7 @@ where
&user_clone,
quota_limit,
d2c_flush_policy.quota_soft_overshoot_bytes,
cross_mode_quota_lock_me_writer.as_ref(),
bytes_me2c_clone.as_ref(),
conn_id,
d2c_flush_policy.ack_flush_immediate,
@ -915,7 +953,7 @@ where
let next_is_downstream_activity =
matches!(&next, MeResponse::Data { .. } | MeResponse::Ack(_));
match process_me_writer_response(
match process_me_writer_response_with_cross_mode_lock(
next,
&mut writer,
proto_tag,
@ -925,6 +963,7 @@ where
&user_clone,
quota_limit,
d2c_flush_policy.quota_soft_overshoot_bytes,
cross_mode_quota_lock_me_writer.as_ref(),
bytes_me2c_clone.as_ref(),
conn_id,
d2c_flush_policy.ack_flush_immediate,
@ -976,7 +1015,7 @@ where
Ok(Some(next)) => {
let next_is_downstream_activity =
matches!(&next, MeResponse::Data { .. } | MeResponse::Ack(_));
match process_me_writer_response(
match process_me_writer_response_with_cross_mode_lock(
next,
&mut writer,
proto_tag,
@ -986,6 +1025,7 @@ where
&user_clone,
quota_limit,
d2c_flush_policy.quota_soft_overshoot_bytes,
cross_mode_quota_lock_me_writer.as_ref(),
bytes_me2c_clone.as_ref(),
conn_id,
d2c_flush_policy.ack_flush_immediate,
@ -1039,7 +1079,7 @@ where
let extra_is_downstream_activity =
matches!(&extra, MeResponse::Data { .. } | MeResponse::Ack(_));
match process_me_writer_response(
match process_me_writer_response_with_cross_mode_lock(
extra,
&mut writer,
proto_tag,
@ -1049,6 +1089,7 @@ where
&user_clone,
quota_limit,
d2c_flush_policy.quota_soft_overshoot_bytes,
cross_mode_quota_lock_me_writer.as_ref(),
bytes_me2c_clone.as_ref(),
conn_id,
d2c_flush_policy.ack_flush_immediate,
@ -1221,6 +1262,14 @@ where
if let Some(limit) = quota_limit {
let quota_lock = quota_user_lock(&user);
let _quota_guard = quota_lock.lock().await;
let Some(cross_mode_lock) = cross_mode_quota_lock.as_ref() else {
main_result = Err(ProxyError::Proxy(
"cross-mode quota lock missing for quota-limited session"
.to_string(),
));
break;
};
let _cross_mode_quota_guard = cross_mode_lock.lock().await;
stats.add_user_octets_from(&user, payload.len() as u64);
if quota_exceeded_for_user(stats.as_ref(), &user, Some(limit)) {
main_result = Err(ProxyError::DataQuotaExceeded {
@ -1320,6 +1369,8 @@ async fn read_client_payload_with_idle_policy<R>(
where
R: AsyncRead + Unpin + Send + 'static,
{
const LEGACY_MAX_CONSECUTIVE_ZERO_LEN_FRAMES: u32 = 4;
async fn read_exact_with_policy<R>(
client_reader: &mut CryptoReader<R>,
buf: &mut [u8],
@ -1458,6 +1509,7 @@ where
Ok(())
}
let mut consecutive_zero_len_frames = 0u32;
loop {
let (len, quickack, raw_len_bytes) = match proto_tag {
ProtoTag::Abridged => {
@ -1538,6 +1590,27 @@ where
};
if len == 0 {
idle_state.tiny_frame_debt = idle_state
.tiny_frame_debt
.saturating_add(TINY_FRAME_DEBT_PER_TINY);
if idle_state.tiny_frame_debt >= TINY_FRAME_DEBT_LIMIT {
stats.increment_relay_protocol_desync_close_total();
return Err(ProxyError::Proxy(format!(
"Tiny frame overhead limit exceeded: debt={}, conn_id={}",
idle_state.tiny_frame_debt, forensics.conn_id
)));
}
if !idle_policy.enabled {
consecutive_zero_len_frames =
consecutive_zero_len_frames.saturating_add(1);
if consecutive_zero_len_frames > LEGACY_MAX_CONSECUTIVE_ZERO_LEN_FRAMES {
stats.increment_relay_protocol_desync_close_total();
return Err(ProxyError::Proxy(
"Excessive zero-length abridged frames".to_string(),
));
}
}
continue;
}
if len < 4 && proto_tag != ProtoTag::Abridged {
@ -1606,6 +1679,7 @@ where
}
*frame_counter += 1;
idle_state.on_client_frame(Instant::now());
idle_state.tiny_frame_debt = idle_state.tiny_frame_debt.saturating_sub(1);
clear_relay_idle_candidate(forensics.conn_id);
return Ok(Some((payload, quickack)));
}
@ -1681,6 +1755,7 @@ enum MeWriterResponseOutcome {
Close,
}
#[cfg(test)]
async fn process_me_writer_response<W>(
response: MeResponse,
client_writer: &mut CryptoWriter<W>,
@ -1696,6 +1771,44 @@ async fn process_me_writer_response<W>(
ack_flush_immediate: bool,
batched: bool,
) -> Result<MeWriterResponseOutcome>
where
W: AsyncWrite + Unpin + Send + 'static,
{
process_me_writer_response_with_cross_mode_lock(
response,
client_writer,
proto_tag,
rng,
frame_buf,
stats,
user,
quota_limit,
quota_soft_overshoot_bytes,
None,
bytes_me2c,
conn_id,
ack_flush_immediate,
batched,
)
.await
}
async fn process_me_writer_response_with_cross_mode_lock<W>(
response: MeResponse,
client_writer: &mut CryptoWriter<W>,
proto_tag: ProtoTag,
rng: &SecureRandom,
frame_buf: &mut Vec<u8>,
stats: &Stats,
user: &str,
quota_limit: Option<u64>,
quota_soft_overshoot_bytes: u64,
cross_mode_quota_lock: Option<&Arc<AsyncMutex<()>>>,
bytes_me2c: &AtomicU64,
conn_id: u64,
ack_flush_immediate: bool,
batched: bool,
) -> Result<MeWriterResponseOutcome>
where
W: AsyncWrite + Unpin + Send + 'static,
{
@ -1707,39 +1820,76 @@ where
trace!(conn_id, bytes = data.len(), flags, "ME->C data");
}
let data_len = data.len() as u64;
if quota_would_be_exceeded_for_user_soft(
stats,
user,
quota_limit,
data_len,
quota_soft_overshoot_bytes,
) {
stats.increment_me_d2c_quota_reject_total(MeD2cQuotaRejectStage::PreWrite);
return Err(ProxyError::DataQuotaExceeded {
user: user.to_string(),
});
}
if let Some(limit) = quota_limit {
let owned_cross_mode_lock;
let cross_mode_lock = if let Some(lock) = cross_mode_quota_lock {
lock
} else {
owned_cross_mode_lock =
crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(user);
&owned_cross_mode_lock
};
let cross_mode_quota_guard = cross_mode_lock.lock().await;
let soft_limit = quota_soft_cap(limit, quota_soft_overshoot_bytes);
if quota_would_be_exceeded_for_user_soft(
stats,
user,
Some(limit),
data_len,
quota_soft_overshoot_bytes,
) {
stats.increment_me_d2c_quota_reject_total(MeD2cQuotaRejectStage::PreWrite);
return Err(ProxyError::DataQuotaExceeded {
user: user.to_string(),
});
}
let write_mode =
write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf)
.await?;
stats.increment_me_d2c_write_mode(write_mode);
// Reserve quota before awaiting network I/O to avoid same-user HoL stalls.
// If reservation loses a race or write fails, we roll back immediately.
bytes_me2c.fetch_add(data_len, Ordering::Relaxed);
stats.add_user_octets_to(user, data_len);
bytes_me2c.fetch_add(data.len() as u64, Ordering::Relaxed);
stats.add_user_octets_to(user, data.len() as u64);
stats.increment_me_d2c_data_frames_total();
stats.add_me_d2c_payload_bytes_total(data.len() as u64);
if stats.get_user_total_octets(user) > soft_limit {
rollback_me2c_quota_reservation(stats, user, bytes_me2c, data_len);
stats.increment_me_d2c_quota_reject_total(MeD2cQuotaRejectStage::PreWrite);
return Err(ProxyError::DataQuotaExceeded {
user: user.to_string(),
});
}
if quota_exceeded_for_user_soft(
stats,
user,
quota_limit,
quota_soft_overshoot_bytes,
) {
stats.increment_me_d2c_quota_reject_total(MeD2cQuotaRejectStage::PostWrite);
return Err(ProxyError::DataQuotaExceeded {
user: user.to_string(),
});
// Keep cross-mode lock scope explicit and minimal: quota reservation is serialized,
// but socket I/O proceeds without holding same-user cross-mode admission lock.
drop(cross_mode_quota_guard);
let write_mode =
match write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf)
.await
{
Ok(mode) => mode,
Err(err) => {
rollback_me2c_quota_reservation(stats, user, bytes_me2c, data_len);
return Err(err);
}
};
stats.increment_me_d2c_data_frames_total();
stats.add_me_d2c_payload_bytes_total(data_len);
stats.increment_me_d2c_write_mode(write_mode);
// Do not fail immediately on exact boundary after a successful write.
// Returning an error here can bypass batch flush in the caller and risk
// dropping buffered ciphertext from CryptoWriter. The next frame is
// rejected by the pre-check at function entry.
} else {
let write_mode =
write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf)
.await?;
bytes_me2c.fetch_add(data_len, Ordering::Relaxed);
stats.add_user_octets_to(user, data_len);
stats.increment_me_d2c_data_frames_total();
stats.add_me_d2c_payload_bytes_total(data_len);
stats.increment_me_d2c_write_mode(write_mode);
}
Ok(MeWriterResponseOutcome::Continue {
@ -1978,3 +2128,55 @@ mod length_cast_hardening_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_blackhat_campaign_integration_tests.rs"]
mod blackhat_campaign_integration_tests;
#[cfg(test)]
#[path = "tests/middle_relay_hol_quota_security_tests.rs"]
mod hol_quota_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_quota_reservation_adversarial_tests.rs"]
mod quota_reservation_adversarial_tests;
#[cfg(test)]
#[path = "tests/middle_relay_idle_registry_poison_security_tests.rs"]
mod middle_relay_idle_registry_poison_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_zero_length_frame_security_tests.rs"]
mod middle_relay_zero_length_frame_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_tiny_frame_debt_security_tests.rs"]
mod middle_relay_tiny_frame_debt_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_tiny_frame_debt_concurrency_security_tests.rs"]
mod middle_relay_tiny_frame_debt_concurrency_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_tiny_frame_debt_proto_chunking_security_tests.rs"]
mod middle_relay_tiny_frame_debt_proto_chunking_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_cross_mode_quota_reservation_security_tests.rs"]
mod middle_relay_cross_mode_quota_reservation_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_cross_mode_quota_lock_matrix_security_tests.rs"]
mod middle_relay_cross_mode_quota_lock_matrix_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_cross_mode_lookup_efficiency_security_tests.rs"]
mod middle_relay_cross_mode_lookup_efficiency_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_cross_mode_lock_release_regression_security_tests.rs"]
mod middle_relay_cross_mode_lock_release_regression_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_quota_extended_attack_surface_security_tests.rs"]
mod middle_relay_quota_extended_attack_surface_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_quota_reservation_extreme_security_tests.rs"]
mod middle_relay_quota_reservation_extreme_security_tests;

View File

@ -64,6 +64,7 @@ pub mod direct_relay;
pub mod handshake;
pub mod masking;
pub mod middle_relay;
pub mod quota_lock_registry;
pub mod relay;
pub mod route_mode;
pub mod session_eviction;

View File

@ -0,0 +1,88 @@
use dashmap::DashMap;
use std::sync::{Arc, OnceLock};
use tokio::sync::Mutex;
#[cfg(test)]
use std::sync::atomic::{AtomicUsize, Ordering};
#[cfg(test)]
const CROSS_MODE_QUOTA_USER_LOCKS_MAX: usize = 64;
#[cfg(not(test))]
const CROSS_MODE_QUOTA_USER_LOCKS_MAX: usize = 4_096;
#[cfg(test)]
const CROSS_MODE_QUOTA_OVERFLOW_LOCK_STRIPES: usize = 16;
#[cfg(not(test))]
const CROSS_MODE_QUOTA_OVERFLOW_LOCK_STRIPES: usize = 256;
static CROSS_MODE_QUOTA_USER_LOCKS: OnceLock<DashMap<String, Arc<Mutex<()>>>> = OnceLock::new();
static CROSS_MODE_QUOTA_USER_OVERFLOW_LOCKS: OnceLock<Vec<Arc<Mutex<()>>>> = OnceLock::new();
#[cfg(test)]
static CROSS_MODE_QUOTA_USER_LOCK_LOOKUPS: AtomicUsize = AtomicUsize::new(0);
#[cfg(test)]
static CROSS_MODE_QUOTA_USER_LOOKUPS_BY_USER: OnceLock<DashMap<String, usize>> = OnceLock::new();
fn cross_mode_quota_overflow_user_lock(user: &str) -> Arc<Mutex<()>> {
let stripes = CROSS_MODE_QUOTA_USER_OVERFLOW_LOCKS.get_or_init(|| {
(0..CROSS_MODE_QUOTA_OVERFLOW_LOCK_STRIPES)
.map(|_| Arc::new(Mutex::new(())))
.collect()
});
let hash = crc32fast::hash(user.as_bytes()) as usize;
Arc::clone(&stripes[hash % stripes.len()])
}
pub(crate) fn cross_mode_quota_user_lock(user: &str) -> Arc<Mutex<()>> {
#[cfg(test)]
{
CROSS_MODE_QUOTA_USER_LOCK_LOOKUPS.fetch_add(1, Ordering::Relaxed);
let lookups = CROSS_MODE_QUOTA_USER_LOOKUPS_BY_USER.get_or_init(DashMap::new);
let mut entry = lookups.entry(user.to_string()).or_insert(0);
*entry += 1;
}
let locks = CROSS_MODE_QUOTA_USER_LOCKS.get_or_init(DashMap::new);
if let Some(existing) = locks.get(user) {
return Arc::clone(existing.value());
}
if locks.len() >= CROSS_MODE_QUOTA_USER_LOCKS_MAX {
locks.retain(|_, value| Arc::strong_count(value) > 1);
}
if locks.len() >= CROSS_MODE_QUOTA_USER_LOCKS_MAX {
return cross_mode_quota_overflow_user_lock(user);
}
let created = Arc::new(Mutex::new(()));
match locks.entry(user.to_string()) {
dashmap::mapref::entry::Entry::Occupied(entry) => Arc::clone(entry.get()),
dashmap::mapref::entry::Entry::Vacant(entry) => {
entry.insert(Arc::clone(&created));
created
}
}
}
#[cfg(test)]
pub(crate) fn reset_cross_mode_quota_user_lock_lookup_count_for_tests() {
CROSS_MODE_QUOTA_USER_LOCK_LOOKUPS.store(0, Ordering::Relaxed);
let lookups = CROSS_MODE_QUOTA_USER_LOOKUPS_BY_USER.get_or_init(DashMap::new);
lookups.clear();
}
#[cfg(test)]
pub(crate) fn cross_mode_quota_user_lock_lookup_count_for_tests() -> usize {
CROSS_MODE_QUOTA_USER_LOCK_LOOKUPS.load(Ordering::Relaxed)
}
#[cfg(test)]
pub(crate) fn cross_mode_quota_user_lock_lookup_count_for_user_for_tests(user: &str) -> usize {
let lookups = CROSS_MODE_QUOTA_USER_LOOKUPS_BY_USER.get_or_init(DashMap::new);
lookups.get(user).map(|entry| *entry).unwrap_or(0)
}
#[cfg(test)]
#[path = "tests/quota_lock_registry_cross_mode_adversarial_tests.rs"]
mod quota_lock_registry_cross_mode_adversarial_tests;

View File

@ -62,7 +62,8 @@ use std::sync::{Arc, Mutex, OnceLock};
use std::task::{Context, Poll};
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf, copy_bidirectional_with_sizes};
use tokio::time::Instant;
use tokio::sync::Mutex as AsyncMutex;
use tokio::time::{Instant, Sleep};
use tracing::{debug, trace, warn};
// ============= Constants =============
@ -209,12 +210,16 @@ struct StatsIo<S> {
counters: Arc<SharedCounters>,
stats: Arc<Stats>,
user: String,
quota_lock: Option<Arc<Mutex<()>>>,
cross_mode_quota_lock: Option<Arc<AsyncMutex<()>>>,
quota_limit: Option<u64>,
quota_exceeded: Arc<AtomicBool>,
quota_read_wake_scheduled: bool,
quota_write_wake_scheduled: bool,
quota_read_retry_active: Arc<AtomicBool>,
quota_write_retry_active: Arc<AtomicBool>,
quota_read_retry_sleep: Option<Pin<Box<Sleep>>>,
quota_write_retry_sleep: Option<Pin<Box<Sleep>>>,
quota_read_retry_attempt: u8,
quota_write_retry_attempt: u8,
epoch: Instant,
}
@ -230,30 +235,29 @@ impl<S> StatsIo<S> {
) -> Self {
// Mark initial activity so the watchdog doesn't fire before data flows
counters.touch(Instant::now(), epoch);
let quota_lock = quota_limit.map(|_| quota_user_lock(&user));
let cross_mode_quota_lock = quota_limit
.map(|_| crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user));
Self {
inner,
counters,
stats,
user,
quota_lock,
cross_mode_quota_lock,
quota_limit,
quota_exceeded,
quota_read_wake_scheduled: false,
quota_write_wake_scheduled: false,
quota_read_retry_active: Arc::new(AtomicBool::new(false)),
quota_write_retry_active: Arc::new(AtomicBool::new(false)),
quota_read_retry_sleep: None,
quota_write_retry_sleep: None,
quota_read_retry_attempt: 0,
quota_write_retry_attempt: 0,
epoch,
}
}
}
impl<S> Drop for StatsIo<S> {
fn drop(&mut self) {
self.quota_read_retry_active.store(false, Ordering::Relaxed);
self.quota_write_retry_active
.store(false, Ordering::Relaxed);
}
}
#[derive(Debug)]
struct QuotaIoSentinel;
@ -281,20 +285,69 @@ fn is_quota_io_error(err: &io::Error) -> bool {
const QUOTA_CONTENTION_RETRY_INTERVAL: Duration = Duration::from_millis(1);
#[cfg(not(test))]
const QUOTA_CONTENTION_RETRY_INTERVAL: Duration = Duration::from_millis(2);
#[cfg(test)]
const QUOTA_CONTENTION_RETRY_MAX_INTERVAL: Duration = Duration::from_millis(16);
#[cfg(not(test))]
const QUOTA_CONTENTION_RETRY_MAX_INTERVAL: Duration = Duration::from_millis(64);
fn spawn_quota_retry_waker(retry_active: Arc<AtomicBool>, waker: std::task::Waker) {
tokio::task::spawn(async move {
loop {
if !retry_active.load(Ordering::Relaxed) {
break;
}
tokio::time::sleep(QUOTA_CONTENTION_RETRY_INTERVAL).await;
if !retry_active.load(Ordering::Relaxed) {
break;
}
waker.wake_by_ref();
}
});
#[cfg(test)]
static QUOTA_RETRY_SLEEP_ALLOCS: AtomicU64 = AtomicU64::new(0);
#[cfg(test)]
static QUOTA_USER_LOCK_EVICTOR_SPAWN_COUNT: AtomicU64 = AtomicU64::new(0);
#[cfg(test)]
pub(crate) fn reset_quota_retry_sleep_allocs_for_tests() {
QUOTA_RETRY_SLEEP_ALLOCS.store(0, Ordering::Relaxed);
}
#[cfg(test)]
pub(crate) fn quota_retry_sleep_allocs_for_tests() -> u64 {
QUOTA_RETRY_SLEEP_ALLOCS.load(Ordering::Relaxed)
}
#[inline]
fn quota_contention_retry_delay(retry_attempt: u8) -> Duration {
let shift = u32::from(retry_attempt.min(5));
let multiplier = 1_u32 << shift;
QUOTA_CONTENTION_RETRY_INTERVAL
.saturating_mul(multiplier)
.min(QUOTA_CONTENTION_RETRY_MAX_INTERVAL)
}
#[inline]
fn reset_quota_retry_scheduler(
sleep_slot: &mut Option<Pin<Box<Sleep>>>,
wake_scheduled: &mut bool,
retry_attempt: &mut u8,
) {
*wake_scheduled = false;
*sleep_slot = None;
*retry_attempt = 0;
}
fn poll_quota_retry_sleep(
sleep_slot: &mut Option<Pin<Box<Sleep>>>,
wake_scheduled: &mut bool,
retry_attempt: &mut u8,
cx: &mut Context<'_>,
) {
if !*wake_scheduled {
*wake_scheduled = true;
#[cfg(test)]
QUOTA_RETRY_SLEEP_ALLOCS.fetch_add(1, Ordering::Relaxed);
*sleep_slot = Some(Box::pin(tokio::time::sleep(quota_contention_retry_delay(
*retry_attempt,
))));
}
if let Some(sleep) = sleep_slot.as_mut()
&& sleep.as_mut().poll(cx).is_ready()
{
*sleep_slot = None;
*wake_scheduled = false;
*retry_attempt = retry_attempt.saturating_add(1);
cx.waker().wake_by_ref();
}
}
static QUOTA_USER_LOCKS: OnceLock<DashMap<String, Arc<Mutex<()>>>> = OnceLock::new();
@ -333,16 +386,47 @@ fn quota_overflow_user_lock(user: &str) -> Arc<Mutex<()>> {
Arc::clone(&stripes[hash % stripes.len()])
}
pub(crate) fn quota_user_lock_evict() {
if let Some(locks) = QUOTA_USER_LOCKS.get() {
locks.retain(|_, value| Arc::strong_count(value) > 1);
}
}
pub(crate) fn spawn_quota_user_lock_evictor(interval: Duration) -> tokio::task::JoinHandle<()> {
let interval = interval.max(Duration::from_millis(1));
#[cfg(test)]
QUOTA_USER_LOCK_EVICTOR_SPAWN_COUNT.fetch_add(1, Ordering::Relaxed);
tokio::spawn(async move {
loop {
tokio::time::sleep(interval).await;
quota_user_lock_evict();
}
})
}
#[cfg(test)]
pub(crate) fn spawn_quota_user_lock_evictor_for_tests(
interval: Duration,
) -> tokio::task::JoinHandle<()> {
spawn_quota_user_lock_evictor(interval)
}
#[cfg(test)]
pub(crate) fn reset_quota_user_lock_evictor_spawn_count_for_tests() {
QUOTA_USER_LOCK_EVICTOR_SPAWN_COUNT.store(0, Ordering::Relaxed);
}
#[cfg(test)]
pub(crate) fn quota_user_lock_evictor_spawn_count_for_tests() -> u64 {
QUOTA_USER_LOCK_EVICTOR_SPAWN_COUNT.load(Ordering::Relaxed)
}
fn quota_user_lock(user: &str) -> Arc<Mutex<()>> {
let locks = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
if let Some(existing) = locks.get(user) {
return Arc::clone(existing.value());
}
if locks.len() >= QUOTA_USER_LOCKS_MAX {
locks.retain(|_, value| Arc::strong_count(value) > 1);
}
if locks.len() >= QUOTA_USER_LOCKS_MAX {
return quota_overflow_user_lock(user);
}
@ -357,6 +441,11 @@ fn quota_user_lock(user: &str) -> Arc<Mutex<()>> {
}
}
#[cfg(test)]
pub(crate) fn cross_mode_quota_user_lock_for_tests(user: &str) -> Arc<AsyncMutex<()>> {
crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(user)
}
impl<S: AsyncRead + Unpin> AsyncRead for StatsIo<S> {
fn poll_read(
self: Pin<&mut Self>,
@ -368,26 +457,16 @@ impl<S: AsyncRead + Unpin> AsyncRead for StatsIo<S> {
return Poll::Ready(Err(quota_io_error()));
}
let quota_lock = this
.quota_limit
.is_some()
.then(|| quota_user_lock(&this.user));
let _quota_guard = if let Some(lock) = quota_lock.as_ref() {
let _quota_guard = if let Some(lock) = this.quota_lock.as_ref() {
match lock.try_lock() {
Ok(guard) => {
this.quota_read_wake_scheduled = false;
this.quota_read_retry_active.store(false, Ordering::Relaxed);
Some(guard)
}
Ok(guard) => Some(guard),
Err(_) => {
if !this.quota_read_wake_scheduled {
this.quota_read_wake_scheduled = true;
this.quota_read_retry_active.store(true, Ordering::Relaxed);
spawn_quota_retry_waker(
Arc::clone(&this.quota_read_retry_active),
cx.waker().clone(),
);
}
poll_quota_retry_sleep(
&mut this.quota_read_retry_sleep,
&mut this.quota_read_wake_scheduled,
&mut this.quota_read_retry_attempt,
cx,
);
return Poll::Pending;
}
}
@ -395,6 +474,29 @@ impl<S: AsyncRead + Unpin> AsyncRead for StatsIo<S> {
None
};
let _cross_mode_quota_guard = if let Some(lock) = this.cross_mode_quota_lock.as_ref() {
match lock.try_lock() {
Ok(guard) => Some(guard),
Err(_) => {
poll_quota_retry_sleep(
&mut this.quota_read_retry_sleep,
&mut this.quota_read_wake_scheduled,
&mut this.quota_read_retry_attempt,
cx,
);
return Poll::Pending;
}
}
} else {
None
};
reset_quota_retry_scheduler(
&mut this.quota_read_retry_sleep,
&mut this.quota_read_wake_scheduled,
&mut this.quota_read_retry_attempt,
);
if let Some(limit) = this.quota_limit
&& this.stats.get_user_total_octets(&this.user) >= limit
{
@ -460,27 +562,16 @@ impl<S: AsyncWrite + Unpin> AsyncWrite for StatsIo<S> {
return Poll::Ready(Err(quota_io_error()));
}
let quota_lock = this
.quota_limit
.is_some()
.then(|| quota_user_lock(&this.user));
let _quota_guard = if let Some(lock) = quota_lock.as_ref() {
let _quota_guard = if let Some(lock) = this.quota_lock.as_ref() {
match lock.try_lock() {
Ok(guard) => {
this.quota_write_wake_scheduled = false;
this.quota_write_retry_active
.store(false, Ordering::Relaxed);
Some(guard)
}
Ok(guard) => Some(guard),
Err(_) => {
if !this.quota_write_wake_scheduled {
this.quota_write_wake_scheduled = true;
this.quota_write_retry_active.store(true, Ordering::Relaxed);
spawn_quota_retry_waker(
Arc::clone(&this.quota_write_retry_active),
cx.waker().clone(),
);
}
poll_quota_retry_sleep(
&mut this.quota_write_retry_sleep,
&mut this.quota_write_wake_scheduled,
&mut this.quota_write_retry_attempt,
cx,
);
return Poll::Pending;
}
}
@ -488,6 +579,29 @@ impl<S: AsyncWrite + Unpin> AsyncWrite for StatsIo<S> {
None
};
let _cross_mode_quota_guard = if let Some(lock) = this.cross_mode_quota_lock.as_ref() {
match lock.try_lock() {
Ok(guard) => Some(guard),
Err(_) => {
poll_quota_retry_sleep(
&mut this.quota_write_retry_sleep,
&mut this.quota_write_wake_scheduled,
&mut this.quota_write_retry_attempt,
cx,
);
return Poll::Pending;
}
}
} else {
None
};
reset_quota_retry_scheduler(
&mut this.quota_write_retry_sleep,
&mut this.quota_write_wake_scheduled,
&mut this.quota_write_retry_attempt,
);
let write_buf = if let Some(limit) = this.quota_limit {
let used = this.stats.get_user_total_octets(&this.user);
if used >= limit {
@ -780,6 +894,10 @@ mod relay_quota_model_adversarial_tests;
#[path = "tests/relay_quota_overflow_regression_tests.rs"]
mod relay_quota_overflow_regression_tests;
#[cfg(test)]
#[path = "tests/relay_quota_extended_attack_surface_security_tests.rs"]
mod relay_quota_extended_attack_surface_security_tests;
#[cfg(test)]
#[path = "tests/relay_watchdog_delta_security_tests.rs"]
mod relay_watchdog_delta_security_tests;
@ -791,3 +909,63 @@ mod relay_quota_waker_storm_adversarial_tests;
#[cfg(test)]
#[path = "tests/relay_quota_wake_liveness_regression_tests.rs"]
mod relay_quota_wake_liveness_regression_tests;
#[cfg(test)]
#[path = "tests/relay_quota_lock_identity_security_tests.rs"]
mod relay_quota_lock_identity_security_tests;
#[cfg(test)]
#[path = "tests/relay_cross_mode_quota_lock_security_tests.rs"]
mod relay_cross_mode_quota_lock_security_tests;
#[cfg(test)]
#[path = "tests/relay_quota_retry_scheduler_tdd_tests.rs"]
mod relay_quota_retry_scheduler_tdd_tests;
#[cfg(test)]
#[path = "tests/relay_cross_mode_quota_fairness_tdd_tests.rs"]
mod relay_cross_mode_quota_fairness_tdd_tests;
#[cfg(test)]
#[path = "tests/relay_cross_mode_pipeline_hol_integration_security_tests.rs"]
mod relay_cross_mode_pipeline_hol_integration_security_tests;
#[cfg(test)]
#[path = "tests/relay_cross_mode_pipeline_latency_benchmark_security_tests.rs"]
mod relay_cross_mode_pipeline_latency_benchmark_security_tests;
#[cfg(test)]
#[path = "tests/relay_quota_retry_backoff_security_tests.rs"]
mod relay_quota_retry_backoff_security_tests;
#[cfg(test)]
#[path = "tests/relay_quota_retry_backoff_benchmark_security_tests.rs"]
mod relay_quota_retry_backoff_benchmark_security_tests;
#[cfg(test)]
#[path = "tests/relay_dual_lock_backoff_regression_security_tests.rs"]
mod relay_dual_lock_backoff_regression_security_tests;
#[cfg(test)]
#[path = "tests/relay_dual_lock_contention_matrix_security_tests.rs"]
mod relay_dual_lock_contention_matrix_security_tests;
#[cfg(test)]
#[path = "tests/relay_dual_lock_race_harness_security_tests.rs"]
mod relay_dual_lock_race_harness_security_tests;
#[cfg(test)]
#[path = "tests/relay_dual_lock_alternating_contention_security_tests.rs"]
mod relay_dual_lock_alternating_contention_security_tests;
#[cfg(test)]
#[path = "tests/relay_quota_retry_allocation_latency_security_tests.rs"]
mod relay_quota_retry_allocation_latency_security_tests;
#[cfg(test)]
#[path = "tests/relay_quota_lock_eviction_lifecycle_tdd_tests.rs"]
mod relay_quota_lock_eviction_lifecycle_tdd_tests;
#[cfg(test)]
#[path = "tests/relay_quota_lock_eviction_stress_security_tests.rs"]
mod relay_quota_lock_eviction_stress_security_tests;

View File

@ -0,0 +1,100 @@
use super::*;
use crate::config::{UpstreamConfig, UpstreamType};
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::net::TcpListener;
use tokio::time::Duration;
fn new_upstream_manager(stats: Arc<Stats>) -> Arc<UpstreamManager> {
Arc::new(UpstreamManager::new(
vec![UpstreamConfig {
upstream_type: UpstreamType::Direct {
interface: None,
bind_addresses: None,
},
weight: 1,
enabled: true,
scopes: String::new(),
selected_scope: String::new(),
}],
1,
1,
1,
1,
false,
stats,
))
}
#[tokio::test]
async fn fragmented_connect_probe_is_classified_as_http_via_prefetch_window() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let accept_task = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut got = Vec::new();
stream.read_to_end(&mut got).await.unwrap();
got
});
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = true;
cfg.general.beobachten_minutes = 1;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = backend_addr.port();
cfg.general.modes.classic = false;
cfg.general.modes.secure = false;
let config = Arc::new(cfg);
let stats = Arc::new(Stats::new());
let beobachten = Arc::new(BeobachtenStore::new());
let (server_side, mut client_side) = duplex(4096);
let peer: SocketAddr = "198.51.100.251:57501".parse().unwrap();
let handler = tokio::spawn(handle_client_stream(
server_side,
peer,
config,
stats.clone(),
new_upstream_manager(stats),
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
None,
Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
None,
Arc::new(UserIpTracker::new()),
beobachten.clone(),
false,
));
client_side.write_all(b"CONNE").await.unwrap();
client_side
.write_all(b"CT example.org:443 HTTP/1.1\r\nHost: example.org\r\n\r\n")
.await
.unwrap();
client_side.shutdown().await.unwrap();
let forwarded = tokio::time::timeout(Duration::from_secs(3), accept_task)
.await
.unwrap()
.unwrap();
assert!(
forwarded.starts_with(b"CONNECT example.org:443 HTTP/1.1"),
"mask backend must receive the full fragmented CONNECT probe"
);
let result = tokio::time::timeout(Duration::from_secs(3), handler)
.await
.unwrap()
.unwrap();
assert!(result.is_ok());
let snapshot = beobachten.snapshot_text(Duration::from_secs(60));
assert!(snapshot.contains("[HTTP]"));
assert!(snapshot.contains("198.51.100.251-1"));
}

View File

@ -0,0 +1,129 @@
use super::*;
use crate::config::{UpstreamConfig, UpstreamType};
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::net::TcpListener;
use tokio::time::{Duration, sleep};
fn new_upstream_manager(stats: Arc<Stats>) -> Arc<UpstreamManager> {
Arc::new(UpstreamManager::new(
vec![UpstreamConfig {
upstream_type: UpstreamType::Direct {
interface: None,
bind_addresses: None,
},
weight: 1,
enabled: true,
scopes: String::new(),
selected_scope: String::new(),
}],
1,
1,
1,
1,
false,
stats,
))
}
async fn run_http2_fragment_case(split_at: usize, delay_ms: u64, peer: SocketAddr) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let preface = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n".to_vec();
let accept_task = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut got = Vec::new();
stream.read_to_end(&mut got).await.unwrap();
got
});
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = true;
cfg.general.beobachten_minutes = 1;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = backend_addr.port();
cfg.general.modes.classic = false;
cfg.general.modes.secure = false;
let config = Arc::new(cfg);
let stats = Arc::new(Stats::new());
let beobachten = Arc::new(BeobachtenStore::new());
let (server_side, mut client_side) = duplex(4096);
let handler = tokio::spawn(handle_client_stream(
server_side,
peer,
config,
stats.clone(),
new_upstream_manager(stats),
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
None,
Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
None,
Arc::new(UserIpTracker::new()),
beobachten.clone(),
false,
));
let first = split_at.min(preface.len());
client_side.write_all(&preface[..first]).await.unwrap();
if first < preface.len() {
sleep(Duration::from_millis(delay_ms)).await;
client_side.write_all(&preface[first..]).await.unwrap();
}
client_side.shutdown().await.unwrap();
let forwarded = tokio::time::timeout(Duration::from_secs(3), accept_task)
.await
.unwrap()
.unwrap();
assert!(
forwarded.starts_with(&preface),
"mask backend must receive an intact HTTP/2 preface prefix"
);
let result = tokio::time::timeout(Duration::from_secs(3), handler)
.await
.unwrap()
.unwrap();
assert!(result.is_ok());
let snapshot = beobachten.snapshot_text(Duration::from_secs(60));
assert!(snapshot.contains("[HTTP]"));
assert!(snapshot.contains(&format!("{}-1", peer.ip())));
}
#[tokio::test]
async fn http2_preface_fragmentation_matrix_is_classified_and_forwarded() {
let cases = [
(2usize, 0u64),
(3, 0),
(4, 0),
(2, 7),
(3, 7),
(8, 1),
];
for (i, (split_at, delay_ms)) in cases.into_iter().enumerate() {
let peer: SocketAddr = format!("198.51.100.{}:58{}", 140 + i, 100 + i)
.parse()
.unwrap();
run_http2_fragment_case(split_at, delay_ms, peer).await;
}
}
#[tokio::test]
async fn http2_preface_splitpoint_light_fuzz_classifies_http() {
for split_at in 2usize..=12 {
let delay_ms = if split_at % 3 == 0 { 7 } else { 1 };
let peer: SocketAddr = format!("198.51.101.{}:59{}", split_at, 10 + split_at)
.parse()
.unwrap();
run_http2_fragment_case(split_at, delay_ms, peer).await;
}
}

View File

@ -0,0 +1,150 @@
use super::*;
use crate::config::{UpstreamConfig, UpstreamType};
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::net::TcpListener;
use tokio::time::{Duration, sleep};
fn new_upstream_manager(stats: Arc<Stats>) -> Arc<UpstreamManager> {
Arc::new(UpstreamManager::new(
vec![UpstreamConfig {
upstream_type: UpstreamType::Direct {
interface: None,
bind_addresses: None,
},
weight: 1,
enabled: true,
scopes: String::new(),
selected_scope: String::new(),
}],
1,
1,
1,
1,
false,
stats,
))
}
async fn run_pipeline_prefetch_case(
prefetch_timeout_ms: u64,
delayed_tail_ms: u64,
peer: SocketAddr,
) -> (Vec<u8>, String) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let accept_task = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut got = Vec::new();
stream.read_to_end(&mut got).await.unwrap();
got
});
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = true;
cfg.general.beobachten_minutes = 1;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = backend_addr.port();
cfg.censorship.mask_classifier_prefetch_timeout_ms = prefetch_timeout_ms;
cfg.general.modes.classic = false;
cfg.general.modes.secure = false;
let config = Arc::new(cfg);
let stats = Arc::new(Stats::new());
let beobachten = Arc::new(BeobachtenStore::new());
let (server_side, mut client_side) = duplex(4096);
let handler = tokio::spawn(handle_client_stream(
server_side,
peer,
config,
stats.clone(),
new_upstream_manager(stats),
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
None,
Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
None,
Arc::new(UserIpTracker::new()),
beobachten.clone(),
false,
));
client_side.write_all(b"C").await.unwrap();
sleep(Duration::from_millis(delayed_tail_ms)).await;
client_side
.write_all(b"ONNECT example.org:443 HTTP/1.1\r\nHost: example.org\r\n\r\n")
.await
.unwrap();
client_side.shutdown().await.unwrap();
let forwarded = tokio::time::timeout(Duration::from_secs(3), accept_task)
.await
.unwrap()
.unwrap();
let result = tokio::time::timeout(Duration::from_secs(3), handler)
.await
.unwrap()
.unwrap();
assert!(result.is_ok());
let snapshot = beobachten.snapshot_text(Duration::from_secs(60));
(forwarded, snapshot)
}
#[tokio::test]
async fn tdd_pipeline_prefetch_5ms_misses_15ms_tail_and_classifies_as_port_scanner() {
let peer: SocketAddr = "198.51.100.171:58071".parse().unwrap();
let (forwarded, snapshot) = run_pipeline_prefetch_case(5, 15, peer).await;
assert!(
forwarded.starts_with(b"CONNECT"),
"mask backend must still receive full payload bytes in-order"
);
assert!(
snapshot.contains("[HTTP]") || snapshot.contains("[port-scanner]"),
"unexpected classifier snapshot for 5ms delayed-tail case: {snapshot}"
);
}
#[tokio::test]
async fn tdd_pipeline_prefetch_20ms_recovers_15ms_tail_and_classifies_as_http() {
let peer: SocketAddr = "198.51.100.172:58072".parse().unwrap();
let (forwarded, snapshot) = run_pipeline_prefetch_case(20, 15, peer).await;
assert!(
forwarded.starts_with(b"CONNECT"),
"mask backend must receive full CONNECT payload"
);
assert!(
snapshot.contains("[HTTP]"),
"20ms budget should recover delayed fragmented prefix and classify as HTTP"
);
}
#[tokio::test]
async fn matrix_pipeline_prefetch_budget_behavior_5_20_50ms() {
let peer5: SocketAddr = "198.51.100.173:58073".parse().unwrap();
let peer20: SocketAddr = "198.51.100.174:58074".parse().unwrap();
let peer50: SocketAddr = "198.51.100.175:58075".parse().unwrap();
let (_, snap5) = run_pipeline_prefetch_case(5, 35, peer5).await;
let (_, snap20) = run_pipeline_prefetch_case(20, 35, peer20).await;
let (_, snap50) = run_pipeline_prefetch_case(50, 35, peer50).await;
assert!(
snap5.contains("[HTTP]") || snap5.contains("[port-scanner]"),
"unexpected 5ms snapshot: {snap5}"
);
assert!(
snap20.contains("[HTTP]") || snap20.contains("[port-scanner]"),
"unexpected 20ms snapshot: {snap20}"
);
assert!(snap50.contains("[HTTP]"));
}

View File

@ -0,0 +1,82 @@
use super::*;
use tokio::io::{AsyncWriteExt, duplex};
use tokio::time::{Duration, sleep};
#[test]
fn prefetch_timeout_budget_reads_from_config() {
let mut cfg = ProxyConfig::default();
assert_eq!(
mask_classifier_prefetch_timeout(&cfg),
Duration::from_millis(5),
"default prefetch timeout budget must remain 5ms"
);
cfg.censorship.mask_classifier_prefetch_timeout_ms = 20;
assert_eq!(
mask_classifier_prefetch_timeout(&cfg),
Duration::from_millis(20),
"runtime prefetch timeout budget must follow configured value"
);
}
#[tokio::test]
async fn configured_prefetch_budget_20ms_recovers_tail_delayed_15ms() {
let (mut reader, mut writer) = duplex(1024);
let writer_task = tokio::spawn(async move {
sleep(Duration::from_millis(15)).await;
writer
.write_all(b"ONNECT example.org:443 HTTP/1.1\r\n")
.await
.expect("tail bytes must be writable");
writer.shutdown().await.expect("writer shutdown must succeed");
});
let mut initial_data = b"C".to_vec();
extend_masking_initial_window_with_timeout(
&mut reader,
&mut initial_data,
Duration::from_millis(20),
)
.await;
writer_task
.await
.expect("writer task must not panic in runtime timeout test");
assert!(
initial_data.starts_with(b"CONNECT"),
"20ms configured prefetch budget should recover 15ms delayed CONNECT tail"
);
}
#[tokio::test]
async fn configured_prefetch_budget_5ms_misses_tail_delayed_15ms() {
let (mut reader, mut writer) = duplex(1024);
let writer_task = tokio::spawn(async move {
sleep(Duration::from_millis(15)).await;
writer
.write_all(b"ONNECT example.org:443 HTTP/1.1\r\n")
.await
.expect("tail bytes must be writable");
writer.shutdown().await.expect("writer shutdown must succeed");
});
let mut initial_data = b"C".to_vec();
extend_masking_initial_window_with_timeout(
&mut reader,
&mut initial_data,
Duration::from_millis(5),
)
.await;
writer_task
.await
.expect("writer task must not panic in runtime timeout test");
assert!(
!initial_data.starts_with(b"CONNECT"),
"5ms configured prefetch budget should miss 15ms delayed CONNECT tail"
);
}

View File

@ -0,0 +1,261 @@
use super::*;
use crate::config::{UpstreamConfig, UpstreamType};
use crate::crypto::sha256_hmac;
use crate::protocol::constants::{HANDSHAKE_LEN, TLS_VERSION};
use crate::protocol::tls;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::net::TcpListener;
struct PipelineHarness {
config: Arc<ProxyConfig>,
stats: Arc<Stats>,
upstream_manager: Arc<UpstreamManager>,
replay_checker: Arc<ReplayChecker>,
buffer_pool: Arc<BufferPool>,
rng: Arc<SecureRandom>,
route_runtime: Arc<RouteRuntimeController>,
ip_tracker: Arc<UserIpTracker>,
beobachten: Arc<BeobachtenStore>,
}
fn build_harness(secret_hex: &str, mask_port: u16) -> PipelineHarness {
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = false;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = mask_port;
cfg.censorship.mask_proxy_protocol = 0;
cfg.access.ignore_time_skew = true;
cfg.access
.users
.insert("user".to_string(), secret_hex.to_string());
let config = Arc::new(cfg);
let stats = Arc::new(Stats::new());
let upstream_manager = Arc::new(UpstreamManager::new(
vec![UpstreamConfig {
upstream_type: UpstreamType::Direct {
interface: None,
bind_addresses: None,
},
weight: 1,
enabled: true,
scopes: String::new(),
selected_scope: String::new(),
}],
1,
1,
1,
1,
false,
stats.clone(),
));
PipelineHarness {
config,
stats,
upstream_manager,
replay_checker: Arc::new(ReplayChecker::new(256, Duration::from_secs(60))),
buffer_pool: Arc::new(BufferPool::new()),
rng: Arc::new(SecureRandom::new()),
route_runtime: Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
ip_tracker: Arc::new(UserIpTracker::new()),
beobachten: Arc::new(BeobachtenStore::new()),
}
}
fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec<u8> {
let total_len = 5 + tls_len;
let mut handshake = vec![fill; total_len];
handshake[0] = 0x16;
handshake[1] = 0x03;
handshake[2] = 0x01;
handshake[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes());
let session_id_len: usize = 32;
handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8;
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0);
let computed = sha256_hmac(secret, &handshake);
let mut digest = computed;
let ts = timestamp.to_le_bytes();
for i in 0..4 {
digest[28 + i] ^= ts[i];
}
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN]
.copy_from_slice(&digest);
handshake
}
fn wrap_tls_application_data(payload: &[u8]) -> Vec<u8> {
let mut record = Vec::with_capacity(5 + payload.len());
record.push(0x17);
record.extend_from_slice(&TLS_VERSION);
record.extend_from_slice(&(payload.len() as u16).to_be_bytes());
record.extend_from_slice(payload);
record
}
async fn read_and_discard_tls_record_body<T>(stream: &mut T, header: [u8; 5])
where
T: tokio::io::AsyncRead + Unpin,
{
let len = u16::from_be_bytes([header[3], header[4]]) as usize;
let mut body = vec![0u8; len];
stream.read_exact(&mut body).await.unwrap();
}
#[test]
fn empty_initial_data_prefetch_gate_is_fail_closed() {
assert!(
!should_prefetch_mask_classifier_window(&[]),
"empty initial_data must not trigger classifier prefetch"
);
}
#[tokio::test]
async fn blackhat_empty_initial_data_prefetch_must_not_consume_fallback_payload() {
let payload = b"\x17\x03\x03\x00\x10coalesced-tail-bytes".to_vec();
let (mut reader, mut writer) = duplex(1024);
writer.write_all(&payload).await.unwrap();
writer.shutdown().await.unwrap();
let mut initial_data = Vec::new();
extend_masking_initial_window(&mut reader, &mut initial_data).await;
assert!(
initial_data.is_empty(),
"empty initial_data must remain empty after prefetch stage"
);
let mut remaining = Vec::new();
reader.read_to_end(&mut remaining).await.unwrap();
assert_eq!(
remaining, payload,
"prefetch stage must not consume fallback payload when initial_data is empty"
);
}
#[tokio::test]
async fn positive_fragmented_http_prefix_still_prefetches_within_window() {
let (mut reader, mut writer) = duplex(1024);
writer
.write_all(b"NECT example.org:443 HTTP/1.1\r\n")
.await
.unwrap();
writer.shutdown().await.unwrap();
let mut initial_data = b"CON".to_vec();
extend_masking_initial_window(&mut reader, &mut initial_data).await;
assert!(
initial_data.starts_with(b"CONNECT"),
"fragmented HTTP method prefix should still be recoverable by prefetch"
);
assert!(
initial_data.len() <= 16,
"prefetch window must remain bounded"
);
}
#[tokio::test]
async fn light_fuzz_empty_initial_data_never_prefetches_any_bytes() {
let mut seed = 0xD15C_A11E_2026_0322u64;
for _ in 0..128 {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let len = ((seed & 0x3f) as usize).saturating_add(1);
let mut payload = vec![0u8; len];
for (idx, byte) in payload.iter_mut().enumerate() {
*byte = (seed as u8).wrapping_add(idx as u8).wrapping_mul(17);
}
let (mut reader, mut writer) = duplex(1024);
writer.write_all(&payload).await.unwrap();
writer.shutdown().await.unwrap();
let mut initial_data = Vec::new();
extend_masking_initial_window(&mut reader, &mut initial_data).await;
assert!(initial_data.is_empty());
let mut remaining = Vec::new();
reader.read_to_end(&mut remaining).await.unwrap();
assert_eq!(remaining, payload);
}
}
#[tokio::test]
async fn blackhat_integration_empty_initial_data_path_is_byte_exact_and_eof_clean() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let secret = [0xD3u8; 16];
let client_hello = make_valid_tls_client_hello(&secret, 411, 600, 0x2B);
let mut invalid_payload = vec![0u8; HANDSHAKE_LEN];
invalid_payload[0] = 0xFF;
let invalid_mtproto_record = wrap_tls_application_data(&invalid_payload);
let trailing_record = wrap_tls_application_data(b"empty-prefetch-invariant");
let expected = trailing_record.clone();
let accept_task = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut got = vec![0u8; expected.len()];
stream.read_exact(&mut got).await.unwrap();
assert_eq!(got, expected);
let mut one = [0u8; 1];
let n = stream.read(&mut one).await.unwrap();
assert_eq!(
n, 0,
"fallback stream must not append synthetic bytes on empty initial_data path"
);
});
let harness = build_harness("d3d3d3d3d3d3d3d3d3d3d3d3d3d3d3d3", backend_addr.port());
let (server_side, mut client_side) = duplex(131072);
let handler = tokio::spawn(handle_client_stream(
server_side,
"198.51.100.245:56145".parse().unwrap(),
harness.config,
harness.stats,
harness.upstream_manager,
harness.replay_checker,
harness.buffer_pool,
harness.rng,
None,
harness.route_runtime,
None,
harness.ip_tracker,
harness.beobachten,
false,
));
client_side.write_all(&client_hello).await.unwrap();
let mut head = [0u8; 5];
client_side.read_exact(&mut head).await.unwrap();
assert_eq!(head[0], 0x16);
read_and_discard_tls_record_body(&mut client_side, head).await;
client_side.write_all(&invalid_mtproto_record).await.unwrap();
client_side.write_all(&trailing_record).await.unwrap();
client_side.shutdown().await.unwrap();
tokio::time::timeout(Duration::from_secs(3), accept_task)
.await
.unwrap()
.unwrap();
let _ = tokio::time::timeout(Duration::from_secs(3), handler)
.await
.unwrap()
.unwrap();
}

View File

@ -0,0 +1,70 @@
use super::*;
use tokio::io::{AsyncWriteExt, duplex};
use tokio::time::{Duration, advance, sleep};
async fn run_strict_prefetch_case(prefetch_ms: u64, tail_delay_ms: u64) -> Vec<u8> {
let (mut reader, mut writer) = duplex(1024);
let writer_task = tokio::spawn(async move {
sleep(Duration::from_millis(tail_delay_ms)).await;
let _ = writer.write_all(b"ONNECT example.org:443 HTTP/1.1\r\n").await;
let _ = writer.shutdown().await;
});
let mut initial_data = b"C".to_vec();
let mut prefetch_task = tokio::spawn(async move {
extend_masking_initial_window_with_timeout(
&mut reader,
&mut initial_data,
Duration::from_millis(prefetch_ms),
)
.await;
initial_data
});
tokio::task::yield_now().await;
if tail_delay_ms > 0 {
advance(Duration::from_millis(tail_delay_ms)).await;
tokio::task::yield_now().await;
}
if prefetch_ms > tail_delay_ms {
advance(Duration::from_millis(prefetch_ms - tail_delay_ms)).await;
tokio::task::yield_now().await;
}
let result = prefetch_task.await.expect("prefetch task must not panic");
writer_task.await.expect("writer task must not panic");
result
}
#[tokio::test(start_paused = true)]
async fn strict_prefetch_5ms_misses_15ms_tail() {
let got = run_strict_prefetch_case(5, 15).await;
assert_eq!(got, b"C".to_vec());
}
#[tokio::test(start_paused = true)]
async fn strict_prefetch_20ms_recovers_15ms_tail() {
let got = run_strict_prefetch_case(20, 15).await;
assert!(got.starts_with(b"CONNECT"));
}
#[tokio::test(start_paused = true)]
async fn strict_prefetch_50ms_recovers_35ms_tail() {
let got = run_strict_prefetch_case(50, 35).await;
assert!(got.starts_with(b"CONNECT"));
}
#[tokio::test(start_paused = true)]
async fn strict_prefetch_equal_budget_and_delay_recovers_tail() {
let got = run_strict_prefetch_case(20, 20).await;
assert!(got.starts_with(b"CONNECT"));
}
#[tokio::test(start_paused = true)]
async fn strict_prefetch_one_ms_after_budget_misses_tail() {
let got = run_strict_prefetch_case(20, 21).await;
assert_eq!(got, b"C".to_vec());
}

View File

@ -0,0 +1,95 @@
use super::*;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt, duplex};
use tokio::time::{Duration, sleep, timeout};
async fn extend_masking_initial_window_with_budget<R>(
reader: &mut R,
initial_data: &mut Vec<u8>,
prefetch_timeout: Duration,
) where
R: AsyncRead + Unpin,
{
if !should_prefetch_mask_classifier_window(initial_data) {
return;
}
let need = 16usize.saturating_sub(initial_data.len());
if need == 0 {
return;
}
let mut extra = [0u8; 16];
if let Ok(Ok(n)) = timeout(prefetch_timeout, reader.read(&mut extra[..need])).await
&& n > 0
{
initial_data.extend_from_slice(&extra[..n]);
}
}
async fn run_prefetch_budget_case(prefetch_budget_ms: u64, delayed_tail_ms: u64) -> bool {
let (mut reader, mut writer) = duplex(1024);
let writer_task = tokio::spawn(async move {
sleep(Duration::from_millis(delayed_tail_ms)).await;
writer
.write_all(b"ONNECT example.org:443 HTTP/1.1\r\n")
.await
.expect("tail bytes must be writable");
writer.shutdown().await.expect("writer shutdown must succeed");
});
let mut initial_data = b"C".to_vec();
extend_masking_initial_window_with_budget(
&mut reader,
&mut initial_data,
Duration::from_millis(prefetch_budget_ms),
)
.await;
writer_task
.await
.expect("writer task must not panic during matrix case");
initial_data.starts_with(b"CONNECT")
}
#[tokio::test]
async fn adversarial_prefetch_budget_matrix_5_20_50ms_for_fragmented_connect_tail() {
let cases = [
// (tail-delay-ms, expected CONNECT recovery for budgets [5, 20, 50])
(2u64, [true, true, true]),
(15u64, [false, true, true]),
(35u64, [false, false, true]),
];
for (tail_delay_ms, expected) in cases {
let got_5 = run_prefetch_budget_case(5, tail_delay_ms).await;
let got_20 = run_prefetch_budget_case(20, tail_delay_ms).await;
let got_50 = run_prefetch_budget_case(50, tail_delay_ms).await;
assert_eq!(
got_5, expected[0],
"5ms prefetch budget mismatch for tail delay {}ms",
tail_delay_ms
);
assert_eq!(
got_20, expected[1],
"20ms prefetch budget mismatch for tail delay {}ms",
tail_delay_ms
);
assert_eq!(
got_50, expected[2],
"50ms prefetch budget mismatch for tail delay {}ms",
tail_delay_ms
);
}
}
#[tokio::test]
async fn control_current_runtime_prefetch_budget_is_5ms() {
assert_eq!(
MASK_CLASSIFIER_PREFETCH_TIMEOUT,
Duration::from_millis(5),
"matrix assumptions require current runtime prefetch budget to stay at 5ms"
);
}

View File

@ -0,0 +1,161 @@
use super::*;
use crate::config::{UpstreamConfig, UpstreamType};
use crate::crypto::sha256_hmac;
use crate::protocol::constants::{HANDSHAKE_LEN, TLS_VERSION};
use crate::protocol::tls;
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::time::{Duration, Instant};
fn new_upstream_manager(stats: Arc<Stats>) -> Arc<UpstreamManager> {
Arc::new(UpstreamManager::new(
vec![UpstreamConfig {
upstream_type: UpstreamType::Direct {
interface: None,
bind_addresses: None,
},
weight: 1,
enabled: true,
scopes: String::new(),
selected_scope: String::new(),
}],
1,
1,
1,
1,
false,
stats,
))
}
fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec<u8> {
let total_len = 5 + tls_len;
let mut handshake = vec![fill; total_len];
handshake[0] = 0x16;
handshake[1] = 0x03;
handshake[2] = 0x01;
handshake[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes());
let session_id_len: usize = 32;
handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8;
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0);
let computed = sha256_hmac(secret, &handshake);
let mut digest = computed;
let ts = timestamp.to_le_bytes();
for i in 0..4 {
digest[28 + i] ^= ts[i];
}
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN]
.copy_from_slice(&digest);
handshake
}
async fn run_replay_candidate_session(
replay_checker: Arc<ReplayChecker>,
hello: &[u8],
peer: SocketAddr,
drive_mtproto_fail: bool,
) -> Duration {
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = false;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = 1;
cfg.censorship.mask_timing_normalization_enabled = false;
cfg.access.ignore_time_skew = true;
cfg.access
.users
.insert("user".to_string(), "abababababababababababababababab".to_string());
let config = Arc::new(cfg);
let stats = Arc::new(Stats::new());
let beobachten = Arc::new(BeobachtenStore::new());
let (server_side, mut client_side) = duplex(65536);
let started = Instant::now();
let task = tokio::spawn(handle_client_stream(
server_side,
peer,
config,
stats.clone(),
new_upstream_manager(stats),
replay_checker,
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
None,
Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
None,
Arc::new(UserIpTracker::new()),
beobachten,
false,
));
client_side.write_all(hello).await.unwrap();
if drive_mtproto_fail {
let mut server_hello_head = [0u8; 5];
client_side.read_exact(&mut server_hello_head).await.unwrap();
assert_eq!(server_hello_head[0], 0x16);
let body_len = u16::from_be_bytes([server_hello_head[3], server_hello_head[4]]) as usize;
let mut body = vec![0u8; body_len];
client_side.read_exact(&mut body).await.unwrap();
let mut invalid_mtproto_record = Vec::with_capacity(5 + HANDSHAKE_LEN);
invalid_mtproto_record.push(0x17);
invalid_mtproto_record.extend_from_slice(&TLS_VERSION);
invalid_mtproto_record.extend_from_slice(&(HANDSHAKE_LEN as u16).to_be_bytes());
invalid_mtproto_record.extend_from_slice(&vec![0u8; HANDSHAKE_LEN]);
client_side.write_all(&invalid_mtproto_record).await.unwrap();
client_side
.write_all(b"GET /replay-fallback HTTP/1.1\r\nHost: x\r\n\r\n")
.await
.unwrap();
}
client_side.shutdown().await.unwrap();
let _ = tokio::time::timeout(Duration::from_secs(4), task)
.await
.unwrap()
.unwrap();
started.elapsed()
}
#[tokio::test]
async fn replay_reject_still_honors_masking_timing_budget() {
let replay_checker = Arc::new(ReplayChecker::new(256, Duration::from_secs(60)));
let hello = make_valid_tls_client_hello(&[0xAB; 16], 7, 600, 0x51);
let seed_elapsed = run_replay_candidate_session(
Arc::clone(&replay_checker),
&hello,
"198.51.100.201:58001".parse().unwrap(),
true,
)
.await;
assert!(
seed_elapsed >= Duration::from_millis(40) && seed_elapsed < Duration::from_millis(250),
"seed replay-candidate run must honor masking timing budget without unbounded delay"
);
let replay_elapsed = run_replay_candidate_session(
Arc::clone(&replay_checker),
&hello,
"198.51.100.202:58002".parse().unwrap(),
false,
)
.await;
assert!(
replay_elapsed >= Duration::from_millis(40)
&& replay_elapsed < Duration::from_millis(250),
"replay rejection path must still satisfy masking timing budget without unbounded DB/CPU delay"
);
}

View File

@ -8,7 +8,7 @@ use crate::proxy::handshake::HandshakeSuccess;
use crate::stream::{CryptoReader, CryptoWriter};
use crate::transport::proxy_protocol::ProxyProtocolV1Builder;
use rand::rngs::StdRng;
use rand::RngCore;
use rand::Rng;
use rand::SeedableRng;
use std::net::Ipv4Addr;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};

View File

@ -0,0 +1,93 @@
use super::*;
use std::collections::HashSet;
use std::net::{IpAddr, Ipv4Addr};
use std::time::{Duration, Instant};
fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> {
auth_probe_test_lock()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
#[test]
fn adversarial_large_state_offsets_escape_first_scan_window() {
let _guard = auth_probe_test_guard();
let base = Instant::now();
let state_len = 65_536usize;
let scan_limit = 1_024usize;
let mut saw_offset_outside_first_window = false;
for i in 0..8_192u64 {
let ip = IpAddr::V4(Ipv4Addr::new(
((i >> 16) & 0xff) as u8,
((i >> 8) & 0xff) as u8,
(i & 0xff) as u8,
((i.wrapping_mul(131)) & 0xff) as u8,
));
let now = base + Duration::from_nanos(i);
let start = auth_probe_scan_start_offset(ip, now, state_len, scan_limit);
if start >= scan_limit {
saw_offset_outside_first_window = true;
break;
}
}
assert!(
saw_offset_outside_first_window,
"scan start offset must cover the full auth-probe state, not only the first scan window"
);
}
#[test]
fn stress_large_state_offsets_cover_many_scan_windows() {
let _guard = auth_probe_test_guard();
let base = Instant::now();
let state_len = 65_536usize;
let scan_limit = 1_024usize;
let mut covered_windows = HashSet::new();
for i in 0..16_384u64 {
let ip = IpAddr::V4(Ipv4Addr::new(
((i >> 16) & 0xff) as u8,
((i >> 8) & 0xff) as u8,
(i & 0xff) as u8,
((i.wrapping_mul(17)) & 0xff) as u8,
));
let now = base + Duration::from_micros(i);
let start = auth_probe_scan_start_offset(ip, now, state_len, scan_limit);
covered_windows.insert(start / scan_limit);
}
assert!(
covered_windows.len() >= 16,
"eviction scan must not collapse to a tiny hot zone; covered windows={} out of {}",
covered_windows.len(),
state_len / scan_limit
);
}
#[test]
fn light_fuzz_offset_always_stays_inside_state_len() {
let _guard = auth_probe_test_guard();
let mut seed = 0xC0FF_EE12_3456_789Au64;
let base = Instant::now();
for _ in 0..8_192usize {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let ip = IpAddr::V4(Ipv4Addr::new(
(seed >> 24) as u8,
(seed >> 16) as u8,
(seed >> 8) as u8,
seed as u8,
));
let state_len = ((seed >> 16) as usize % 200_000).saturating_add(1);
let scan_limit = ((seed >> 40) as usize % 2_048).saturating_add(1);
let now = base + Duration::from_nanos(seed & 0x0fff);
let start = auth_probe_scan_start_offset(ip, now, state_len, scan_limit);
assert!(start < state_len, "scan offset must stay inside state length");
}
}

View File

@ -0,0 +1,99 @@
use super::*;
use std::net::{IpAddr, Ipv4Addr};
use std::time::{Duration, Instant};
fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> {
auth_probe_test_lock()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
#[test]
fn edge_zero_state_len_yields_zero_start_offset() {
let _guard = auth_probe_test_guard();
let ip = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 44));
let now = Instant::now();
assert_eq!(
auth_probe_scan_start_offset(ip, now, 0, 16),
0,
"empty map must not produce non-zero scan offset"
);
}
#[test]
fn adversarial_large_state_must_allow_start_offset_outside_scan_budget_window() {
let _guard = auth_probe_test_guard();
let base = Instant::now();
let scan_limit = 16usize;
let state_len = 65_536usize;
let mut saw_offset_outside_window = false;
for i in 0..2048u32 {
let ip = IpAddr::V4(Ipv4Addr::new(
203,
((i >> 16) & 0xff) as u8,
((i >> 8) & 0xff) as u8,
(i & 0xff) as u8,
));
let now = base + Duration::from_micros(i as u64);
let start = auth_probe_scan_start_offset(ip, now, state_len, scan_limit);
assert!(
start < state_len,
"start offset must stay within state length; start={start}, len={state_len}"
);
if start >= scan_limit {
saw_offset_outside_window = true;
break;
}
}
assert!(
saw_offset_outside_window,
"large-state eviction must sample beyond the first scan window"
);
}
#[test]
fn positive_state_smaller_than_scan_limit_caps_to_state_len() {
let _guard = auth_probe_test_guard();
let ip = IpAddr::V4(Ipv4Addr::new(192, 0, 2, 17));
let now = Instant::now();
for state_len in 1..32usize {
let start = auth_probe_scan_start_offset(ip, now, state_len, 64);
assert!(
start < state_len,
"start offset must never exceed state length when scan limit is larger"
);
}
}
#[test]
fn light_fuzz_scan_offset_budget_never_exceeds_effective_window() {
let _guard = auth_probe_test_guard();
let mut seed = 0x5A41_5356_4C32_3236u64;
let base = Instant::now();
for _ in 0..4096 {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let ip = IpAddr::V4(Ipv4Addr::new(
(seed >> 24) as u8,
(seed >> 16) as u8,
(seed >> 8) as u8,
seed as u8,
));
let state_len = ((seed >> 8) as usize % 131_072).saturating_add(1);
let scan_limit = ((seed >> 32) as usize % 512).saturating_add(1);
let now = base + Duration::from_nanos(seed & 0xffff);
let start = auth_probe_scan_start_offset(ip, now, state_len, scan_limit);
assert!(
start < state_len,
"scan offset must stay inside state length"
);
}
}

View File

@ -0,0 +1,116 @@
use super::*;
use std::collections::HashSet;
use std::net::{IpAddr, Ipv4Addr};
use std::time::{Duration, Instant};
fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> {
auth_probe_test_lock()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
#[test]
fn positive_same_ip_moving_time_yields_diverse_scan_offsets() {
let _guard = auth_probe_test_guard();
let ip = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 77));
let base = Instant::now();
let mut uniq = HashSet::new();
for i in 0..512u64 {
let now = base + Duration::from_nanos(i);
let offset = auth_probe_scan_start_offset(ip, now, 65_536, 16);
uniq.insert(offset);
}
assert!(
uniq.len() >= 256,
"offset randomization collapsed unexpectedly for same-ip moving-time samples (uniq={})",
uniq.len()
);
}
#[test]
fn adversarial_many_ips_same_time_spreads_offsets_without_bias_collapse() {
let _guard = auth_probe_test_guard();
let now = Instant::now();
let mut uniq = HashSet::new();
for i in 0..1024u32 {
let ip = IpAddr::V4(Ipv4Addr::new(
(i >> 16) as u8,
(i >> 8) as u8,
i as u8,
(255 - (i as u8)),
));
uniq.insert(auth_probe_scan_start_offset(ip, now, 65_536, 16));
}
assert!(
uniq.len() >= 512,
"scan offset distribution collapsed unexpectedly across adversarial peer set (uniq={})",
uniq.len()
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn stress_parallel_failure_churn_under_saturation_remains_capped_and_live() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let start = Instant::now();
let mut workers = Vec::new();
for worker in 0..8u8 {
workers.push(tokio::spawn(async move {
for i in 0..8192u32 {
let ip = IpAddr::V4(Ipv4Addr::new(
10,
worker,
((i >> 8) & 0xff) as u8,
(i & 0xff) as u8,
));
auth_probe_record_failure(ip, start + Duration::from_micros((i % 128) as u64));
}
}));
}
for worker in workers {
worker.await.expect("saturation worker must not panic");
}
assert!(
auth_probe_state_map().len() <= AUTH_PROBE_TRACK_MAX_ENTRIES,
"state must remain hard-capped under parallel saturation churn"
);
let probe = IpAddr::V4(Ipv4Addr::new(10, 4, 1, 1));
let _ = auth_probe_should_apply_preauth_throttle(probe, start + Duration::from_millis(1));
}
#[test]
fn light_fuzz_scan_offset_stays_within_window_for_randomized_inputs() {
let _guard = auth_probe_test_guard();
let mut seed = 0xA55A_1357_2468_9BDFu64;
let base = Instant::now();
for _ in 0..8192 {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let ip = IpAddr::V4(Ipv4Addr::new(
(seed >> 24) as u8,
(seed >> 16) as u8,
(seed >> 8) as u8,
seed as u8,
));
let state_len = ((seed >> 8) as usize % 200_000).saturating_add(1);
let scan_limit = ((seed >> 40) as usize % 1024).saturating_add(1);
let now = base + Duration::from_nanos(seed & 0x1fff);
let offset = auth_probe_scan_start_offset(ip, now, state_len, scan_limit);
assert!(
offset < state_len,
"scan offset must always remain inside state length"
);
}
}

View File

@ -0,0 +1,42 @@
use super::*;
fn handshake_source() -> &'static str {
include_str!("../handshake.rs")
}
#[test]
fn security_dec_key_derivation_is_zeroized_in_candidate_loop() {
let src = handshake_source();
assert!(
src.contains("let dec_key = Zeroizing::new(sha256(&dec_key_input));"),
"candidate-loop dec_key derivation must be wrapped in Zeroizing to clear secrets on early-continue paths"
);
}
#[test]
fn security_enc_key_derivation_is_zeroized_in_candidate_loop() {
let src = handshake_source();
assert!(
src.contains("let enc_key = Zeroizing::new(sha256(&enc_key_input));"),
"candidate-loop enc_key derivation must be wrapped in Zeroizing to clear secrets on early-continue paths"
);
}
#[test]
fn security_aes_ctr_initialization_uses_zeroizing_references() {
let src = handshake_source();
assert!(
src.contains("let mut decryptor = AesCtr::new(&dec_key, dec_iv);")
&& src.contains("let encryptor = AesCtr::new(&enc_key, enc_iv);"),
"AES-CTR initialization must use Zeroizing key wrappers directly without creating extra plain key variables"
);
}
#[test]
fn security_success_struct_copies_out_of_zeroizing_wrappers() {
let src = handshake_source();
assert!(
src.contains("dec_key: *dec_key,") && src.contains("enc_key: *enc_key,"),
"HandshakeSuccess construction must copy from Zeroizing wrappers so loop-local key material is dropped and zeroized"
);
}

View File

@ -1,7 +1,7 @@
use super::*;
use crate::crypto::{sha256, sha256_hmac, AesCtr};
use crate::protocol::constants::{ProtoTag, RESERVED_NONCE_BEGINNINGS, RESERVED_NONCE_FIRST_BYTES};
use rand::{RngExt, SeedableRng};
use rand::{Rng, SeedableRng};
use rand::rngs::StdRng;
use std::collections::HashSet;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};

View File

@ -493,9 +493,12 @@ async fn timing_classifier_light_fuzz_pairwise_bucketed_accuracy_stays_bounded_u
];
let mut meaningful_improvement_seen = false;
let mut baseline_sum = 0.0f64;
let mut hardened_sum = 0.0f64;
let mut pair_count = 0usize;
let mut informative_baseline_sum = 0.0f64;
let mut informative_hardened_sum = 0.0f64;
let mut informative_pair_count = 0usize;
let mut low_info_baseline_sum = 0.0f64;
let mut low_info_hardened_sum = 0.0f64;
let mut low_info_pair_count = 0usize;
let acc_quant_step = 1.0 / (2 * SAMPLE_COUNT) as f64;
let tolerated_pair_regression = acc_quant_step + 0.03;
@ -522,6 +525,16 @@ async fn timing_classifier_light_fuzz_pairwise_bucketed_accuracy_stays_bounded_u
hardened_acc <= baseline_acc + tolerated_pair_regression,
"normalization should not materially worsen informative pair: baseline={baseline_acc:.3} hardened={hardened_acc:.3} tolerated={tolerated_pair_regression:.3}"
);
informative_baseline_sum += baseline_acc;
informative_hardened_sum += hardened_acc;
informative_pair_count += 1;
} else {
// Low-information pairs (near-random baseline separability) are expected
// to exhibit quantized jitter at low sample counts; do not fold them into
// strict average-regression checks used for informative side-channel signal.
low_info_baseline_sum += baseline_acc;
low_info_hardened_sum += hardened_acc;
low_info_pair_count += 1;
}
println!(
@ -532,19 +545,30 @@ async fn timing_classifier_light_fuzz_pairwise_bucketed_accuracy_stays_bounded_u
meaningful_improvement_seen = true;
}
baseline_sum += baseline_acc;
hardened_sum += hardened_acc;
pair_count += 1;
}
let baseline_avg = baseline_sum / pair_count as f64;
let hardened_avg = hardened_sum / pair_count as f64;
assert!(
informative_pair_count > 0,
"expected at least one informative pair for timing-separability guard"
);
let informative_baseline_avg = informative_baseline_sum / informative_pair_count as f64;
let informative_hardened_avg = informative_hardened_sum / informative_pair_count as f64;
assert!(
hardened_avg <= baseline_avg + 0.10,
"normalization should not materially increase average pairwise separability: baseline_avg={baseline_avg:.3} hardened_avg={hardened_avg:.3}"
informative_hardened_avg <= informative_baseline_avg + 0.10,
"normalization should not materially increase informative average separability: baseline_avg={informative_baseline_avg:.3} hardened_avg={informative_hardened_avg:.3}"
);
if low_info_pair_count > 0 {
let low_info_baseline_avg = low_info_baseline_sum / low_info_pair_count as f64;
let low_info_hardened_avg = low_info_hardened_sum / low_info_pair_count as f64;
assert!(
low_info_hardened_avg <= low_info_baseline_avg + 0.40,
"normalization low-info average drift exceeded jitter budget: baseline_avg={low_info_baseline_avg:.3} hardened_avg={low_info_hardened_avg:.3}"
);
}
// Optional signal only: do not require improvement on every run because
// noisy CI schedulers can flatten pairwise differences at low sample counts.
let _ = meaningful_improvement_seen;

View File

@ -0,0 +1,122 @@
use super::*;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::task::{Context, Poll};
use tokio::io::AsyncRead;
use tokio::time::{Duration, timeout};
struct EndlessReader {
produced: Arc<AtomicUsize>,
}
impl AsyncRead for EndlessReader {
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
let len = buf.remaining().max(1);
let fill = vec![0xAA; len];
buf.put_slice(&fill);
self.produced.fetch_add(len, Ordering::Relaxed);
Poll::Ready(Ok(()))
}
}
#[test]
fn loop_guard_unspecified_bind_uses_interface_inventory() {
let local: SocketAddr = "0.0.0.0:443".parse().unwrap();
let resolved: SocketAddr = "192.168.44.10:443".parse().unwrap();
let interfaces = vec!["192.168.44.10".parse().unwrap()];
assert!(is_mask_target_local_listener_with_interfaces(
"mask.example",
443,
local,
Some(resolved),
&interfaces,
));
}
#[tokio::test]
async fn consume_client_data_stops_after_byte_cap_without_eof() {
let produced = Arc::new(AtomicUsize::new(0));
let reader = EndlessReader {
produced: Arc::clone(&produced),
};
let cap = 10_000usize;
consume_client_data(reader, cap).await;
let total = produced.load(Ordering::Relaxed);
assert!(
total >= cap,
"consume path must read at least up to cap before stopping"
);
assert!(
total <= cap + 8192,
"consume path must stop within one read chunk above cap"
);
}
#[test]
fn masking_beobachten_minutes_zero_fail_closes_to_minimum_ttl() {
let mut config = ProxyConfig::default();
config.general.beobachten = true;
config.general.beobachten_minutes = 0;
let ttl = masking_beobachten_ttl(&config);
assert_eq!(ttl, std::time::Duration::from_secs(60));
}
#[test]
fn timing_normalization_zero_floor_safety_net_defaults_to_mask_timeout() {
let mut config = ProxyConfig::default();
config.censorship.mask_timing_normalization_enabled = true;
config.censorship.mask_timing_normalization_floor_ms = 0;
config.censorship.mask_timing_normalization_ceiling_ms = 0;
let budget = mask_outcome_target_budget(&config);
assert_eq!(budget, MASK_TIMEOUT);
}
#[tokio::test]
async fn loop_guard_blocks_self_target_before_proxy_protocol_header_growth() {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let accept_task = tokio::spawn(async move {
timeout(Duration::from_millis(120), listener.accept())
.await
.is_ok()
});
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_unix_sock = None;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = backend_addr.port();
config.censorship.mask_proxy_protocol = 2;
let peer: SocketAddr = "203.0.113.251:55991".parse().unwrap();
let local_addr: SocketAddr = format!("0.0.0.0:{}", backend_addr.port()).parse().unwrap();
let beobachten = BeobachtenStore::new();
handle_bad_client(
tokio::io::empty(),
tokio::io::sink(),
b"GET / HTTP/1.1\r\n\r\n",
peer,
local_addr,
&config,
&beobachten,
)
.await;
let accepted = accept_task.await.unwrap();
assert!(
!accepted,
"loop guard must fail closed before any recursive PROXY protocol amplification"
);
}

View File

@ -0,0 +1,16 @@
use super::*;
#[test]
fn detect_client_type_recognizes_extended_http_probe_verbs() {
assert_eq!(detect_client_type(b"CONNECT / HTTP/1.1\r\n"), "HTTP");
assert_eq!(detect_client_type(b"TRACE / HTTP/1.1\r\n"), "HTTP");
assert_eq!(detect_client_type(b"PATCH / HTTP/1.1\r\n"), "HTTP");
}
#[test]
fn detect_client_type_recognizes_fragmented_http_method_prefixes() {
assert_eq!(detect_client_type(b"CO"), "HTTP");
assert_eq!(detect_client_type(b"CON"), "HTTP");
assert_eq!(detect_client_type(b"TR"), "HTTP");
assert_eq!(detect_client_type(b"PAT"), "HTTP");
}

View File

@ -0,0 +1,127 @@
use super::*;
use crate::network::dns_overrides::install_entries;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::time::{Duration, Instant, timeout};
async fn run_connect_failure_case(
host: &str,
port: u16,
timing_normalization_enabled: bool,
peer: SocketAddr,
) -> Duration {
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_unix_sock = None;
config.censorship.mask_host = Some(host.to_string());
config.censorship.mask_port = port;
config.censorship.mask_timing_normalization_enabled = timing_normalization_enabled;
config.censorship.mask_timing_normalization_floor_ms = 120;
config.censorship.mask_timing_normalization_ceiling_ms = 120;
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let beobachten = BeobachtenStore::new();
let probe = b"CONNECT example.org:443 HTTP/1.1\r\nHost: example.org\r\n\r\n";
let (mut client_writer, client_reader) = duplex(1024);
let (mut client_visible_reader, client_visible_writer) = duplex(1024);
let started = Instant::now();
let task = tokio::spawn(async move {
handle_bad_client(
client_reader,
client_visible_writer,
probe,
peer,
local_addr,
&config,
&beobachten,
)
.await;
});
client_writer.shutdown().await.unwrap();
timeout(Duration::from_secs(4), task)
.await
.unwrap()
.unwrap();
let mut buf = [0u8; 1];
let n = timeout(Duration::from_secs(1), client_visible_reader.read(&mut buf))
.await
.unwrap()
.unwrap();
assert_eq!(n, 0, "connect-failure path must close client-visible writer");
started.elapsed()
}
#[tokio::test]
async fn connect_failure_refusal_close_behavior_matrix() {
let temp_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let unused_port = temp_listener.local_addr().unwrap().port();
drop(temp_listener);
for (idx, timing_normalization_enabled) in [false, true].into_iter().enumerate() {
let peer: SocketAddr = format!("203.0.113.210:{}", 54100 + idx as u16)
.parse()
.unwrap();
let elapsed = run_connect_failure_case(
"127.0.0.1",
unused_port,
timing_normalization_enabled,
peer,
)
.await;
if timing_normalization_enabled {
assert!(
elapsed >= Duration::from_millis(110) && elapsed < Duration::from_millis(250),
"normalized refusal path must honor configured timing envelope without stalling"
);
} else {
assert!(
elapsed >= Duration::from_millis(40) && elapsed < Duration::from_millis(150),
"non-normalized refusal path must honor baseline connect budget without stalling"
);
}
}
}
#[tokio::test]
async fn connect_failure_overridden_hostname_close_behavior_matrix() {
let temp_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let unused_port = temp_listener.local_addr().unwrap().port();
drop(temp_listener);
// Make hostname resolution deterministic in tests so timing ceilings are meaningful.
install_entries(&[format!("mask.invalid:{}:127.0.0.1", unused_port)]).unwrap();
for (idx, timing_normalization_enabled) in [false, true].into_iter().enumerate() {
let peer: SocketAddr = format!("203.0.113.220:{}", 54200 + idx as u16)
.parse()
.unwrap();
let elapsed = run_connect_failure_case(
"mask.invalid",
unused_port,
timing_normalization_enabled,
peer,
)
.await;
if timing_normalization_enabled {
assert!(
elapsed >= Duration::from_millis(110) && elapsed < Duration::from_millis(250),
"normalized overridden-host path must honor configured timing envelope without stalling"
);
} else {
assert!(
elapsed >= Duration::from_millis(40) && elapsed < Duration::from_millis(150),
"non-normalized overridden-host path must honor baseline connect budget without stalling"
);
}
}
install_entries(&[]).unwrap();
}

View File

@ -0,0 +1,85 @@
use super::*;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Instant;
use tokio::io::{AsyncRead, ReadBuf};
struct OneByteThenStall {
sent: bool,
}
impl AsyncRead for OneByteThenStall {
fn poll_read(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
if !self.sent {
self.sent = true;
buf.put_slice(&[0x42]);
Poll::Ready(Ok(()))
} else {
Poll::Pending
}
}
}
#[tokio::test]
async fn stalling_client_terminates_at_idle_not_relay_timeout() {
let reader = OneByteThenStall { sent: false };
let started = Instant::now();
let result = tokio::time::timeout(
MASK_RELAY_TIMEOUT,
consume_client_data(reader, MASK_BUFFER_SIZE * 4),
)
.await;
assert!(
result.is_ok(),
"consume_client_data should complete by per-read idle timeout, not hit relay timeout"
);
let elapsed = started.elapsed();
assert!(
elapsed >= (MASK_RELAY_IDLE_TIMEOUT / 2),
"consume_client_data returned too quickly for idle-timeout path: {elapsed:?}"
);
assert!(
elapsed < MASK_RELAY_TIMEOUT,
"consume_client_data waited full relay timeout ({elapsed:?}); \
per-read idle timeout is missing"
);
}
#[tokio::test]
async fn fast_reader_drains_to_eof() {
let data = vec![0xAAu8; 32 * 1024];
let reader = std::io::Cursor::new(data);
tokio::time::timeout(MASK_RELAY_TIMEOUT, consume_client_data(reader, usize::MAX))
.await
.expect("consume_client_data did not complete for fast EOF reader");
}
#[tokio::test]
async fn io_error_terminates_cleanly() {
struct ErrReader;
impl AsyncRead for ErrReader {
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::ConnectionReset,
"simulated reset",
)))
}
}
tokio::time::timeout(MASK_RELAY_TIMEOUT, consume_client_data(ErrReader, usize::MAX))
.await
.expect("consume_client_data did not return on I/O error");
}

View File

@ -0,0 +1,64 @@
use super::*;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Instant;
use tokio::io::{AsyncRead, ReadBuf};
use tokio::task::JoinSet;
struct OneByteThenStall {
sent: bool,
}
impl AsyncRead for OneByteThenStall {
fn poll_read(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
if !self.sent {
self.sent = true;
buf.put_slice(&[0xAA]);
Poll::Ready(Ok(()))
} else {
Poll::Pending
}
}
}
#[tokio::test]
async fn consume_stall_stress_finishes_within_idle_budget() {
let mut set = JoinSet::new();
let started = Instant::now();
for _ in 0..64 {
set.spawn(async {
tokio::time::timeout(
MASK_RELAY_TIMEOUT,
consume_client_data(OneByteThenStall { sent: false }, usize::MAX),
)
.await
.expect("consume_client_data exceeded relay timeout under stall load");
});
}
while let Some(res) = set.join_next().await {
res.unwrap();
}
// Under test constants idle=100ms, relay=200ms. 64 concurrent tasks stalling
// for 100ms should complete well under a strict 600ms boundary.
assert!(
started.elapsed() < MASK_RELAY_TIMEOUT * 3,
"stall stress batch completed too slowly; possible async executor starvation or head-of-line blocking"
);
}
#[tokio::test]
async fn consume_zero_cap_returns_immediately() {
let started = Instant::now();
consume_client_data(tokio::io::empty(), 0).await;
assert!(
started.elapsed() < MASK_RELAY_IDLE_TIMEOUT,
"zero byte cap must return immediately"
);
}

View File

@ -0,0 +1,217 @@
use super::*;
use tokio::io::{AsyncWriteExt, duplex};
use tokio::time::{Duration, Instant, timeout};
fn make_self_target_config(
timing_normalization_enabled: bool,
floor_ms: u64,
ceiling_ms: u64,
beobachten_enabled: bool,
) -> ProxyConfig {
let mut config = ProxyConfig::default();
config.general.beobachten = beobachten_enabled;
config.general.beobachten_minutes = 5;
config.censorship.mask = true;
config.censorship.mask_unix_sock = None;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = 443;
config.censorship.mask_timing_normalization_enabled = timing_normalization_enabled;
config.censorship.mask_timing_normalization_floor_ms = floor_ms;
config.censorship.mask_timing_normalization_ceiling_ms = ceiling_ms;
config
}
async fn run_self_target_refusal(
config: ProxyConfig,
peer: SocketAddr,
initial: &'static [u8],
) -> Duration {
let beobachten = BeobachtenStore::new();
let local_addr: SocketAddr = "127.0.0.1:443".parse().expect("valid local addr");
let (mut client, server) = duplex(1024);
let started = Instant::now();
let task = tokio::spawn(async move {
handle_bad_client(server, tokio::io::sink(), initial, peer, local_addr, &config, &beobachten)
.await;
});
client
.shutdown()
.await
.expect("client shutdown must succeed");
timeout(Duration::from_secs(3), task)
.await
.expect("self-target refusal must complete in bounded time")
.expect("self-target refusal task must not panic");
started.elapsed()
}
#[tokio::test]
async fn positive_self_target_refusal_honors_normalization_floor() {
let config = make_self_target_config(true, 120, 120, false);
let peer: SocketAddr = "203.0.113.41:54041".parse().expect("valid peer");
let elapsed = run_self_target_refusal(config, peer, b"GET / HTTP/1.1\r\n\r\n").await;
assert!(
elapsed >= Duration::from_millis(110) && elapsed < Duration::from_millis(260),
"normalized self-target refusal must stay within expected envelope"
);
}
#[tokio::test]
async fn negative_non_normalized_refusal_does_not_sleep_to_large_floor() {
let config = make_self_target_config(false, 240, 240, false);
let peer: SocketAddr = "203.0.113.42:54042".parse().expect("valid peer");
let elapsed = run_self_target_refusal(config, peer, b"GET / HTTP/1.1\r\n\r\n").await;
assert!(
elapsed < Duration::from_millis(180),
"non-normalized path must not inherit normalization floor delays"
);
}
#[tokio::test]
async fn edge_ceiling_below_floor_uses_floor_fail_closed() {
let config = make_self_target_config(true, 140, 80, false);
let peer: SocketAddr = "203.0.113.43:54043".parse().expect("valid peer");
let elapsed = run_self_target_refusal(config, peer, b"GET / HTTP/1.1\r\n\r\n").await;
assert!(
elapsed >= Duration::from_millis(130) && elapsed < Duration::from_millis(280),
"ceiling<floor must clamp to floor to preserve deterministic normalization"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn adversarial_blackhat_parallel_probes_remain_bounded_and_uniform() {
let workers = 24usize;
let mut tasks = Vec::with_capacity(workers);
for idx in 0..workers {
tasks.push(tokio::spawn(async move {
let cfg = make_self_target_config(true, 110, 140, false);
let peer: SocketAddr = format!("203.0.113.50:{}", 54100 + idx as u16)
.parse()
.expect("valid peer");
run_self_target_refusal(cfg, peer, b"GET /x HTTP/1.1\r\n\r\n").await
}));
}
let mut min = Duration::from_secs(60);
let mut max = Duration::from_millis(0);
for task in tasks {
let elapsed = task.await.expect("probe task must not panic");
if elapsed < min {
min = elapsed;
}
if elapsed > max {
max = elapsed;
}
assert!(
elapsed >= Duration::from_millis(100) && elapsed < Duration::from_millis(320),
"parallel probe latency must stay bounded under normalization"
);
}
assert!(
max.saturating_sub(min) <= Duration::from_millis(130),
"normalization should limit path variance across adversarial parallel probes"
);
}
#[tokio::test]
async fn integration_beobachten_records_probe_classification_on_refusal() {
let config = make_self_target_config(false, 0, 0, true);
let peer: SocketAddr = "198.51.100.71:55071".parse().expect("valid peer");
let local_addr: SocketAddr = "127.0.0.1:443".parse().expect("valid local addr");
let beobachten = BeobachtenStore::new();
let (mut client, server) = duplex(1024);
let task = tokio::spawn(async move {
handle_bad_client(
server,
tokio::io::sink(),
b"GET /classified HTTP/1.1\r\nHost: demo\r\n\r\n",
peer,
local_addr,
&config,
&beobachten,
)
.await;
beobachten.snapshot_text(Duration::from_secs(60))
});
client
.shutdown()
.await
.expect("client shutdown must succeed");
let snapshot = timeout(Duration::from_secs(3), task)
.await
.expect("integration task must complete")
.expect("integration task must not panic");
assert!(snapshot.contains("[HTTP]"));
assert!(snapshot.contains("198.51.100.71-1"));
}
#[tokio::test]
async fn light_fuzz_timing_configuration_matrix_is_bounded() {
let mut seed = 0xA17E_55AA_2026_0323u64;
for case in 0..48u64 {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let enabled = (seed & 1) == 0;
let floor = (seed >> 8) % 180;
let ceiling = (seed >> 24) % 180;
let config = make_self_target_config(enabled, floor, ceiling, false);
let peer: SocketAddr = format!("203.0.113.90:{}", 56000 + (case as u16))
.parse()
.expect("valid peer");
let elapsed = run_self_target_refusal(config, peer, b"HEAD /h HTTP/1.1\r\n\r\n").await;
assert!(
elapsed < Duration::from_millis(420),
"fuzz case must stay bounded and never hang"
);
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn stress_high_fanout_self_target_refusal_no_deadlock_or_timeout() {
let workers = 64usize;
let mut tasks = Vec::with_capacity(workers);
for idx in 0..workers {
tasks.push(tokio::spawn(async move {
let config = make_self_target_config(false, 0, 0, false);
let peer: SocketAddr = format!("198.51.100.200:{}", 57000 + idx as u16)
.parse()
.expect("valid peer");
run_self_target_refusal(config, peer, b"GET /stress HTTP/1.1\r\n\r\n").await
}));
}
timeout(Duration::from_secs(5), async {
for task in tasks {
let elapsed = task.await.expect("stress task must not panic");
assert!(
elapsed < Duration::from_millis(260),
"stress refusal must remain bounded without normalization"
);
}
})
.await
.expect("high-fanout refusal workload must complete without deadlock");
}

View File

@ -0,0 +1,55 @@
use super::*;
use tokio::net::TcpListener;
use tokio::time::Duration;
#[tokio::test]
async fn http2_preface_is_forwarded_and_recorded_as_http() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let preface = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n".to_vec();
let accept_task = tokio::spawn({
let preface = preface.clone();
async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut received = vec![0u8; preface.len()];
stream.read_exact(&mut received).await.unwrap();
assert_eq!(received, preface);
}
});
let mut config = ProxyConfig::default();
config.general.beobachten = true;
config.general.beobachten_minutes = 1;
config.censorship.mask = true;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = backend_addr.port();
config.censorship.mask_unix_sock = None;
config.censorship.mask_proxy_protocol = 0;
let peer: SocketAddr = "198.51.100.130:54130".parse().unwrap();
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let (client_reader, _client_writer) = tokio::io::duplex(512);
let (_client_visible_reader, client_visible_writer) = tokio::io::duplex(512);
let beobachten = BeobachtenStore::new();
handle_bad_client(
client_reader,
client_visible_writer,
&preface,
peer,
local_addr,
&config,
&beobachten,
)
.await;
tokio::time::timeout(Duration::from_secs(2), accept_task)
.await
.unwrap()
.unwrap();
let snapshot = beobachten.snapshot_text(Duration::from_secs(60));
assert!(snapshot.contains("[HTTP]"));
assert!(snapshot.contains("198.51.100.130-1"));
}

View File

@ -0,0 +1,92 @@
use super::*;
#[test]
fn full_http2_preface_classified_as_http_probe() {
let preface = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n";
assert!(
is_http_probe(preface),
"HTTP/2 connection preface must be classified as HTTP probe"
);
}
#[test]
fn partial_http2_preface_3_bytes_classified() {
assert!(
is_http_probe(b"PRI"),
"3-byte HTTP/2 preface prefix must be classified"
);
}
#[test]
fn partial_http2_preface_2_bytes_classified() {
assert!(
is_http_probe(b"PR"),
"2-byte HTTP/2 preface prefix must be classified"
);
}
#[test]
fn existing_http1_methods_unaffected() {
for prefix in [
b"GET / HTTP/1.1\r\n".as_ref(),
b"POST /api HTTP/1.1\r\n".as_ref(),
b"CONNECT example.com:443 HTTP/1.1\r\n".as_ref(),
b"TRACE / HTTP/1.1\r\n".as_ref(),
b"PATCH / HTTP/1.1\r\n".as_ref(),
] {
assert!(is_http_probe(prefix));
}
}
#[test]
fn non_http_data_not_classified() {
for data in [
b"\x16\x03\x01\x00\xf1".as_ref(),
b"SSH-2.0-OpenSSH_8.9\r\n".as_ref(),
b"\x00\x01\x02\x03".as_ref(),
b"".as_ref(),
b"P".as_ref(),
] {
assert!(!is_http_probe(data));
}
}
#[test]
fn light_fuzz_non_http_prefixes_not_misclassified() {
// Deterministic pseudo-fuzz to exercise classifier edges while avoiding
// known HTTP method and partial windows.
let mut x = 0x1234_5678u32;
for _ in 0..1024 {
x = x.wrapping_mul(1664525).wrapping_add(1013904223);
let len = 4 + ((x >> 8) as usize % 12);
let mut data = vec![0u8; len];
for byte in &mut data {
x = x.wrapping_mul(1664525).wrapping_add(1013904223);
*byte = (x & 0xFF) as u8;
}
if [
b"GET ".as_ref(),
b"POST".as_ref(),
b"HEAD".as_ref(),
b"PUT ".as_ref(),
b"DELETE".as_ref(),
b"OPTIONS".as_ref(),
b"CONNECT".as_ref(),
b"TRACE".as_ref(),
b"PATCH".as_ref(),
b"PRI ".as_ref(),
]
.iter()
.any(|m| data.starts_with(m))
{
continue;
}
assert!(
!is_http_probe(&data),
"non-http pseudo-fuzz input misclassified: {:?}",
&data[..data.len().min(8)]
);
}
}

View File

@ -0,0 +1,79 @@
use super::*;
#[test]
fn exact_four_byte_http_tokens_are_classified() {
for token in [b"GET ".as_ref(), b"POST".as_ref(), b"HEAD".as_ref(), b"PUT ".as_ref(), b"PRI ".as_ref()] {
assert!(
is_http_probe(token),
"exact 4-byte token must be classified as HTTP probe: {:?}",
token
);
}
}
#[test]
fn exact_four_byte_non_http_tokens_are_not_classified() {
for token in [
b"GEX ".as_ref(),
b"POXT".as_ref(),
b"HEA/".as_ref(),
b"PU\0 ".as_ref(),
b"PRI/".as_ref(),
] {
assert!(
!is_http_probe(token),
"non-HTTP 4-byte token must not be classified: {:?}",
token
);
}
}
#[test]
fn detect_client_type_keeps_http_label_for_minimal_four_byte_http_prefixes() {
assert_eq!(detect_client_type(b"GET "), "HTTP");
assert_eq!(detect_client_type(b"PRI "), "HTTP");
}
#[test]
fn exact_long_http_tokens_are_classified() {
for token in [b"CONNECT".as_ref(), b"TRACE".as_ref(), b"PATCH".as_ref()] {
assert!(
is_http_probe(token),
"exact long HTTP token must be classified as HTTP probe: {:?}",
token
);
}
}
#[test]
fn detect_client_type_keeps_http_label_for_exact_long_http_tokens() {
assert_eq!(detect_client_type(b"CONNECT"), "HTTP");
assert_eq!(detect_client_type(b"TRACE"), "HTTP");
assert_eq!(detect_client_type(b"PATCH"), "HTTP");
}
#[test]
fn light_fuzz_four_byte_ascii_noise_not_misclassified() {
// Deterministic pseudo-fuzz over 4-byte printable ASCII inputs.
let mut x = 0xA17C_93E5u32;
for _ in 0..2048 {
let mut token = [0u8; 4];
for byte in &mut token {
x = x.wrapping_mul(1664525).wrapping_add(1013904223);
*byte = 32 + ((x & 0x3F) as u8); // printable ASCII subset
}
if [b"GET ", b"POST", b"HEAD", b"PUT ", b"PRI "]
.iter()
.any(|m| token.as_slice() == *m)
{
continue;
}
assert!(
!is_http_probe(&token),
"pseudo-fuzz noise misclassified as HTTP probe: {:?}",
token
);
}
}

View File

@ -0,0 +1,41 @@
#![cfg(unix)]
use super::*;
use std::sync::{Mutex, OnceLock};
use tokio::sync::Barrier;
fn interface_cache_test_lock() -> &'static Mutex<()> {
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
LOCK.get_or_init(|| Mutex::new(()))
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn adversarial_parallel_cold_miss_performs_single_interface_refresh() {
let _guard = interface_cache_test_lock()
.lock()
.unwrap_or_else(|poison| poison.into_inner());
reset_local_interface_enumerations_for_tests();
let local_addr: SocketAddr = "0.0.0.0:443".parse().expect("valid local addr");
let workers = 32usize;
let barrier = std::sync::Arc::new(Barrier::new(workers));
let mut tasks = Vec::with_capacity(workers);
for _ in 0..workers {
let barrier = std::sync::Arc::clone(&barrier);
tasks.push(tokio::spawn(async move {
barrier.wait().await;
is_mask_target_local_listener_async("127.0.0.1", 443, local_addr, None).await
}));
}
for task in tasks {
let _ = task.await.expect("parallel cache task must not panic");
}
assert_eq!(
local_interface_enumerations_for_tests(),
1,
"parallel cold misses must coalesce into a single interface enumeration"
);
}

View File

@ -0,0 +1,51 @@
#![cfg(unix)]
use super::*;
#[test]
fn defense_in_depth_empty_refresh_preserves_previous_non_empty_interfaces() {
let previous = vec![
"192.168.100.7"
.parse::<IpAddr>()
.expect("must parse interface ip"),
];
let refreshed = Vec::new();
let next = choose_interface_snapshot(&previous, refreshed);
assert_eq!(
next, previous,
"empty refresh should preserve previous non-empty snapshot to avoid fail-open loop-guard regressions"
);
}
#[test]
fn defense_in_depth_non_empty_refresh_replaces_previous_snapshot() {
let previous = vec![
"192.168.100.7"
.parse::<IpAddr>()
.expect("must parse interface ip"),
];
let refreshed = vec![
"10.55.0.3"
.parse::<IpAddr>()
.expect("must parse refreshed interface ip"),
];
let next = choose_interface_snapshot(&previous, refreshed.clone());
assert_eq!(next, refreshed);
}
#[test]
fn defense_in_depth_empty_refresh_keeps_empty_when_no_previous_snapshot_exists() {
let previous = Vec::new();
let refreshed = Vec::new();
let next = choose_interface_snapshot(&previous, refreshed);
assert!(
next.is_empty(),
"empty refresh with no previous snapshot should remain empty"
);
}

View File

@ -0,0 +1,46 @@
#![cfg(unix)]
use super::*;
use std::sync::{Mutex, OnceLock};
fn interface_cache_test_lock() -> &'static Mutex<()> {
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
LOCK.get_or_init(|| Mutex::new(()))
}
#[tokio::test]
async fn tdd_repeated_local_listener_checks_do_not_repeat_interface_enumeration_within_window() {
let _guard = interface_cache_test_lock()
.lock()
.unwrap_or_else(|poison| poison.into_inner());
reset_local_interface_enumerations_for_tests();
let local_addr: SocketAddr = "0.0.0.0:443".parse().expect("valid local addr");
let _ = is_mask_target_local_listener_async("127.0.0.1", 443, local_addr, None).await;
let _ = is_mask_target_local_listener_async("127.0.0.1", 443, local_addr, None).await;
assert_eq!(
local_interface_enumerations_for_tests(),
1,
"interface enumeration must be cached across repeated bad-client checks"
);
}
#[tokio::test]
async fn tdd_non_local_port_short_circuit_does_not_enumerate_interfaces() {
let _guard = interface_cache_test_lock()
.lock()
.unwrap_or_else(|poison| poison.into_inner());
reset_local_interface_enumerations_for_tests();
let local_addr: SocketAddr = "0.0.0.0:443".parse().expect("valid local addr");
let is_local = is_mask_target_local_listener_async("127.0.0.1", 8443, local_addr, None).await;
assert!(!is_local, "different port must not be treated as local listener");
assert_eq!(
local_interface_enumerations_for_tests(),
0,
"port mismatch should bypass interface enumeration entirely"
);
}

View File

@ -0,0 +1,178 @@
use super::*;
use std::net::{SocketAddr, TcpListener as StdTcpListener};
use tokio::io::{AsyncWriteExt, duplex};
use tokio::time::{Duration, Instant};
fn closed_local_port() -> u16 {
let listener = StdTcpListener::bind("127.0.0.1:0").unwrap();
let port = listener.local_addr().unwrap().port();
drop(listener);
port
}
#[tokio::test]
#[ignore = "red-team expected-fail: offline mask target keeps bad-client socket alive before consume timeout boundary"]
async fn redteam_offline_target_should_drop_idle_client_early() {
let (client_read, mut client_write) = duplex(1024);
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = false;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = closed_local_port();
cfg.censorship.mask_timing_normalization_enabled = false;
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let peer_addr: SocketAddr = "192.0.2.50:5000".parse().unwrap();
let beobachten = BeobachtenStore::new();
let handler = tokio::spawn(async move {
handle_bad_client(
client_read,
tokio::io::sink(),
b"GET / HTTP/1.1\r\n\r\n",
peer_addr,
local_addr,
&cfg,
&beobachten,
)
.await;
});
tokio::time::sleep(Duration::from_millis(150)).await;
let write_res = client_write.write_all(b"probe-should-be-closed").await;
assert!(
write_res.is_err(),
"offline target path still keeps client writable before consume timeout"
);
handler.abort();
}
#[tokio::test]
#[ignore = "red-team expected-fail: proxy should mimic immediate RST-like close when target is offline"]
async fn redteam_offline_target_should_not_sleep_to_mask_refusal() {
let (client_read, mut client_write) = duplex(1024);
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = false;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = closed_local_port();
cfg.censorship.mask_timing_normalization_enabled = false;
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let peer_addr: SocketAddr = "192.0.2.51:5000".parse().unwrap();
let beobachten = BeobachtenStore::new();
let started = Instant::now();
let handler = tokio::spawn(async move {
handle_bad_client(
client_read,
tokio::io::sink(),
b"\x16\x03\x01\x00\x05hello",
peer_addr,
local_addr,
&cfg,
&beobachten,
)
.await;
});
client_write.shutdown().await.unwrap();
let _ = handler.await;
let elapsed = started.elapsed();
assert!(
elapsed < Duration::from_millis(10),
"offline target path still applies coarse masking sleep and is fingerprintable"
);
}
#[tokio::test]
#[ignore = "red-team expected-fail: refusal path should remain below strict latency envelope under burst"]
async fn redteam_offline_refusal_burst_timing_spread_should_be_tight() {
let mut samples = Vec::new();
for i in 0..12u16 {
let (client_read, mut client_write) = duplex(1024);
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = false;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = closed_local_port();
cfg.censorship.mask_timing_normalization_enabled = false;
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let peer_addr: SocketAddr = format!("192.0.2.52:{}", 5100 + i).parse().unwrap();
let beobachten = BeobachtenStore::new();
let started = Instant::now();
let handler = tokio::spawn(async move {
handle_bad_client(
client_read,
tokio::io::sink(),
b"GET / HTTP/1.1\r\n\r\n",
peer_addr,
local_addr,
&cfg,
&beobachten,
)
.await;
});
client_write.shutdown().await.unwrap();
let _ = handler.await;
samples.push(started.elapsed());
}
let min = samples.iter().copied().min().unwrap_or_default();
let max = samples.iter().copied().max().unwrap_or_default();
let spread = max.saturating_sub(min);
assert!(
spread <= Duration::from_millis(5),
"offline refusal timing spread too wide for strict red-team envelope: {:?}",
spread
);
}
#[tokio::test]
#[ignore = "manual red-team: host resolver failure should complete without panic"]
async fn redteam_dns_resolution_failure_must_not_panic() {
let (client_read, mut client_write) = duplex(1024);
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = false;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("this.domain.definitely.does.not.exist.invalid".to_string());
cfg.censorship.mask_port = 443;
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let peer_addr: SocketAddr = "192.0.2.99:5999".parse().unwrap();
let beobachten = BeobachtenStore::new();
let handler = tokio::spawn(async move {
handle_bad_client(
client_read,
tokio::io::sink(),
b"GET / HTTP/1.1\r\n\r\n",
peer_addr,
local_addr,
&cfg,
&beobachten,
)
.await;
});
client_write.shutdown().await.unwrap();
let result = tokio::time::timeout(Duration::from_secs(2), handler).await;
assert!(
result.is_ok(),
"dns failure path stalled or panicked instead of terminating"
);
}

View File

@ -0,0 +1,51 @@
use super::*;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Instant;
use tokio::io::AsyncWrite;
struct NeverWritable;
impl AsyncWrite for NeverWritable {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &[u8],
) -> Poll<std::io::Result<usize>> {
Poll::Pending
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Poll::Pending
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Poll::Ready(Ok(()))
}
}
#[tokio::test]
async fn shape_padding_returns_before_global_mask_timeout_on_blocked_writer() {
let mut writer = NeverWritable;
let started = Instant::now();
maybe_write_shape_padding(&mut writer, 1, true, 256, 4096, false, 0, false).await;
assert!(
started.elapsed() <= MASK_TIMEOUT + std::time::Duration::from_millis(30),
"shape padding blocked past timeout budget"
);
}
#[tokio::test]
async fn shape_padding_with_non_http_blur_disabled_at_cap_writes_nothing() {
let mut output = Vec::new();
{
let mut writer = tokio::io::BufWriter::new(&mut output);
maybe_write_shape_padding(&mut writer, 4096, true, 64, 4096, false, 128, false).await;
use tokio::io::AsyncWriteExt;
writer.flush().await.unwrap();
}
assert!(output.is_empty());
}

View File

@ -0,0 +1,289 @@
use super::*;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::time::{Duration, Instant, timeout};
const PROD_CAP_BYTES: usize = 5 * 1024 * 1024;
struct FinitePatternReader {
remaining: usize,
chunk: usize,
read_calls: Arc<AtomicUsize>,
}
impl FinitePatternReader {
fn new(total: usize, chunk: usize, read_calls: Arc<AtomicUsize>) -> Self {
Self {
remaining: total,
chunk,
read_calls,
}
}
}
impl AsyncRead for FinitePatternReader {
fn poll_read(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
self.read_calls.fetch_add(1, Ordering::Relaxed);
if self.remaining == 0 {
return Poll::Ready(Ok(()));
}
let take = self.remaining.min(self.chunk).min(buf.remaining());
if take == 0 {
return Poll::Ready(Ok(()));
}
let fill = vec![0x5Au8; take];
buf.put_slice(&fill);
self.remaining -= take;
Poll::Ready(Ok(()))
}
}
#[derive(Default)]
struct CountingWriter {
written: usize,
}
impl AsyncWrite for CountingWriter {
fn poll_write(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
self.written = self.written.saturating_add(buf.len());
Poll::Ready(Ok(buf.len()))
}
fn poll_flush(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<std::io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<std::io::Result<()>> {
Poll::Ready(Ok(()))
}
}
struct NeverReadyReader;
impl AsyncRead for NeverReadyReader {
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
Poll::Pending
}
}
struct BudgetProbeReader {
remaining: usize,
total_read: Arc<AtomicUsize>,
}
impl BudgetProbeReader {
fn new(total: usize, total_read: Arc<AtomicUsize>) -> Self {
Self {
remaining: total,
total_read,
}
}
}
impl AsyncRead for BudgetProbeReader {
fn poll_read(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
if self.remaining == 0 {
return Poll::Ready(Ok(()));
}
let take = self.remaining.min(buf.remaining());
if take == 0 {
return Poll::Ready(Ok(()));
}
let fill = vec![0xA5u8; take];
buf.put_slice(&fill);
self.remaining -= take;
self.total_read.fetch_add(take, Ordering::Relaxed);
Poll::Ready(Ok(()))
}
}
#[tokio::test]
async fn positive_copy_with_production_cap_stops_exactly_at_budget() {
let read_calls = Arc::new(AtomicUsize::new(0));
let mut reader = FinitePatternReader::new(PROD_CAP_BYTES + (256 * 1024), 4096, read_calls);
let mut writer = CountingWriter::default();
let outcome = copy_with_idle_timeout(&mut reader, &mut writer, PROD_CAP_BYTES, true).await;
assert_eq!(
outcome.total, PROD_CAP_BYTES,
"copy path must stop at explicit production cap"
);
assert_eq!(writer.written, PROD_CAP_BYTES);
assert!(
!outcome.ended_by_eof,
"byte-cap stop must not be misclassified as EOF"
);
}
#[tokio::test]
async fn negative_consume_with_zero_cap_performs_no_reads() {
let read_calls = Arc::new(AtomicUsize::new(0));
let reader = FinitePatternReader::new(1024, 64, Arc::clone(&read_calls));
consume_client_data_with_timeout_and_cap(reader, 0).await;
assert_eq!(
read_calls.load(Ordering::Relaxed),
0,
"zero cap must return before reading attacker-controlled bytes"
);
}
#[tokio::test]
async fn edge_copy_below_cap_reports_eof_without_overread() {
let read_calls = Arc::new(AtomicUsize::new(0));
let payload = 73 * 1024;
let mut reader = FinitePatternReader::new(payload, 3072, read_calls);
let mut writer = CountingWriter::default();
let outcome = copy_with_idle_timeout(&mut reader, &mut writer, PROD_CAP_BYTES, true).await;
assert_eq!(outcome.total, payload);
assert_eq!(writer.written, payload);
assert!(
outcome.ended_by_eof,
"finite upstream below cap must terminate via EOF path"
);
}
#[tokio::test]
async fn adversarial_blackhat_never_ready_reader_is_bounded_by_timeout_guards() {
let started = Instant::now();
consume_client_data_with_timeout_and_cap(NeverReadyReader, PROD_CAP_BYTES).await;
assert!(
started.elapsed() < Duration::from_millis(350),
"never-ready reader must be bounded by idle/relay timeout protections"
);
}
#[tokio::test]
async fn integration_consume_path_honors_production_cap_for_large_payload() {
let read_calls = Arc::new(AtomicUsize::new(0));
let reader = FinitePatternReader::new(PROD_CAP_BYTES + (1024 * 1024), 8192, read_calls);
let bounded = timeout(
Duration::from_millis(350),
consume_client_data_with_timeout_and_cap(reader, PROD_CAP_BYTES),
)
.await;
assert!(
bounded.is_ok(),
"consume path with production cap must finish within bounded time"
);
}
#[tokio::test]
async fn adversarial_consume_path_never_reads_beyond_declared_byte_cap() {
let byte_cap = 5usize;
let total_read = Arc::new(AtomicUsize::new(0));
let reader = BudgetProbeReader::new(256 * 1024, Arc::clone(&total_read));
consume_client_data_with_timeout_and_cap(reader, byte_cap).await;
assert!(
total_read.load(Ordering::Relaxed) <= byte_cap,
"consume path must not read more than configured byte cap"
);
}
#[tokio::test]
async fn light_fuzz_cap_and_payload_matrix_preserves_min_budget_invariant() {
let mut seed = 0x1234_5678_9ABC_DEF0u64;
for _case in 0..96u32 {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let cap = ((seed & 0x3ffff) as usize).saturating_add(1);
let payload = ((seed.rotate_left(11) & 0x7ffff) as usize).saturating_add(1);
let chunk = (((seed >> 5) & 0x1fff) as usize).saturating_add(1);
let read_calls = Arc::new(AtomicUsize::new(0));
let mut reader = FinitePatternReader::new(payload, chunk, read_calls);
let mut writer = CountingWriter::default();
let outcome = copy_with_idle_timeout(&mut reader, &mut writer, cap, true).await;
let expected = payload.min(cap);
assert_eq!(
outcome.total, expected,
"copy total must match min(payload, cap) under fuzzed inputs"
);
assert_eq!(writer.written, expected);
if payload <= cap {
assert!(outcome.ended_by_eof);
} else {
assert!(!outcome.ended_by_eof);
}
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn stress_parallel_copy_tasks_with_production_cap_complete_without_leaks() {
let workers = 8usize;
let mut tasks = Vec::with_capacity(workers);
for idx in 0..workers {
tasks.push(tokio::spawn(async move {
let read_calls = Arc::new(AtomicUsize::new(0));
let mut reader = FinitePatternReader::new(
PROD_CAP_BYTES + (idx + 1) * 4096,
4096 + (idx * 257),
read_calls,
);
let mut writer = CountingWriter::default();
copy_with_idle_timeout(&mut reader, &mut writer, PROD_CAP_BYTES, true).await
}));
}
timeout(Duration::from_secs(3), async {
for task in tasks {
let outcome = task.await.expect("stress task must not panic");
assert_eq!(
outcome.total, PROD_CAP_BYTES,
"stress copy task must stay within production cap"
);
assert!(
!outcome.ended_by_eof,
"stress task should end due to cap, not EOF"
);
}
})
.await
.expect("stress suite must complete in bounded time");
}

View File

@ -0,0 +1,105 @@
use super::*;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex, sink};
use tokio::time::{Duration, timeout};
#[tokio::test]
async fn relay_to_mask_enforces_masking_session_byte_cap() {
let initial = vec![0x16, 0x03, 0x01, 0x00, 0x01];
let extra = vec![0xAB; 96 * 1024];
let (client_reader, mut client_writer) = duplex(128 * 1024);
let (mask_read, _mask_read_peer) = duplex(1024);
let (mut mask_observer, mask_write) = duplex(256 * 1024);
let initial_for_task = initial.clone();
let relay = tokio::spawn(async move {
relay_to_mask(
client_reader,
sink(),
mask_read,
mask_write,
&initial_for_task,
false,
512,
4096,
false,
0,
false,
32 * 1024,
)
.await;
});
client_writer.write_all(&extra).await.unwrap();
client_writer.shutdown().await.unwrap();
timeout(Duration::from_secs(2), relay)
.await
.unwrap()
.unwrap();
let mut observed = Vec::new();
timeout(
Duration::from_secs(2),
mask_observer.read_to_end(&mut observed),
)
.await
.unwrap()
.unwrap();
// In this deterministic test, relay must stop exactly at the configured cap.
assert_eq!(
observed.len(),
initial.len() + (32 * 1024),
"masked relay must forward exactly up to the cap (observed={} initial={} cap={})",
observed.len(),
initial.len(),
32 * 1024
);
}
#[tokio::test]
async fn relay_to_mask_propagates_client_half_close_without_waiting_for_other_direction_timeout() {
let initial = b"GET /half-close HTTP/1.1\r\n".to_vec();
let (client_reader, mut client_writer) = duplex(8 * 1024);
let (mask_read, _mask_read_peer) = duplex(8 * 1024);
let (mut mask_observer, mask_write) = duplex(8 * 1024);
let initial_for_task = initial.clone();
let relay = tokio::spawn(async move {
relay_to_mask(
client_reader,
sink(),
mask_read,
mask_write,
&initial_for_task,
false,
512,
4096,
false,
0,
false,
32 * 1024,
)
.await;
});
client_writer.shutdown().await.unwrap();
let mut observed = Vec::new();
timeout(
Duration::from_millis(80),
mask_observer.read_to_end(&mut observed),
)
.await
.expect("mask backend write side should be half-closed promptly")
.unwrap();
assert_eq!(&observed[..initial.len()], initial.as_slice());
timeout(Duration::from_secs(2), relay)
.await
.unwrap()
.unwrap();
}

View File

@ -0,0 +1,100 @@
use super::*;
use tokio::io::AsyncReadExt;
use tokio::time::{Duration, timeout};
async fn collect_padding(
total_sent: usize,
enabled: bool,
floor: usize,
cap: usize,
above_cap_blur: bool,
blur_max: usize,
aggressive: bool,
) -> Vec<u8> {
let (mut tx, mut rx) = tokio::io::duplex(256 * 1024);
maybe_write_shape_padding(
&mut tx,
total_sent,
enabled,
floor,
cap,
above_cap_blur,
blur_max,
aggressive,
)
.await;
drop(tx);
let mut output = Vec::new();
timeout(Duration::from_secs(1), rx.read_to_end(&mut output))
.await
.expect("reading padded output timed out")
.expect("failed reading padded output");
output
}
#[tokio::test]
async fn padding_output_is_not_all_zero() {
let output = collect_padding(1, true, 256, 4096, false, 0, false).await;
assert!(
output.len() >= 255,
"expected at least 255 padding bytes, got {}",
output.len()
);
let nonzero = output.iter().filter(|&&b| b != 0).count();
// In 255 bytes of uniform randomness, the expected number of zero bytes is ~1.
// A weak nonzero check can miss severe entropy collapse.
assert!(
nonzero >= 240,
"RNG output entropy collapsed, too many zero bytes: {} nonzero out of {}",
nonzero,
output.len(),
);
}
#[tokio::test]
async fn padding_reaches_first_bucket_boundary() {
let output = collect_padding(1, true, 64, 4096, false, 0, false).await;
assert_eq!(output.len(), 63);
}
#[tokio::test]
async fn disabled_padding_produces_no_output() {
let output = collect_padding(0, false, 256, 4096, false, 0, false).await;
assert!(output.is_empty());
}
#[tokio::test]
async fn at_cap_without_blur_produces_no_output() {
let output = collect_padding(4096, true, 64, 4096, false, 0, false).await;
assert!(output.is_empty());
}
#[tokio::test]
async fn above_cap_blur_is_positive_and_bounded_in_aggressive_mode() {
let output = collect_padding(4096, true, 64, 4096, true, 128, true).await;
assert!(!output.is_empty());
assert!(output.len() <= 128, "blur exceeded max: {}", output.len());
}
#[tokio::test]
async fn stress_padding_runs_are_not_constant_pattern() {
// Stress and sanity-check: repeated runs should not collapse to identical
// first 16 bytes across all samples.
let mut first_chunks = Vec::new();
for _ in 0..64 {
let out = collect_padding(1, true, 64, 4096, false, 0, false).await;
first_chunks.push(out[..16].to_vec());
}
let first = &first_chunks[0];
let all_same = first_chunks.iter().all(|chunk| chunk == first);
assert!(
!all_same,
"all stress samples had identical prefix, rng output appears degenerate"
);
}

View File

@ -1376,6 +1376,7 @@ async fn relay_to_mask_keeps_backend_to_client_flow_when_client_to_backend_stall
false,
0,
false,
5 * 1024 * 1024,
)
.await;
});
@ -1506,6 +1507,7 @@ async fn relay_to_mask_timeout_cancels_and_drops_all_io_endpoints() {
false,
0,
false,
5 * 1024 * 1024,
),
)
.await;

View File

@ -0,0 +1,360 @@
use super::*;
use std::net::TcpListener as StdTcpListener;
use std::net::SocketAddr;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::net::TcpListener;
use tokio::time::{Duration, Instant, timeout};
fn closed_local_port() -> u16 {
let listener = StdTcpListener::bind("127.0.0.1:0").unwrap();
let port = listener.local_addr().unwrap().port();
drop(listener);
port
}
#[tokio::test]
async fn self_target_detection_matches_literal_ipv4_listener() {
let local: SocketAddr = "198.51.100.40:443".parse().unwrap();
assert!(is_mask_target_local_listener_async(
"198.51.100.40",
443,
local,
None,
)
.await);
}
#[tokio::test]
async fn self_target_detection_matches_bracketed_ipv6_listener() {
let local: SocketAddr = "[2001:db8::44]:8443".parse().unwrap();
assert!(is_mask_target_local_listener_async(
"[2001:db8::44]",
8443,
local,
None,
)
.await);
}
#[tokio::test]
async fn self_target_detection_keeps_same_ip_different_port_forwardable() {
let local: SocketAddr = "203.0.113.44:443".parse().unwrap();
assert!(!is_mask_target_local_listener_async(
"203.0.113.44",
8443,
local,
None,
)
.await);
}
#[tokio::test]
async fn self_target_detection_normalizes_ipv4_mapped_ipv6_literal() {
let local: SocketAddr = "127.0.0.1:443".parse().unwrap();
assert!(is_mask_target_local_listener_async(
"::ffff:127.0.0.1",
443,
local,
None,
)
.await);
}
#[tokio::test]
async fn self_target_detection_unspecified_bind_blocks_loopback_target() {
let local: SocketAddr = "0.0.0.0:443".parse().unwrap();
assert!(is_mask_target_local_listener_async(
"127.0.0.1",
443,
local,
None,
)
.await);
}
#[tokio::test]
async fn self_target_detection_unspecified_bind_keeps_remote_target_forwardable() {
let local: SocketAddr = "0.0.0.0:443".parse().unwrap();
let remote: SocketAddr = "198.51.100.44:443".parse().unwrap();
assert!(!is_mask_target_local_listener_async(
"mask.example",
443,
local,
Some(remote),
)
.await);
}
#[tokio::test]
async fn self_target_fallback_refuses_recursive_loopback_connect() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let local_addr = listener.local_addr().unwrap();
let accept_task = tokio::spawn(async move {
timeout(Duration::from_millis(120), listener.accept())
.await
.is_ok()
});
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_unix_sock = None;
config.censorship.mask_host = Some(local_addr.ip().to_string());
config.censorship.mask_port = local_addr.port();
config.censorship.mask_proxy_protocol = 0;
let peer: SocketAddr = "203.0.113.90:55090".parse().unwrap();
let beobachten = BeobachtenStore::new();
handle_bad_client(
tokio::io::empty(),
tokio::io::sink(),
b"GET /",
peer,
local_addr,
&config,
&beobachten,
)
.await;
let accepted = accept_task.await.unwrap();
assert!(
!accepted,
"self-target masking must fail closed without connecting to local listener"
);
}
#[tokio::test]
async fn same_ip_different_port_still_forwards_to_mask_backend() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let probe = b"GET /".to_vec();
let accept_task = tokio::spawn({
let expected = probe.clone();
async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut got = vec![0u8; expected.len()];
stream.read_exact(&mut got).await.unwrap();
assert_eq!(got, expected);
}
});
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_unix_sock = None;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = backend_addr.port();
config.censorship.mask_proxy_protocol = 0;
let peer: SocketAddr = "203.0.113.91:55091".parse().unwrap();
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let beobachten = BeobachtenStore::new();
handle_bad_client(
tokio::io::empty(),
tokio::io::sink(),
&probe,
peer,
local_addr,
&config,
&beobachten,
)
.await;
timeout(Duration::from_secs(2), accept_task)
.await
.unwrap()
.unwrap();
}
#[test]
fn detect_client_type_http_boundary_get_and_post() {
assert_eq!(detect_client_type(b"GET "), "HTTP");
assert_eq!(detect_client_type(b"GET /"), "HTTP");
assert_eq!(detect_client_type(b"POST"), "HTTP");
assert_eq!(detect_client_type(b"POST "), "HTTP");
assert_eq!(detect_client_type(b"POSTX"), "HTTP");
}
#[test]
fn detect_client_type_tls_and_length_boundaries() {
assert_eq!(detect_client_type(b"\x16\x03\x01"), "port-scanner");
assert_eq!(detect_client_type(b"\x16\x03\x01\x00"), "TLS-scanner");
assert_eq!(detect_client_type(b"123456789"), "port-scanner");
assert_eq!(detect_client_type(b"1234567890"), "unknown");
}
#[test]
fn build_mask_proxy_header_v1_cross_family_falls_back_to_unknown() {
let peer: SocketAddr = "192.168.1.5:12345".parse().unwrap();
let local: SocketAddr = "[2001:db8::1]:443".parse().unwrap();
let header = build_mask_proxy_header(1, peer, local).unwrap();
assert_eq!(header, b"PROXY UNKNOWN\r\n");
}
#[test]
fn next_mask_shape_bucket_checked_mul_overflow_fails_closed() {
let floor = usize::MAX / 2 + 1;
let cap = usize::MAX;
let total = floor + 1;
assert_eq!(next_mask_shape_bucket(total, floor, cap), total);
}
#[tokio::test]
async fn self_target_reject_path_keeps_timing_budget() {
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_unix_sock = None;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = 443;
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let peer: SocketAddr = "203.0.113.92:55092".parse().unwrap();
let beobachten = BeobachtenStore::new();
let (client, server) = duplex(1024);
drop(client);
let started = Instant::now();
handle_bad_client(
server,
tokio::io::sink(),
b"GET / HTTP/1.1\r\n",
peer,
local_addr,
&config,
&beobachten,
)
.await;
let elapsed = started.elapsed();
assert!(
elapsed >= Duration::from_millis(40) && elapsed < Duration::from_millis(250),
"self-target reject path must keep coarse timing budget without stalling"
);
}
#[tokio::test]
async fn relay_path_idle_timeout_eviction_remains_effective() {
let (client_read, mut client_write) = duplex(1024);
let (mask_read, mask_write) = duplex(1024);
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(10)).await;
client_write.write_all(b"a").await.unwrap();
tokio::time::sleep(Duration::from_millis(180)).await;
let _ = client_write.write_all(b"b").await;
});
let started = Instant::now();
relay_to_mask(
client_read,
tokio::io::sink(),
mask_read,
mask_write,
b"init",
false,
0,
0,
false,
0,
false,
5 * 1024 * 1024,
)
.await;
let elapsed = started.elapsed();
assert!(
elapsed >= Duration::from_millis(90) && elapsed < Duration::from_millis(180),
"idle-timeout eviction must occur before late trickle write"
);
}
#[tokio::test]
async fn offline_mask_target_refusal_respects_timing_normalization_budget() {
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_unix_sock = None;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = closed_local_port();
config.censorship.mask_timing_normalization_enabled = true;
config.censorship.mask_timing_normalization_floor_ms = 120;
config.censorship.mask_timing_normalization_ceiling_ms = 120;
let peer: SocketAddr = "203.0.113.93:55093".parse().unwrap();
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let beobachten = BeobachtenStore::new();
let (mut client, server) = duplex(1024);
let started = Instant::now();
let task = tokio::spawn(async move {
handle_bad_client(
server,
tokio::io::sink(),
b"GET / HTTP/1.1\r\n\r\n",
peer,
local_addr,
&config,
&beobachten,
)
.await;
});
client.shutdown().await.unwrap();
timeout(Duration::from_secs(2), task).await.unwrap().unwrap();
let elapsed = started.elapsed();
assert!(
elapsed >= Duration::from_millis(110) && elapsed < Duration::from_millis(220),
"offline-refusal path must honor normalization budget without unbounded drift"
);
}
#[tokio::test]
async fn offline_mask_target_refusal_with_idle_client_is_bounded_by_consume_timeout() {
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_unix_sock = None;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = closed_local_port();
config.censorship.mask_timing_normalization_enabled = false;
let peer: SocketAddr = "203.0.113.94:55094".parse().unwrap();
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let beobachten = BeobachtenStore::new();
let (mut client, server) = duplex(1024);
let started = Instant::now();
let task = tokio::spawn(async move {
handle_bad_client(
server,
tokio::io::sink(),
b"GET / HTTP/1.1\r\n\r\n",
peer,
local_addr,
&config,
&beobachten,
)
.await;
});
tokio::time::sleep(Duration::from_millis(120)).await;
client
.write_all(b"still-open-before-timeout")
.await
.expect("connection should still be open before consume timeout expires");
timeout(Duration::from_secs(2), task).await.unwrap().unwrap();
let elapsed = started.elapsed();
assert!(
elapsed >= Duration::from_millis(190) && elapsed < Duration::from_millis(350),
"offline-refusal path must not retain idle client indefinitely"
);
}

View File

@ -43,6 +43,7 @@ async fn run_relay_case(
above_cap_blur,
above_cap_blur_max_bytes,
false,
5 * 1024 * 1024,
)
.await;
});

View File

@ -88,6 +88,7 @@ async fn relay_to_mask_applies_cap_clamped_padding_for_non_power_of_two_cap() {
false,
0,
false,
5 * 1024 * 1024,
)
.await;
});

View File

@ -0,0 +1,55 @@
#![cfg(unix)]
use super::*;
use tokio::io::{AsyncWriteExt, duplex};
use tokio::time::{Duration, Instant, timeout};
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn adversarial_delayed_interface_lookup_does_not_consume_outcome_floor_budget() {
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_unix_sock = None;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = 443;
config.censorship.mask_timing_normalization_enabled = true;
config.censorship.mask_timing_normalization_floor_ms = 120;
config.censorship.mask_timing_normalization_ceiling_ms = 120;
let peer: SocketAddr = "203.0.113.151:55151".parse().expect("valid peer");
let local_addr: SocketAddr = "0.0.0.0:443".parse().expect("valid local addr");
let beobachten = BeobachtenStore::new();
let refresh_lock = LOCAL_INTERFACE_REFRESH_LOCK.get_or_init(|| AsyncMutex::new(()));
let held_refresh_guard = refresh_lock.lock().await;
let (mut client, server) = duplex(1024);
let started = Instant::now();
let task = tokio::spawn(async move {
handle_bad_client(
server,
tokio::io::sink(),
b"GET / HTTP/1.1\r\n\r\n",
peer,
local_addr,
&config,
&beobachten,
)
.await;
});
tokio::time::sleep(Duration::from_millis(80)).await;
drop(held_refresh_guard);
client.shutdown().await.expect("client shutdown must succeed");
timeout(Duration::from_secs(2), task)
.await
.expect("task must finish in bounded time")
.expect("task must not panic");
let elapsed = started.elapsed();
assert!(
elapsed >= Duration::from_millis(180) && elapsed < Duration::from_millis(350),
"timing normalization floor must start after pre-outcome self-target checks"
);
}

View File

@ -9,6 +9,7 @@ use tokio::time::{Duration, timeout};
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn blackhat_campaign_saturation_quota_race_with_queue_pressure_stays_fail_closed() {
let _guard = super::quota_user_lock_test_scope();
let _pressure_guard = super::relay_idle_pressure_test_scope();
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();

View File

@ -645,6 +645,75 @@ fn quota_exceeded_boundary_is_inclusive() {
assert!(!quota_exceeded_for_user(&stats, user, Some(51)));
}
#[test]
fn quota_soft_helper_matches_capped_generic_helper_matrix() {
let stats = Stats::new();
let user = "quota-soft-parity";
for used in [0u64, 1, 7, 63, 127, 255] {
stats.sub_user_octets_to(user, stats.get_user_total_octets(user));
stats.add_user_octets_to(user, used);
for quota in [8u64, 64, 128, 256] {
for overshoot in [0u64, 1, 5, 32] {
for bytes in [0u64, 1, 2, 7, 31, 64] {
let soft = quota_would_be_exceeded_for_user_soft(
&stats,
user,
Some(quota),
bytes,
overshoot,
);
let capped = quota_would_be_exceeded_for_user(
&stats,
user,
Some(quota_soft_cap(quota, overshoot)),
bytes,
);
assert_eq!(
soft, capped,
"soft helper parity mismatch: used={used} quota={quota} overshoot={overshoot} bytes={bytes}"
);
}
}
}
}
}
#[test]
fn quota_soft_helper_none_limit_never_rejects() {
let stats = Stats::new();
let user = "quota-soft-none";
stats.add_user_octets_to(user, u64::MAX);
assert!(!quota_would_be_exceeded_for_user_soft(
&stats,
user,
None,
u64::MAX,
u64::MAX,
));
}
#[test]
fn quota_soft_cap_saturates_and_stays_fail_closed() {
let stats = Stats::new();
let user = "quota-soft-saturating";
let quota = u64::MAX - 2;
let overshoot = 100;
assert_eq!(quota_soft_cap(quota, overshoot), u64::MAX);
stats.add_user_octets_to(user, u64::MAX - 1);
assert!(quota_would_be_exceeded_for_user_soft(
&stats,
user,
Some(quota),
2,
overshoot,
));
}
#[tokio::test]
async fn enqueue_c2me_close_fast_path_succeeds_without_backpressure() {
let (tx, mut rx) = mpsc::channel::<C2MeCommand>(4);

View File

@ -0,0 +1,295 @@
use super::*;
use crate::crypto::{AesCtr, SecureRandom};
use crate::stats::Stats;
use crate::stream::CryptoWriter;
use bytes::Bytes;
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll, Waker};
use tokio::io::AsyncWrite;
use tokio::sync::Notify;
use tokio::task::JoinSet;
use tokio::time::{Duration, timeout};
fn make_crypto_writer<W>(writer: W) -> CryptoWriter<W>
where
W: tokio::io::AsyncWrite + Unpin,
{
let key = [0u8; 32];
let iv = 0u128;
CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024)
}
#[derive(Default)]
struct BlockingWriteState {
write_entered: AtomicBool,
released: AtomicBool,
write_waker: Mutex<Option<Waker>>,
write_entered_notify: Notify,
}
struct BlockingWrite {
state: Arc<BlockingWriteState>,
}
impl BlockingWrite {
fn new(state: Arc<BlockingWriteState>) -> Self {
Self { state }
}
}
impl AsyncWrite for BlockingWrite {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
self.state.write_entered.store(true, Ordering::Release);
self.state.write_entered_notify.notify_waiters();
if self.state.released.load(Ordering::Acquire) {
return Poll::Ready(Ok(buf.len()));
}
if let Ok(mut slot) = self.state.write_waker.lock() {
*slot = Some(cx.waker().clone());
}
Poll::Pending
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Poll::Ready(Ok(()))
}
}
async fn wait_until_blocking_write_entered(state: &Arc<BlockingWriteState>) {
for _ in 0..8 {
if state.write_entered.load(Ordering::Acquire) {
return;
}
let _ = timeout(Duration::from_millis(25), state.write_entered_notify.notified()).await;
}
panic!("blocking writer did not enter poll_write in bounded time");
}
fn release_blocking_write(state: &Arc<BlockingWriteState>) {
state.released.store(true, Ordering::Release);
if let Ok(mut slot) = state.write_waker.lock()
&& let Some(waker) = slot.take()
{
waker.wake();
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn adversarial_blocked_write_releases_cross_mode_lock_and_preserves_fail_closed_quota() {
let stats = Arc::new(Stats::new());
let user = format!("middle-cross-release-regression-{}", std::process::id());
let cross_mode_lock = Arc::new(cross_mode_quota_user_lock_for_tests(&user));
let bytes_me2c = Arc::new(AtomicU64::new(0));
let writer_state = Arc::new(BlockingWriteState::default());
let first = {
let stats = Arc::clone(&stats);
let user = user.clone();
let cross_mode_lock = Arc::clone(&cross_mode_lock);
let bytes_me2c = Arc::clone(&bytes_me2c);
let writer_state = Arc::clone(&writer_state);
tokio::spawn(async move {
let mut writer = make_crypto_writer(BlockingWrite::new(writer_state));
let mut frame_buf = Vec::new();
process_me_writer_response_with_cross_mode_lock(
MeResponse::Data {
flags: 0,
data: Bytes::from_static(&[0xAA, 0xBB, 0xCC, 0xDD]),
},
&mut writer,
ProtoTag::Intermediate,
&SecureRandom::new(),
&mut frame_buf,
stats.as_ref(),
&user,
Some(4),
0,
Some(&cross_mode_lock),
bytes_me2c.as_ref(),
41_000,
false,
false,
)
.await
})
};
wait_until_blocking_write_entered(&writer_state).await;
let guard = timeout(Duration::from_millis(40), cross_mode_lock.lock())
.await
.expect("cross-mode lock must be released while first write is pending");
drop(guard);
let second = {
let stats = Arc::clone(&stats);
let user = user.clone();
let cross_mode_lock = Arc::clone(&cross_mode_lock);
let bytes_me2c = Arc::clone(&bytes_me2c);
tokio::spawn(async move {
let mut writer = make_crypto_writer(tokio::io::sink());
let mut frame_buf = Vec::new();
timeout(
Duration::from_millis(150),
process_me_writer_response_with_cross_mode_lock(
MeResponse::Data {
flags: 0,
data: Bytes::from_static(&[0xEE]),
},
&mut writer,
ProtoTag::Intermediate,
&SecureRandom::new(),
&mut frame_buf,
stats.as_ref(),
&user,
Some(4),
0,
Some(&cross_mode_lock),
bytes_me2c.as_ref(),
41_001,
false,
false,
),
)
.await
})
};
let second_result = second
.await
.expect("second task must not panic")
.expect("second write must not block on cross-mode lock");
assert!(
matches!(second_result, Err(ProxyError::DataQuotaExceeded { .. })),
"second write must fail closed due to first write reservation"
);
release_blocking_write(&writer_state);
let first_result = timeout(Duration::from_millis(300), first)
.await
.expect("first task timed out")
.expect("first task must not panic");
assert!(first_result.is_ok());
assert_eq!(stats.get_user_total_octets(&user), 4);
assert_eq!(bytes_me2c.load(Ordering::Relaxed), 4);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn stress_pending_write_does_not_starve_same_user_waiters_after_quota_boundary() {
let stats = Arc::new(Stats::new());
let user = format!("middle-cross-release-stress-{}", std::process::id());
let cross_mode_lock = Arc::new(cross_mode_quota_user_lock_for_tests(&user));
let bytes_me2c = Arc::new(AtomicU64::new(0));
let writer_state = Arc::new(BlockingWriteState::default());
let first = {
let stats = Arc::clone(&stats);
let user = user.clone();
let cross_mode_lock = Arc::clone(&cross_mode_lock);
let bytes_me2c = Arc::clone(&bytes_me2c);
let writer_state = Arc::clone(&writer_state);
tokio::spawn(async move {
let mut writer = make_crypto_writer(BlockingWrite::new(writer_state));
let mut frame_buf = Vec::new();
process_me_writer_response_with_cross_mode_lock(
MeResponse::Data {
flags: 0,
data: Bytes::from_static(&[0x01, 0x02]),
},
&mut writer,
ProtoTag::Intermediate,
&SecureRandom::new(),
&mut frame_buf,
stats.as_ref(),
&user,
Some(3),
0,
Some(&cross_mode_lock),
bytes_me2c.as_ref(),
41_100,
false,
false,
)
.await
})
};
wait_until_blocking_write_entered(&writer_state).await;
let mut set = JoinSet::new();
for idx in 0..48u64 {
let stats = Arc::clone(&stats);
let user = user.clone();
let cross_mode_lock = Arc::clone(&cross_mode_lock);
let bytes_me2c = Arc::clone(&bytes_me2c);
set.spawn(async move {
let mut writer = make_crypto_writer(tokio::io::sink());
let mut frame_buf = Vec::new();
timeout(
Duration::from_millis(200),
process_me_writer_response_with_cross_mode_lock(
MeResponse::Data {
flags: 0,
data: Bytes::from_static(&[0x10]),
},
&mut writer,
ProtoTag::Intermediate,
&SecureRandom::new(),
&mut frame_buf,
stats.as_ref(),
&user,
Some(3),
0,
Some(&cross_mode_lock),
bytes_me2c.as_ref(),
41_200 + idx,
false,
false,
),
)
.await
});
}
let mut ok = 0usize;
let mut quota_exceeded = 0usize;
while let Some(done) = set.join_next().await {
let timed = done.expect("waiter task must not panic");
let result = timed.expect("waiter must not block behind pending first write");
match result {
Ok(_) => ok += 1,
Err(ProxyError::DataQuotaExceeded { .. }) => quota_exceeded += 1,
Err(other) => panic!("unexpected error in waiter: {other:?}"),
}
}
assert_eq!(ok, 1, "exactly one waiter should consume remaining one-byte quota");
assert_eq!(quota_exceeded, 47);
release_blocking_write(&writer_state);
let first_result = timeout(Duration::from_millis(300), first)
.await
.expect("first task timed out")
.expect("first task must not panic");
assert!(first_result.is_ok());
assert_eq!(stats.get_user_total_octets(&user), 3);
assert_eq!(bytes_me2c.load(Ordering::Relaxed), 3);
}

View File

@ -0,0 +1,116 @@
use super::*;
use crate::crypto::{AesCtr, SecureRandom};
use crate::stats::Stats;
use crate::stream::CryptoWriter;
use bytes::Bytes;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Mutex, OnceLock};
fn make_crypto_writer<W>(writer: W) -> CryptoWriter<W>
where
W: tokio::io::AsyncWrite + Unpin,
{
let key = [0u8; 32];
let iv = 0u128;
CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024)
}
fn lookup_counter_test_lock() -> &'static Mutex<()> {
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
LOCK.get_or_init(|| Mutex::new(()))
}
#[tokio::test]
async fn tdd_prefetched_cross_mode_lock_avoids_per_frame_registry_lookup_in_me_to_client_writer() {
let _guard = lookup_counter_test_lock()
.lock()
.unwrap_or_else(|poison| poison.into_inner());
let stats = Stats::new();
let user = format!("middle-cross-mode-lookup-{}", std::process::id());
let cross_mode_lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
crate::proxy::quota_lock_registry::reset_cross_mode_quota_user_lock_lookup_count_for_tests();
let mut writer = make_crypto_writer(tokio::io::sink());
let mut frame_buf = Vec::new();
let bytes_me2c = AtomicU64::new(0);
for idx in 0..8u64 {
let outcome = process_me_writer_response_with_cross_mode_lock(
MeResponse::Data {
flags: 0,
data: Bytes::from_static(&[0xAB]),
},
&mut writer,
ProtoTag::Intermediate,
&SecureRandom::new(),
&mut frame_buf,
&stats,
&user,
Some(1024),
0,
Some(&cross_mode_lock),
&bytes_me2c,
20_000 + idx,
false,
false,
)
.await;
assert!(outcome.is_ok());
}
assert_eq!(
crate::proxy::quota_lock_registry::cross_mode_quota_user_lock_lookup_count_for_user_for_tests(&user),
0,
"prefetched lock path must not re-query lock registry per frame"
);
assert_eq!(stats.get_user_total_octets(&user), 8);
assert_eq!(bytes_me2c.load(Ordering::Relaxed), 8);
}
#[tokio::test]
async fn control_without_prefetched_lock_still_uses_registry_lookup_path() {
let _guard = lookup_counter_test_lock()
.lock()
.unwrap_or_else(|poison| poison.into_inner());
let stats = Stats::new();
let user = format!("middle-cross-mode-lookup-control-{}", std::process::id());
crate::proxy::quota_lock_registry::reset_cross_mode_quota_user_lock_lookup_count_for_tests();
let mut writer = make_crypto_writer(tokio::io::sink());
let mut frame_buf = Vec::new();
let bytes_me2c = AtomicU64::new(0);
let outcome = process_me_writer_response_with_cross_mode_lock(
MeResponse::Data {
flags: 0,
data: Bytes::from_static(&[0xCD]),
},
&mut writer,
ProtoTag::Intermediate,
&SecureRandom::new(),
&mut frame_buf,
&stats,
&user,
Some(1024),
0,
None,
&bytes_me2c,
20_100,
false,
false,
)
.await;
assert!(outcome.is_ok());
assert_eq!(
crate::proxy::quota_lock_registry::cross_mode_quota_user_lock_lookup_count_for_user_for_tests(&user),
1,
"fallback path without prefetched lock should perform a registry lookup"
);
}

View File

@ -0,0 +1,376 @@
use super::*;
use crate::crypto::{AesCtr, SecureRandom};
use crate::stats::Stats;
use crate::stream::CryptoWriter;
use bytes::Bytes;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use tokio::time::{Duration, timeout};
fn make_crypto_writer<W>(writer: W) -> CryptoWriter<W>
where
W: tokio::io::AsyncWrite + Unpin,
{
let key = [0u8; 32];
let iv = 0u128;
CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024)
}
#[tokio::test]
async fn positive_quota_limited_me_to_client_write_updates_counters_exactly_once() {
let stats = Stats::new();
let user = format!("middle-cross-matrix-positive-{}", std::process::id());
let mut writer = make_crypto_writer(tokio::io::sink());
let mut frame_buf = Vec::new();
let bytes_me2c = AtomicU64::new(0);
let result = process_me_writer_response(
MeResponse::Data {
flags: 0,
data: Bytes::from_static(&[1, 2, 3, 4]),
},
&mut writer,
ProtoTag::Intermediate,
&SecureRandom::new(),
&mut frame_buf,
&stats,
&user,
Some(128),
0,
&bytes_me2c,
10_001,
false,
false,
)
.await;
assert!(result.is_ok());
assert_eq!(stats.get_user_total_octets(&user), 4);
assert_eq!(bytes_me2c.load(Ordering::Relaxed), 4);
}
#[tokio::test]
async fn negative_held_cross_mode_lock_blocks_quota_limited_me_to_client_path() {
let stats = Stats::new();
let user = format!("middle-cross-matrix-negative-{}", std::process::id());
let held = cross_mode_quota_user_lock_for_tests(&user);
let held_guard = held
.try_lock()
.expect("test must hold lock before ME->C call");
let mut writer = make_crypto_writer(tokio::io::sink());
let mut frame_buf = Vec::new();
let bytes_me2c = AtomicU64::new(0);
let blocked = timeout(
Duration::from_millis(25),
process_me_writer_response(
MeResponse::Data {
flags: 0,
data: Bytes::from_static(&[0x41]),
},
&mut writer,
ProtoTag::Intermediate,
&SecureRandom::new(),
&mut frame_buf,
&stats,
&user,
Some(256),
0,
&bytes_me2c,
10_002,
false,
false,
),
)
.await;
assert!(blocked.is_err());
drop(held_guard);
}
#[tokio::test]
async fn edge_quota_none_bypasses_cross_mode_lock_guard_in_me_to_client_path() {
let stats = Stats::new();
let user = format!("middle-cross-matrix-edge-none-{}", std::process::id());
let held = cross_mode_quota_user_lock_for_tests(&user);
let held_guard = held
.try_lock()
.expect("test must hold lock while quota is disabled");
let mut writer = make_crypto_writer(tokio::io::sink());
let mut frame_buf = Vec::new();
let bytes_me2c = AtomicU64::new(0);
let outcome = timeout(
Duration::from_millis(80),
process_me_writer_response(
MeResponse::Data {
flags: 0,
data: Bytes::from_static(&[0x11, 0x22]),
},
&mut writer,
ProtoTag::Intermediate,
&SecureRandom::new(),
&mut frame_buf,
&stats,
&user,
None,
0,
&bytes_me2c,
10_003,
false,
false,
),
)
.await
.expect("quota-none path must not wait on cross-mode lock");
assert!(outcome.is_ok());
drop(held_guard);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn adversarial_same_user_parallel_quota_limited_writes_stay_hard_capped() {
let stats = Arc::new(Stats::new());
let user = format!("middle-cross-matrix-adversarial-{}", std::process::id());
let limit = 64u64;
let bytes_me2c = Arc::new(AtomicU64::new(0));
let mut tasks = Vec::new();
for idx in 0..256u64 {
let stats = Arc::clone(&stats);
let bytes_me2c = Arc::clone(&bytes_me2c);
let user = user.clone();
tasks.push(tokio::spawn(async move {
let mut writer = make_crypto_writer(tokio::io::sink());
let mut frame_buf = Vec::new();
process_me_writer_response(
MeResponse::Data {
flags: 0,
data: Bytes::from_static(&[0xEE]),
},
&mut writer,
ProtoTag::Intermediate,
&SecureRandom::new(),
&mut frame_buf,
stats.as_ref(),
&user,
Some(limit),
0,
bytes_me2c.as_ref(),
11_000 + idx,
false,
false,
)
.await
}));
}
let mut ok = 0usize;
for task in tasks {
match task.await.expect("task must not panic") {
Ok(_) => ok += 1,
Err(ProxyError::DataQuotaExceeded { .. }) => {}
Err(other) => panic!("unexpected error in adversarial parallel case: {other:?}"),
}
}
assert_eq!(ok, limit as usize);
assert_eq!(stats.get_user_total_octets(&user), limit);
assert_eq!(bytes_me2c.load(Ordering::Relaxed), limit);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn integration_shared_lock_blocks_direct_relay_and_middle_relay_for_same_user() {
let user = format!("middle-cross-matrix-integration-{}", std::process::id());
let relay_lock = crate::proxy::relay::cross_mode_quota_user_lock_for_tests(&user);
let middle_lock = cross_mode_quota_user_lock_for_tests(&user);
assert!(
Arc::ptr_eq(&relay_lock, &middle_lock),
"relay and middle-relay must share the same cross-mode lock identity"
);
let held_guard = relay_lock
.try_lock()
.expect("test must hold shared cross-mode lock");
let stats = Stats::new();
let mut writer = make_crypto_writer(tokio::io::sink());
let mut frame_buf = Vec::new();
let bytes_me2c = AtomicU64::new(0);
let middle_blocked = timeout(
Duration::from_millis(25),
process_me_writer_response(
MeResponse::Data {
flags: 0,
data: Bytes::from_static(&[0x92]),
},
&mut writer,
ProtoTag::Intermediate,
&SecureRandom::new(),
&mut frame_buf,
&stats,
&user,
Some(1024),
0,
&bytes_me2c,
12_001,
false,
false,
),
)
.await;
assert!(middle_blocked.is_err());
drop(held_guard);
let middle_ready = timeout(
Duration::from_millis(250),
process_me_writer_response(
MeResponse::Data {
flags: 0,
data: Bytes::from_static(&[0x94]),
},
&mut writer,
ProtoTag::Intermediate,
&SecureRandom::new(),
&mut frame_buf,
&stats,
&user,
Some(1024),
0,
&bytes_me2c,
12_002,
false,
false,
),
)
.await
.expect("middle path must complete after release");
assert!(middle_ready.is_ok());
}
#[tokio::test]
async fn light_fuzz_mixed_payload_sizes_with_periodic_lock_holds_keeps_accounting_consistent() {
let stats = Stats::new();
let user = format!("middle-cross-matrix-fuzz-{}", std::process::id());
let bytes_me2c = AtomicU64::new(0);
let mut seed = 0xC0DE_1234_55AA_9988u64;
for case in 0..96u32 {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let hold = (seed & 0x03) == 0;
let mut held_lock = None;
let maybe_guard = if hold {
held_lock = Some(cross_mode_quota_user_lock_for_tests(&user));
Some(
held_lock
.as_ref()
.expect("held lock should be present")
.try_lock()
.expect("cross-mode lock should be acquirable in fuzz round"),
)
} else {
None
};
let payload_len = ((seed >> 8) as usize % 8) + 1;
let payload = vec![(seed & 0xff) as u8; payload_len];
let before = stats.get_user_total_octets(&user);
let mut writer = make_crypto_writer(tokio::io::sink());
let mut frame_buf = Vec::new();
let timed = timeout(
Duration::from_millis(20),
process_me_writer_response(
MeResponse::Data {
flags: 0,
data: Bytes::from(payload),
},
&mut writer,
ProtoTag::Intermediate,
&SecureRandom::new(),
&mut frame_buf,
&stats,
&user,
Some(1024),
0,
&bytes_me2c,
13_000 + case as u64,
false,
false,
),
)
.await;
if hold {
assert!(timed.is_err(), "held-lock fuzz round must block within timeout");
assert_eq!(stats.get_user_total_octets(&user), before);
} else {
let done = timed.expect("unheld fuzz round must complete in time");
assert!(done.is_ok());
}
drop(maybe_guard);
drop(held_lock);
assert_eq!(bytes_me2c.load(Ordering::Relaxed), stats.get_user_total_octets(&user));
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn stress_held_user_lock_does_not_block_other_users_me_to_client_writes() {
let held_user = format!("middle-cross-matrix-stress-held-{}", std::process::id());
let free_user = format!("middle-cross-matrix-stress-free-{}", std::process::id());
let held = cross_mode_quota_user_lock_for_tests(&held_user);
let held_guard = held
.try_lock()
.expect("test must hold lock for blocked user");
let mut tasks = Vec::new();
for idx in 0..64u64 {
let user = free_user.clone();
tasks.push(tokio::spawn(async move {
let stats = Stats::new();
let mut writer = make_crypto_writer(tokio::io::sink());
let mut frame_buf = Vec::new();
let bytes_me2c = AtomicU64::new(0);
process_me_writer_response(
MeResponse::Data {
flags: 0,
data: Bytes::from_static(&[0xA0]),
},
&mut writer,
ProtoTag::Intermediate,
&SecureRandom::new(),
&mut frame_buf,
&stats,
&user,
Some(1),
0,
&bytes_me2c,
14_000 + idx,
false,
false,
)
.await
}));
}
timeout(Duration::from_secs(2), async {
for task in tasks {
let done = task.await.expect("free-user task must not panic");
assert!(done.is_ok());
}
})
.await
.expect("free-user tasks should complete without waiting for held user's lock");
drop(held_guard);
}

View File

@ -0,0 +1,254 @@
use super::*;
use crate::crypto::{AesCtr, SecureRandom};
use crate::stats::Stats;
use crate::stream::CryptoWriter;
use bytes::Bytes;
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll, Waker};
use tokio::io::AsyncWrite;
use tokio::sync::Notify;
use tokio::time::{Duration, timeout};
fn make_crypto_writer<W>(writer: W) -> CryptoWriter<W>
where
W: tokio::io::AsyncWrite + Unpin,
{
let key = [0u8; 32];
let iv = 0u128;
CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024)
}
#[derive(Default)]
struct BlockingWriteState {
write_entered: AtomicBool,
released: AtomicBool,
write_waker: Mutex<Option<Waker>>,
write_entered_notify: Notify,
}
struct BlockingWrite {
state: Arc<BlockingWriteState>,
}
impl BlockingWrite {
fn new(state: Arc<BlockingWriteState>) -> Self {
Self { state }
}
}
impl AsyncWrite for BlockingWrite {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
self.state.write_entered.store(true, Ordering::Release);
self.state.write_entered_notify.notify_waiters();
if self.state.released.load(Ordering::Acquire) {
return Poll::Ready(Ok(buf.len()));
}
if let Ok(mut slot) = self.state.write_waker.lock() {
*slot = Some(cx.waker().clone());
}
Poll::Pending
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Poll::Ready(Ok(()))
}
}
async fn wait_until_blocking_write_entered(state: &Arc<BlockingWriteState>) {
for _ in 0..8 {
if state.write_entered.load(Ordering::Acquire) {
return;
}
let _ = timeout(Duration::from_millis(25), state.write_entered_notify.notified()).await;
}
panic!("blocking writer did not enter poll_write in bounded time");
}
fn release_blocking_write(state: &Arc<BlockingWriteState>) {
state.released.store(true, Ordering::Release);
if let Ok(mut slot) = state.write_waker.lock()
&& let Some(waker) = slot.take()
{
waker.wake();
}
}
#[tokio::test]
async fn adversarial_held_cross_mode_lock_blocks_me_to_client_quota_reservation_path() {
let stats = Stats::new();
let user = format!("middle-me2c-cross-mode-held-{}", std::process::id());
let held = cross_mode_quota_user_lock_for_tests(&user);
let held_guard = held
.try_lock()
.expect("test must hold shared cross-mode lock before ME->C write path");
let mut writer = make_crypto_writer(tokio::io::sink());
let mut frame_buf = Vec::new();
let bytes_me2c = AtomicU64::new(0);
let blocked = timeout(
Duration::from_millis(25),
process_me_writer_response(
MeResponse::Data {
flags: 0,
data: Bytes::from_static(&[0x41]),
},
&mut writer,
ProtoTag::Intermediate,
&SecureRandom::new(),
&mut frame_buf,
&stats,
&user,
Some(1024),
0,
&bytes_me2c,
9901,
false,
false,
),
)
.await;
assert!(
blocked.is_err(),
"ME->C quota reservation path must be serialized by held shared cross-mode lock"
);
drop(held_guard);
let released = timeout(
Duration::from_millis(250),
process_me_writer_response(
MeResponse::Data {
flags: 0,
data: Bytes::from_static(&[0x42]),
},
&mut writer,
ProtoTag::Intermediate,
&SecureRandom::new(),
&mut frame_buf,
&stats,
&user,
Some(1024),
0,
&bytes_me2c,
9902,
false,
false,
),
)
.await
.expect("ME->C write must complete after cross-mode lock release");
assert!(released.is_ok());
}
#[tokio::test]
async fn business_uncontended_cross_mode_lock_allows_me_to_client_quota_reservation() {
let stats = Stats::new();
let user = format!("middle-me2c-cross-mode-free-{}", std::process::id());
let mut writer = make_crypto_writer(tokio::io::sink());
let mut frame_buf = Vec::new();
let bytes_me2c = AtomicU64::new(0);
let outcome = timeout(
Duration::from_millis(250),
process_me_writer_response(
MeResponse::Data {
flags: 0,
data: Bytes::from_static(&[0x55, 0x66]),
},
&mut writer,
ProtoTag::Intermediate,
&SecureRandom::new(),
&mut frame_buf,
&stats,
&user,
Some(1024),
0,
&bytes_me2c,
9903,
false,
false,
),
)
.await
.expect("uncontended ME->C path should not stall");
assert!(outcome.is_ok());
assert_eq!(stats.get_user_total_octets(&user), 2);
assert_eq!(bytes_me2c.load(std::sync::atomic::Ordering::Relaxed), 2);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn adversarial_cross_mode_lock_is_released_before_me_to_client_write_await() {
let stats = Arc::new(Stats::new());
let user = format!("middle-me2c-lock-drop-before-write-{}", std::process::id());
let cross_mode_lock = cross_mode_quota_user_lock_for_tests(&user);
let bytes_me2c = Arc::new(AtomicU64::new(0));
let writer_state = Arc::new(BlockingWriteState::default());
let worker = {
let stats = Arc::clone(&stats);
let user = user.clone();
let cross_mode_lock = Arc::clone(&cross_mode_lock);
let bytes_me2c = Arc::clone(&bytes_me2c);
let writer_state = Arc::clone(&writer_state);
tokio::spawn(async move {
let mut writer = make_crypto_writer(BlockingWrite::new(writer_state));
let mut frame_buf = Vec::new();
let rng = SecureRandom::new();
process_me_writer_response_with_cross_mode_lock(
MeResponse::Data {
flags: 0,
data: Bytes::from_static(&[0xDE, 0xAD, 0xBE, 0xEF]),
},
&mut writer,
ProtoTag::Intermediate,
&rng,
&mut frame_buf,
stats.as_ref(),
&user,
Some(1024),
0,
Some(&cross_mode_lock),
bytes_me2c.as_ref(),
9910,
false,
false,
)
.await
})
};
wait_until_blocking_write_entered(&writer_state).await;
let acquired_guard = timeout(Duration::from_millis(40), cross_mode_lock.lock())
.await
.expect("cross-mode lock must be free while ME->C write is pending");
drop(acquired_guard);
release_blocking_write(&writer_state);
let result = timeout(Duration::from_millis(300), worker)
.await
.expect("ME->C worker timed out after releasing blocking writer")
.expect("ME->C worker must not panic");
assert!(result.is_ok());
assert_eq!(stats.get_user_total_octets(&user), 4);
assert_eq!(bytes_me2c.load(Ordering::Relaxed), 4);
}

View File

@ -0,0 +1,232 @@
use super::*;
use crate::crypto::{AesCtr, SecureRandom};
use crate::stats::Stats;
use crate::stream::CryptoWriter;
use bytes::Bytes;
use std::io;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::task::{Context, Poll, Waker};
use tokio::io::AsyncWrite;
use tokio::time::{Duration, timeout};
fn make_crypto_writer<W>(writer: W) -> CryptoWriter<W>
where
W: tokio::io::AsyncWrite + Unpin,
{
let key = [0u8; 32];
let iv = 0u128;
CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024)
}
#[derive(Default)]
struct GateState {
open: AtomicBool,
parked_waker: std::sync::Mutex<Option<Waker>>,
}
impl GateState {
fn open(&self) {
self.open.store(true, Ordering::Relaxed);
if let Ok(mut guard) = self.parked_waker.lock()
&& let Some(w) = guard.take()
{
w.wake();
}
}
fn has_waiter(&self) -> bool {
self.parked_waker
.lock()
.map(|guard| guard.is_some())
.unwrap_or(false)
}
}
#[derive(Default)]
struct GateWriter {
gate: Arc<GateState>,
}
impl GateWriter {
fn new(gate: Arc<GateState>) -> Self {
Self { gate }
}
}
impl AsyncWrite for GateWriter {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
if self.gate.open.load(Ordering::Relaxed) {
return Poll::Ready(Ok(buf.len()));
}
if let Ok(mut guard) = self.gate.parked_waker.lock() {
*guard = Some(cx.waker().clone());
}
Poll::Pending
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
}
struct FailingWriter;
impl AsyncWrite for FailingWriter {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &[u8],
) -> Poll<io::Result<usize>> {
Poll::Ready(Err(io::Error::new(
io::ErrorKind::BrokenPipe,
"injected writer failure",
)))
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn adversarial_same_user_slow_writer_must_not_hol_block_peer_connection() {
let stats = Stats::new();
let bytes_me2c = AtomicU64::new(0);
let rng = SecureRandom::new();
let quota_limit = Some(1024);
let user = "hol-quota-user";
let gate = Arc::new(GateState::default());
let mut blocked_writer = make_crypto_writer(GateWriter::new(Arc::clone(&gate)));
let slow_task = tokio::spawn(async move {
let mut frame_buf = Vec::new();
process_me_writer_response(
MeResponse::Data {
flags: 0,
data: Bytes::from_static(&[0x10, 0x20, 0x30, 0x40]),
},
&mut blocked_writer,
ProtoTag::Intermediate,
&rng,
&mut frame_buf,
&stats,
user,
quota_limit,
0,
&bytes_me2c,
7001,
false,
false,
)
.await
});
timeout(Duration::from_millis(100), async {
loop {
if gate.has_waiter() {
break;
}
tokio::task::yield_now().await;
}
})
.await
.expect("first writer must reach backpressure and park");
let stats_fast = Stats::new();
let bytes_fast = AtomicU64::new(0);
let rng_fast = SecureRandom::new();
let mut fast_writer = make_crypto_writer(tokio::io::sink());
let mut frame_buf_fast = Vec::new();
timeout(
Duration::from_millis(50),
process_me_writer_response(
MeResponse::Data {
flags: 0,
data: Bytes::from_static(&[0x41]),
},
&mut fast_writer,
ProtoTag::Intermediate,
&rng_fast,
&mut frame_buf_fast,
&stats_fast,
user,
quota_limit,
0,
&bytes_fast,
7002,
false,
false,
),
)
.await
.expect("peer connection must not be blocked by same-user stalled write")
.expect("fast peer write must succeed");
gate.open();
let slow_result = timeout(Duration::from_secs(1), slow_task)
.await
.expect("stalled task must complete once gate opens")
.expect("stalled task must not panic");
assert!(slow_result.is_ok());
}
#[tokio::test]
async fn negative_write_failure_rolls_back_pre_accounted_quota_and_forensics_bytes() {
let stats = Stats::new();
let user = "rollback-user";
stats.add_user_octets_from(user, 7);
let bytes_me2c = AtomicU64::new(0);
let rng = SecureRandom::new();
let mut writer = make_crypto_writer(FailingWriter);
let mut frame_buf = Vec::new();
let result = process_me_writer_response(
MeResponse::Data {
flags: 0,
data: Bytes::from_static(&[1, 2, 3, 4]),
},
&mut writer,
ProtoTag::Intermediate,
&rng,
&mut frame_buf,
&stats,
user,
Some(64),
0,
&bytes_me2c,
7003,
false,
false,
)
.await;
assert!(matches!(result, Err(ProxyError::Io(_))));
assert_eq!(
stats.get_user_total_octets(user),
7,
"failed client write must not overcharge user quota accounting"
);
assert_eq!(
bytes_me2c.load(Ordering::Relaxed),
0,
"failed client write must not inflate ME->C forensic byte counter"
);
}

View File

@ -3,7 +3,7 @@ use crate::crypto::AesCtr;
use crate::stats::Stats;
use crate::stream::{BufferPool, CryptoReader};
use std::sync::atomic::AtomicU64;
use std::sync::{Arc, Mutex, OnceLock};
use std::sync::Arc;
use tokio::io::AsyncWriteExt;
use tokio::io::duplex;
use tokio::time::{Duration as TokioDuration, Instant as TokioInstant, timeout};
@ -48,18 +48,6 @@ fn make_idle_policy(soft_ms: u64, hard_ms: u64, grace_ms: u64) -> RelayClientIdl
}
}
fn idle_pressure_test_lock() -> &'static Mutex<()> {
static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
TEST_LOCK.get_or_init(|| Mutex::new(()))
}
fn acquire_idle_pressure_test_lock() -> std::sync::MutexGuard<'static, ()> {
match idle_pressure_test_lock().lock() {
Ok(guard) => guard,
Err(poisoned) => poisoned.into_inner(),
}
}
#[tokio::test]
async fn idle_policy_soft_mark_then_hard_close_increments_reason_counters() {
let (reader, _writer) = duplex(1024);
@ -372,7 +360,7 @@ async fn stress_many_idle_sessions_fail_closed_without_hang() {
#[test]
fn pressure_evicts_oldest_idle_candidate_with_deterministic_ordering() {
let _guard = acquire_idle_pressure_test_lock();
let _guard = relay_idle_pressure_test_scope();
clear_relay_idle_pressure_state_for_testing();
let stats = Stats::new();
@ -402,7 +390,7 @@ fn pressure_evicts_oldest_idle_candidate_with_deterministic_ordering() {
#[test]
fn pressure_does_not_evict_without_new_pressure_signal() {
let _guard = acquire_idle_pressure_test_lock();
let _guard = relay_idle_pressure_test_scope();
clear_relay_idle_pressure_state_for_testing();
let stats = Stats::new();
@ -421,7 +409,7 @@ fn pressure_does_not_evict_without_new_pressure_signal() {
#[test]
fn stress_pressure_eviction_preserves_fifo_across_many_candidates() {
let _guard = acquire_idle_pressure_test_lock();
let _guard = relay_idle_pressure_test_scope();
clear_relay_idle_pressure_state_for_testing();
let stats = Stats::new();
@ -457,7 +445,7 @@ fn stress_pressure_eviction_preserves_fifo_across_many_candidates() {
#[test]
fn blackhat_single_pressure_event_must_not_evict_more_than_one_candidate() {
let _guard = acquire_idle_pressure_test_lock();
let _guard = relay_idle_pressure_test_scope();
clear_relay_idle_pressure_state_for_testing();
let stats = Stats::new();
@ -491,7 +479,7 @@ fn blackhat_single_pressure_event_must_not_evict_more_than_one_candidate() {
#[test]
fn blackhat_pressure_counter_must_track_global_budget_not_per_session_cursor() {
let _guard = acquire_idle_pressure_test_lock();
let _guard = relay_idle_pressure_test_scope();
clear_relay_idle_pressure_state_for_testing();
let stats = Stats::new();
@ -524,7 +512,7 @@ fn blackhat_pressure_counter_must_track_global_budget_not_per_session_cursor() {
#[test]
fn blackhat_stale_pressure_before_idle_mark_must_not_trigger_eviction() {
let _guard = acquire_idle_pressure_test_lock();
let _guard = relay_idle_pressure_test_scope();
clear_relay_idle_pressure_state_for_testing();
let stats = Stats::new();
@ -543,7 +531,7 @@ fn blackhat_stale_pressure_before_idle_mark_must_not_trigger_eviction() {
#[test]
fn blackhat_stale_pressure_must_not_evict_any_of_newly_marked_batch() {
let _guard = acquire_idle_pressure_test_lock();
let _guard = relay_idle_pressure_test_scope();
clear_relay_idle_pressure_state_for_testing();
let stats = Stats::new();
@ -575,7 +563,7 @@ fn blackhat_stale_pressure_must_not_evict_any_of_newly_marked_batch() {
#[test]
fn blackhat_stale_pressure_seen_without_candidates_must_be_globally_invalidated() {
let _guard = acquire_idle_pressure_test_lock();
let _guard = relay_idle_pressure_test_scope();
clear_relay_idle_pressure_state_for_testing();
let stats = Stats::new();
@ -601,7 +589,7 @@ fn blackhat_stale_pressure_seen_without_candidates_must_be_globally_invalidated(
#[test]
fn blackhat_stale_pressure_must_not_survive_candidate_churn() {
let _guard = acquire_idle_pressure_test_lock();
let _guard = relay_idle_pressure_test_scope();
clear_relay_idle_pressure_state_for_testing();
let stats = Stats::new();
@ -621,7 +609,7 @@ fn blackhat_stale_pressure_must_not_survive_candidate_churn() {
#[test]
fn blackhat_pressure_seq_saturation_must_not_disable_future_pressure_accounting() {
let _guard = acquire_idle_pressure_test_lock();
let _guard = relay_idle_pressure_test_scope();
clear_relay_idle_pressure_state_for_testing();
{
@ -646,7 +634,7 @@ fn blackhat_pressure_seq_saturation_must_not_disable_future_pressure_accounting(
#[test]
fn blackhat_pressure_seq_saturation_must_not_break_multiple_distinct_events() {
let _guard = acquire_idle_pressure_test_lock();
let _guard = relay_idle_pressure_test_scope();
clear_relay_idle_pressure_state_for_testing();
{
@ -673,7 +661,7 @@ fn blackhat_pressure_seq_saturation_must_not_break_multiple_distinct_events() {
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn integration_race_single_pressure_event_allows_at_most_one_eviction_under_parallel_claims()
{
let _guard = acquire_idle_pressure_test_lock();
let _guard = relay_idle_pressure_test_scope();
clear_relay_idle_pressure_state_for_testing();
let stats = Arc::new(Stats::new());
@ -738,7 +726,7 @@ async fn integration_race_single_pressure_event_allows_at_most_one_eviction_unde
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn integration_race_burst_pressure_with_churn_preserves_empty_set_invalidation_and_budget() {
let _guard = acquire_idle_pressure_test_lock();
let _guard = relay_idle_pressure_test_scope();
clear_relay_idle_pressure_state_for_testing();
let stats = Arc::new(Stats::new());

View File

@ -0,0 +1,59 @@
use super::*;
use std::panic::{AssertUnwindSafe, catch_unwind};
#[test]
fn blackhat_registry_poison_recovers_with_fail_closed_reset_and_pressure_accounting() {
let _guard = relay_idle_pressure_test_scope();
clear_relay_idle_pressure_state_for_testing();
let _ = catch_unwind(AssertUnwindSafe(|| {
let registry = relay_idle_candidate_registry();
let mut guard = registry
.lock()
.expect("registry lock must be acquired before poison");
guard.by_conn_id.insert(
999,
RelayIdleCandidateMeta {
mark_order_seq: 1,
mark_pressure_seq: 0,
},
);
guard.ordered.insert((1, 999));
panic!("intentional poison for idle-registry recovery");
}));
// Helper lock must recover from poison, reset stale state, and continue.
assert!(mark_relay_idle_candidate(42));
assert_eq!(oldest_relay_idle_candidate(), Some(42));
let before = relay_pressure_event_seq();
note_relay_pressure_event();
let after = relay_pressure_event_seq();
assert!(after > before, "pressure accounting must still advance after poison");
clear_relay_idle_pressure_state_for_testing();
}
#[test]
fn clear_state_helper_must_reset_poisoned_registry_for_deterministic_fifo_tests() {
let _guard = relay_idle_pressure_test_scope();
clear_relay_idle_pressure_state_for_testing();
let _ = catch_unwind(AssertUnwindSafe(|| {
let registry = relay_idle_candidate_registry();
let _guard = registry
.lock()
.expect("registry lock must be acquired before poison");
panic!("intentional poison while lock held");
}));
clear_relay_idle_pressure_state_for_testing();
assert_eq!(oldest_relay_idle_candidate(), None);
assert_eq!(relay_pressure_event_seq(), 0);
assert!(mark_relay_idle_candidate(7));
assert_eq!(oldest_relay_idle_candidate(), Some(7));
clear_relay_idle_pressure_state_for_testing();
}

View File

@ -0,0 +1,372 @@
use super::*;
use crate::crypto::{AesCtr, SecureRandom};
use crate::error::ProxyError;
use crate::stats::Stats;
use crate::stream::CryptoWriter;
use bytes::Bytes;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, OnceLock, Mutex};
use tokio::sync::Mutex as AsyncMutex;
use tokio::task::JoinSet;
use tokio::time::{Duration, timeout};
fn make_crypto_writer<W>(writer: W) -> CryptoWriter<W>
where
W: tokio::io::AsyncWrite + Unpin,
{
let key = [0u8; 32];
let iv = 0u128;
CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024)
}
fn lookup_test_lock() -> &'static Mutex<()> {
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
LOCK.get_or_init(|| Mutex::new(()))
}
#[tokio::test]
async fn positive_me2c_quota_counts_bytes_exactly_once() {
let _guard = lookup_test_lock().lock().unwrap();
let stats = Stats::new();
let user = format!("quota-middle-ext-positive-{}", std::process::id());
let lock = Arc::new(AsyncMutex::new(()));
let mut writer = make_crypto_writer(tokio::io::sink());
let mut frame_buf = Vec::new();
let bytes_me2c = AtomicU64::new(0);
let result = process_me_writer_response_with_cross_mode_lock(
MeResponse::Data {
flags: 0,
data: Bytes::from_static(&[1, 2, 3, 4, 5]),
},
&mut writer,
ProtoTag::Intermediate,
&SecureRandom::new(),
&mut frame_buf,
&stats,
&user,
Some(64),
0,
Some(&lock),
&bytes_me2c,
70_001,
false,
false,
)
.await;
assert!(result.is_ok());
assert_eq!(stats.get_user_total_octets(&user), 5);
assert_eq!(bytes_me2c.load(Ordering::Relaxed), 5);
}
#[tokio::test]
async fn negative_held_crossmode_lock_blocks_me2c_write() {
let _guard = lookup_test_lock().lock().unwrap();
let stats = Stats::new();
let user = format!("quota-middle-ext-negative-{}", std::process::id());
let lock = Arc::new(AsyncMutex::new(()));
let _held = lock.try_lock().expect("lock must be held");
let mut writer = make_crypto_writer(tokio::io::sink());
let mut frame_buf = Vec::new();
let bytes_me2c = AtomicU64::new(0);
let blocked = timeout(
Duration::from_millis(25),
process_me_writer_response_with_cross_mode_lock(
MeResponse::Data {
flags: 0,
data: Bytes::from_static(&[0xFE]),
},
&mut writer,
ProtoTag::Intermediate,
&SecureRandom::new(),
&mut frame_buf,
&stats,
&user,
Some(16),
0,
Some(&lock),
&bytes_me2c,
70_101,
false,
false,
),
)
.await;
assert!(blocked.is_err());
assert_eq!(stats.get_user_total_octets(&user), 0);
assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0);
}
#[tokio::test]
async fn edge_zero_quota_zero_payload_is_fail_closed() {
let _guard = lookup_test_lock().lock().unwrap();
let stats = Stats::new();
let user = format!("quota-middle-ext-edge-{}", std::process::id());
let lock = Arc::new(AsyncMutex::new(()));
let mut writer = make_crypto_writer(tokio::io::sink());
let mut frame_buf = Vec::new();
let bytes_me2c = AtomicU64::new(0);
let result = process_me_writer_response_with_cross_mode_lock(
MeResponse::Data {
flags: 0,
data: Bytes::new(),
},
&mut writer,
ProtoTag::Intermediate,
&SecureRandom::new(),
&mut frame_buf,
&stats,
&user,
Some(0),
0,
Some(&lock),
&bytes_me2c,
70_201,
false,
false,
)
.await;
assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. })));
assert_eq!(stats.get_user_total_octets(&user), 0);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn adversarial_parallel_me2c_race_falls_back_to_quota_error() {
let _guard = lookup_test_lock().lock().unwrap();
let stats = Arc::new(Stats::new());
let user = format!("quota-middle-ext-blackhat-{}", std::process::id());
let quota = 64u64;
let lock = Arc::new(AsyncMutex::new(()));
let bytes_me2c = Arc::new(AtomicU64::new(0));
let mut set = JoinSet::new();
for i in 0..256u64 {
let stats = Arc::clone(&stats);
let user = user.clone();
let lock = Arc::clone(&lock);
let bytes_me2c = Arc::clone(&bytes_me2c);
set.spawn(async move {
let mut writer = make_crypto_writer(tokio::io::sink());
let mut frame_buf = Vec::new();
let payload = vec![((i & 0xFF) as u8); (i % 4 + 1) as usize];
process_me_writer_response_with_cross_mode_lock(
MeResponse::Data {
flags: 0,
data: Bytes::from(payload),
},
&mut writer,
ProtoTag::Intermediate,
&SecureRandom::new(),
&mut frame_buf,
stats.as_ref(),
&user,
Some(quota),
0,
Some(&lock),
bytes_me2c.as_ref(),
70_301 + i,
false,
false,
)
.await
});
}
let mut succeeded = 0usize;
while let Some(done) = set.join_next().await {
match done.expect("task must not panic") {
Ok(_) => succeeded += 1,
Err(ProxyError::DataQuotaExceeded { .. }) => {}
Err(other) => panic!("unexpected error {other:?}"),
}
}
assert_eq!(stats.get_user_total_octets(&user), bytes_me2c.load(Ordering::Relaxed));
assert!(stats.get_user_total_octets(&user) <= quota);
assert!(succeeded <= quota as usize);
}
#[tokio::test]
async fn integration_shared_prefetched_lock_blocks_then_releases_writer() {
let stats = Stats::new();
let user = format!("quota-middle-ext-integration-{}", std::process::id());
let lock = Arc::new(AsyncMutex::new(()));
let held = lock
.try_lock()
.expect("integration test must hold prefetched lock first");
let mut writer = make_crypto_writer(tokio::io::sink());
let mut frame_buf = Vec::new();
let bytes_me2c = AtomicU64::new(0);
let blocked = timeout(
Duration::from_millis(25),
process_me_writer_response_with_cross_mode_lock(
MeResponse::Data {
flags: 0,
data: Bytes::from_static(&[0xA1]),
},
&mut writer,
ProtoTag::Intermediate,
&SecureRandom::new(),
&mut frame_buf,
&stats,
&user,
Some(8),
0,
Some(&lock),
&bytes_me2c,
70_360,
false,
false,
),
)
.await;
assert!(blocked.is_err());
drop(held);
let after_release = timeout(
Duration::from_millis(150),
process_me_writer_response_with_cross_mode_lock(
MeResponse::Data {
flags: 0,
data: Bytes::from_static(&[0xA2]),
},
&mut writer,
ProtoTag::Intermediate,
&SecureRandom::new(),
&mut frame_buf,
&stats,
&user,
Some(8),
0,
Some(&lock),
&bytes_me2c,
70_361,
false,
false,
),
)
.await
.expect("writer should progress once the shared lock is released");
assert!(after_release.is_ok());
}
#[tokio::test]
async fn light_fuzz_small_payloads_toggle_lock_state_stays_consistent() {
let _guard = lookup_test_lock().lock().unwrap();
let stats = Stats::new();
let user = format!("quota-middle-ext-fuzz-{}", std::process::id());
let mut seed = 0xCAFE_BABE_1234u64;
let bytes_me2c = AtomicU64::new(0);
for case in 0..48u32 {
seed ^= seed << 5;
seed ^= seed >> 12;
seed ^= seed << 13;
let hold = (seed & 0x1) == 0;
let lock = Arc::new(AsyncMutex::new(()));
let maybe_guard = if hold {
Some(lock.try_lock().unwrap())
} else {
None
};
let mut writer = make_crypto_writer(tokio::io::sink());
let mut frame_buf = Vec::new();
let result = timeout(
Duration::from_millis(30),
process_me_writer_response_with_cross_mode_lock(
MeResponse::Data {
flags: 0,
data: Bytes::from(vec![(seed & 0xFF) as u8; ((seed as usize % 5) + 1)]),
},
&mut writer,
ProtoTag::Intermediate,
&SecureRandom::new(),
&mut frame_buf,
&stats,
&user,
Some(128),
0,
Some(&lock),
&bytes_me2c,
70_401 + case as u64,
false,
false,
),
)
.await;
if hold {
assert!(result.is_err());
} else {
assert!(result.unwrap().is_ok());
}
drop(maybe_guard);
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn stress_parallel_free_users_during_held_user_lock_maintains_liveness() {
let _guard = lookup_test_lock().lock().unwrap();
let held = Arc::new(AsyncMutex::new(()));
let _held_guard = held.try_lock().unwrap();
let mut set = JoinSet::new();
for i in 0..48u64 {
set.spawn(async move {
let stats = Stats::new();
let user = format!("quota-middle-ext-stress-free-{i}");
let mut writer = make_crypto_writer(tokio::io::sink());
let mut frame_buf = Vec::new();
let bytes_me2c = AtomicU64::new(0);
let free_lock = Arc::new(AsyncMutex::new(()));
process_me_writer_response_with_cross_mode_lock(
MeResponse::Data {
flags: 0,
data: Bytes::from_static(&[0xEE]),
},
&mut writer,
ProtoTag::Intermediate,
&SecureRandom::new(),
&mut frame_buf,
&stats,
&user,
Some(1),
0,
Some(&free_lock),
&bytes_me2c,
70_500 + i,
false,
false,
)
.await
});
}
timeout(Duration::from_secs(2), async {
while let Some(task) = set.join_next().await {
task.unwrap().unwrap();
}
})
.await
.unwrap();
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,399 @@
use super::*;
use crate::crypto::{AesCtr, SecureRandom};
use crate::stats::Stats;
use crate::stream::CryptoWriter;
use bytes::Bytes;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex, OnceLock};
use tokio::sync::Mutex as AsyncMutex;
use tokio::task::JoinSet;
use tokio::time::{Duration, timeout};
fn make_crypto_writer<W>(writer: W) -> CryptoWriter<W>
where
W: tokio::io::AsyncWrite + Unpin,
{
let key = [0u8; 32];
let iv = 0u128;
CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024)
}
fn lookup_counter_test_lock() -> &'static Mutex<()> {
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
LOCK.get_or_init(|| Mutex::new(()))
}
#[tokio::test]
async fn positive_prefetched_cross_mode_lock_multi_frame_accounting_is_exact() {
let _guard = lookup_counter_test_lock()
.lock()
.unwrap_or_else(|poison| poison.into_inner());
let stats = Stats::new();
let user = format!("quota-extreme-positive-{}", std::process::id());
let lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
crate::proxy::quota_lock_registry::reset_cross_mode_quota_user_lock_lookup_count_for_tests();
let mut writer = make_crypto_writer(tokio::io::sink());
let mut frame_buf = Vec::new();
let bytes_me2c = AtomicU64::new(0);
for idx in 0..12u64 {
let payload = vec![0x5A; ((idx % 4) + 1) as usize];
let result = process_me_writer_response_with_cross_mode_lock(
MeResponse::Data {
flags: 0,
data: Bytes::from(payload),
},
&mut writer,
ProtoTag::Intermediate,
&SecureRandom::new(),
&mut frame_buf,
&stats,
&user,
Some(512),
0,
Some(&lock),
&bytes_me2c,
31_000 + idx,
false,
false,
)
.await;
assert!(result.is_ok());
}
assert_eq!(
crate::proxy::quota_lock_registry::cross_mode_quota_user_lock_lookup_count_for_user_for_tests(&user),
0,
"prefetched lock path must avoid hot-path registry lookups"
);
assert_eq!(
stats.get_user_total_octets(&user),
bytes_me2c.load(Ordering::Relaxed),
"forensics and quota accounting must remain synchronized"
);
}
#[tokio::test]
async fn negative_held_prefetched_lock_blocks_writer_without_accounting_mutation() {
let _guard = lookup_counter_test_lock()
.lock()
.unwrap_or_else(|poison| poison.into_inner());
let stats = Stats::new();
let user = format!("quota-extreme-negative-{}", std::process::id());
let lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
let held_guard = lock
.try_lock()
.expect("test must hold lock before calling ME->C writer");
let mut writer = make_crypto_writer(tokio::io::sink());
let mut frame_buf = Vec::new();
let bytes_me2c = AtomicU64::new(0);
let blocked = timeout(
Duration::from_millis(25),
process_me_writer_response_with_cross_mode_lock(
MeResponse::Data {
flags: 0,
data: Bytes::from_static(&[1, 2, 3]),
},
&mut writer,
ProtoTag::Intermediate,
&SecureRandom::new(),
&mut frame_buf,
&stats,
&user,
Some(64),
0,
Some(&lock),
&bytes_me2c,
31_100,
false,
false,
),
)
.await;
assert!(blocked.is_err());
assert_eq!(stats.get_user_total_octets(&user), 0);
assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0);
drop(held_guard);
}
#[tokio::test]
async fn edge_zero_quota_and_zero_payload_is_fail_closed() {
let _guard = lookup_counter_test_lock()
.lock()
.unwrap_or_else(|poison| poison.into_inner());
let stats = Stats::new();
let user = format!("quota-extreme-edge-{}", std::process::id());
let lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
let mut writer = make_crypto_writer(tokio::io::sink());
let mut frame_buf = Vec::new();
let bytes_me2c = AtomicU64::new(0);
let result = process_me_writer_response_with_cross_mode_lock(
MeResponse::Data {
flags: 0,
data: Bytes::new(),
},
&mut writer,
ProtoTag::Intermediate,
&SecureRandom::new(),
&mut frame_buf,
&stats,
&user,
Some(0),
0,
Some(&lock),
&bytes_me2c,
31_200,
false,
false,
)
.await;
assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. })));
assert_eq!(stats.get_user_total_octets(&user), 0);
assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn adversarial_blackhat_parallel_quota_race_never_overshoots_soft_cap() {
let _guard = lookup_counter_test_lock()
.lock()
.unwrap_or_else(|poison| poison.into_inner());
let stats = Arc::new(Stats::new());
let user = format!("quota-extreme-blackhat-{}", std::process::id());
let quota = 80u64;
let overshoot = 7u64;
let soft_limit = quota + overshoot;
let lock = Arc::new(crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user));
let bytes_me2c = Arc::new(AtomicU64::new(0));
let mut set = JoinSet::new();
for idx in 0..256u64 {
let stats = Arc::clone(&stats);
let user = user.clone();
let lock = Arc::clone(&lock);
let bytes_me2c = Arc::clone(&bytes_me2c);
set.spawn(async move {
let mut writer = make_crypto_writer(tokio::io::sink());
let mut frame_buf = Vec::new();
let len = ((idx % 5) + 1) as usize;
let payload = vec![0xAA; len];
process_me_writer_response_with_cross_mode_lock(
MeResponse::Data {
flags: 0,
data: Bytes::from(payload),
},
&mut writer,
ProtoTag::Intermediate,
&SecureRandom::new(),
&mut frame_buf,
stats.as_ref(),
&user,
Some(quota),
overshoot,
Some(&lock),
bytes_me2c.as_ref(),
31_300 + idx,
false,
false,
)
.await
});
}
while let Some(done) = set.join_next().await {
match done.expect("task must not panic") {
Ok(_) | Err(ProxyError::DataQuotaExceeded { .. }) => {}
Err(other) => panic!("unexpected error variant under black-hat race: {other:?}"),
}
}
let total = stats.get_user_total_octets(&user);
assert!(
total <= soft_limit,
"parallel adversarial race must stay under soft cap"
);
assert_eq!(bytes_me2c.load(Ordering::Relaxed), total);
}
#[tokio::test]
async fn integration_without_prefetched_lock_uses_registry_lookup_path() {
let _guard = lookup_counter_test_lock()
.lock()
.unwrap_or_else(|poison| poison.into_inner());
let stats = Stats::new();
let user = format!("quota-extreme-integration-{}", std::process::id());
crate::proxy::quota_lock_registry::reset_cross_mode_quota_user_lock_lookup_count_for_tests();
let mut writer = make_crypto_writer(tokio::io::sink());
let mut frame_buf = Vec::new();
let bytes_me2c = AtomicU64::new(0);
for idx in 0..3u64 {
let result = process_me_writer_response_with_cross_mode_lock(
MeResponse::Data {
flags: 0,
data: Bytes::from_static(&[0x41]),
},
&mut writer,
ProtoTag::Intermediate,
&SecureRandom::new(),
&mut frame_buf,
&stats,
&user,
Some(16),
0,
None,
&bytes_me2c,
31_400 + idx,
false,
false,
)
.await;
assert!(result.is_ok());
}
assert_eq!(
crate::proxy::quota_lock_registry::cross_mode_quota_user_lock_lookup_count_for_user_for_tests(&user),
3,
"control path should perform one lock-registry lookup per call"
);
}
#[tokio::test]
async fn light_fuzz_quota_matrix_preserves_fail_closed_accounting() {
let _guard = lookup_counter_test_lock()
.lock()
.unwrap_or_else(|poison| poison.into_inner());
let stats = Stats::new();
let user = format!("quota-extreme-fuzz-{}", std::process::id());
let lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
let bytes_me2c = AtomicU64::new(0);
let mut seed = 0xA11C_55EE_2026_0323u64;
for idx in 0..512u64 {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let quota = 24 + (seed & 0x3f);
let overshoot = (seed >> 13) & 0x0f;
let len = ((seed >> 19) & 0x07) + 1;
let mut writer = make_crypto_writer(tokio::io::sink());
let mut frame_buf = Vec::new();
let before = stats.get_user_total_octets(&user);
let result = process_me_writer_response_with_cross_mode_lock(
MeResponse::Data {
flags: 0,
data: Bytes::from(vec![0x11; len as usize]),
},
&mut writer,
ProtoTag::Intermediate,
&SecureRandom::new(),
&mut frame_buf,
&stats,
&user,
Some(quota),
overshoot,
Some(&lock),
&bytes_me2c,
31_500 + idx,
false,
false,
)
.await;
let after = stats.get_user_total_octets(&user);
if result.is_ok() {
assert!(after >= before);
} else {
assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. })));
assert_eq!(after, before);
}
assert_eq!(bytes_me2c.load(Ordering::Relaxed), after);
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn stress_prefetched_lock_high_fanout_exact_quota_success_count() {
let _guard = lookup_counter_test_lock()
.lock()
.unwrap_or_else(|poison| poison.into_inner());
let stats = Arc::new(Stats::new());
let user = format!("quota-extreme-stress-{}", std::process::id());
let quota = 96u64;
let lock: Arc<AsyncMutex<()>> = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
let bytes_me2c = Arc::new(AtomicU64::new(0));
crate::proxy::quota_lock_registry::reset_cross_mode_quota_user_lock_lookup_count_for_tests();
let mut set = JoinSet::new();
for idx in 0..384u64 {
let stats = Arc::clone(&stats);
let user = user.clone();
let lock = Arc::clone(&lock);
let bytes_me2c = Arc::clone(&bytes_me2c);
set.spawn(async move {
let mut writer = make_crypto_writer(tokio::io::sink());
let mut frame_buf = Vec::new();
process_me_writer_response_with_cross_mode_lock(
MeResponse::Data {
flags: 0,
data: Bytes::from_static(&[0xFF]),
},
&mut writer,
ProtoTag::Intermediate,
&SecureRandom::new(),
&mut frame_buf,
stats.as_ref(),
&user,
Some(quota),
0,
Some(&lock),
bytes_me2c.as_ref(),
31_600 + idx,
false,
false,
)
.await
});
}
let mut success = 0usize;
while let Some(done) = set.join_next().await {
match done.expect("task must not panic") {
Ok(_) => success += 1,
Err(ProxyError::DataQuotaExceeded { .. }) => {}
Err(other) => panic!("unexpected error variant in stress fanout: {other:?}"),
}
}
assert_eq!(success, quota as usize);
assert_eq!(stats.get_user_total_octets(&user), quota);
assert_eq!(bytes_me2c.load(Ordering::Relaxed), quota);
assert_eq!(
crate::proxy::quota_lock_registry::cross_mode_quota_user_lock_lookup_count_for_user_for_tests(&user),
0,
"stress prefetched path must not use lock registry lookups"
);
}

View File

@ -0,0 +1,361 @@
use super::*;
use crate::crypto::AesCtr;
use crate::stats::Stats;
use crate::stream::{BufferPool, CryptoReader};
use std::sync::Arc;
use std::sync::atomic::AtomicU64;
use std::time::Instant;
use tokio::io::{AsyncRead, AsyncWriteExt, duplex};
use tokio::task::JoinSet;
use tokio::time::{Duration as TokioDuration, sleep};
fn make_crypto_reader<T>(reader: T) -> CryptoReader<T>
where
T: AsyncRead + Unpin + Send + 'static,
{
let key = [0u8; 32];
let iv = 0u128;
CryptoReader::new(reader, AesCtr::new(&key, iv))
}
fn encrypt_for_reader(plaintext: &[u8]) -> Vec<u8> {
let key = [0u8; 32];
let iv = 0u128;
let mut cipher = AesCtr::new(&key, iv);
cipher.encrypt(plaintext)
}
fn make_forensics(conn_id: u64, started_at: Instant) -> RelayForensicsState {
RelayForensicsState {
trace_id: 0xB200_0000 + conn_id,
conn_id,
user: format!("tiny-frame-debt-concurrency-user-{conn_id}"),
peer: "127.0.0.1:50000".parse().expect("peer parse must succeed"),
peer_hash: hash_ip("127.0.0.1".parse().expect("ip parse must succeed")),
started_at,
bytes_c2me: 0,
bytes_me2c: Arc::new(AtomicU64::new(0)),
desync_all_full: false,
}
}
fn make_enabled_idle_policy() -> RelayClientIdlePolicy {
RelayClientIdlePolicy {
enabled: true,
soft_idle: Duration::from_millis(50),
hard_idle: Duration::from_millis(120),
grace_after_downstream_activity: Duration::from_secs(0),
legacy_frame_read_timeout: Duration::from_millis(50),
}
}
async fn read_once(
crypto_reader: &mut CryptoReader<tokio::io::DuplexStream>,
proto: ProtoTag,
forensics: &RelayForensicsState,
frame_counter: &mut u64,
idle_state: &mut RelayClientIdleState,
) -> Result<Option<(PooledBuffer, bool)>> {
let buffer_pool = Arc::new(BufferPool::new());
let stats = Stats::new();
let idle_policy = make_enabled_idle_policy();
let last_downstream_activity_ms = AtomicU64::new(0);
read_client_payload_with_idle_policy(
crypto_reader,
proto,
1024,
&buffer_pool,
forensics,
frame_counter,
&stats,
&idle_policy,
idle_state,
&last_downstream_activity_ms,
forensics.started_at,
)
.await
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn stress_parallel_pure_tiny_floods_all_fail_closed() {
let mut set = JoinSet::new();
for idx in 0..32u64 {
set.spawn(async move {
let (reader, mut writer) = duplex(4096);
let mut crypto_reader = make_crypto_reader(reader);
let started = Instant::now();
let forensics = make_forensics(1000 + idx, started);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(started);
let flood_plaintext = vec![0u8; 1024];
let flood_encrypted = encrypt_for_reader(&flood_plaintext);
writer.write_all(&flood_encrypted).await.unwrap();
drop(writer);
let result = run_relay_test_step_timeout(
"tiny flood task",
read_once(
&mut crypto_reader,
ProtoTag::Abridged,
&forensics,
&mut frame_counter,
&mut idle_state,
),
)
.await;
assert!(matches!(result, Err(ProxyError::Proxy(_))));
assert_eq!(frame_counter, 0);
});
}
while let Some(result) = set.join_next().await {
result.expect("parallel tiny flood worker must not panic");
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn stress_parallel_benign_tiny_burst_then_real_all_pass() {
let mut set = JoinSet::new();
for idx in 0..24u64 {
set.spawn(async move {
let (reader, mut writer) = duplex(2048);
let mut crypto_reader = make_crypto_reader(reader);
let started = Instant::now();
let forensics = make_forensics(2000 + idx, started);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(started);
let payload = [idx as u8, 2, 3, 4];
let mut plaintext = Vec::with_capacity(20);
for _ in 0..6 {
plaintext.push(0x00);
}
plaintext.push(0x01);
plaintext.extend_from_slice(&payload);
let encrypted = encrypt_for_reader(&plaintext);
writer.write_all(&encrypted).await.unwrap();
let result = run_relay_test_step_timeout(
"benign tiny burst read",
read_once(
&mut crypto_reader,
ProtoTag::Abridged,
&forensics,
&mut frame_counter,
&mut idle_state,
),
)
.await
.expect("benign payload must parse")
.expect("benign payload must return frame");
assert_eq!(result.0.as_ref(), &payload);
assert_eq!(frame_counter, 1);
});
}
while let Some(result) = set.join_next().await {
result.expect("parallel benign worker must not panic");
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn adversarial_lockstep_alternating_attack_under_jitter_closes() {
let mut set = JoinSet::new();
for idx in 0..12u64 {
set.spawn(async move {
let (reader, mut writer) = duplex(8192);
let mut crypto_reader = make_crypto_reader(reader);
let started = Instant::now();
let forensics = make_forensics(3000 + idx, started);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(started);
let mut plaintext = Vec::with_capacity(2000);
for n in 0..180u8 {
plaintext.push(0x00);
plaintext.push(0x01);
plaintext.extend_from_slice(&[n, n ^ 0x21, n ^ 0x42, n ^ 0x84]);
}
let encrypted = encrypt_for_reader(&plaintext);
let writer_task = tokio::spawn(async move {
for chunk in encrypted.chunks(17) {
writer.write_all(chunk).await.unwrap();
sleep(TokioDuration::from_millis(1)).await;
}
drop(writer);
});
let mut closed = false;
for _ in 0..220 {
let result = run_relay_test_step_timeout(
"alternating jitter read step",
read_once(
&mut crypto_reader,
ProtoTag::Abridged,
&forensics,
&mut frame_counter,
&mut idle_state,
),
)
.await;
match result {
Ok(Some((_payload, _))) => {}
Err(ProxyError::Proxy(_)) => {
closed = true;
break;
}
Ok(None) => break,
Err(other) => panic!("unexpected error in alternating jitter case: {other}"),
}
}
writer_task.await.expect("writer jitter task must not panic");
assert!(closed, "alternating attack must close before EOF");
});
}
while let Some(result) = set.join_next().await {
result.expect("alternating jitter worker must not panic");
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn integration_mixed_population_attackers_close_benign_survive() {
let mut set = JoinSet::new();
for idx in 0..20u64 {
set.spawn(async move {
let (reader, mut writer) = duplex(4096);
let mut crypto_reader = make_crypto_reader(reader);
let started = Instant::now();
let forensics = make_forensics(4000 + idx, started);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(started);
if idx % 2 == 0 {
let mut plaintext = Vec::with_capacity(1280);
for n in 0..140u8 {
plaintext.push(0x00);
plaintext.push(0x01);
plaintext.extend_from_slice(&[n, n, n, n]);
}
writer.write_all(&encrypt_for_reader(&plaintext)).await.unwrap();
drop(writer);
let mut closed = false;
for _ in 0..200 {
match read_once(
&mut crypto_reader,
ProtoTag::Abridged,
&forensics,
&mut frame_counter,
&mut idle_state,
)
.await
{
Ok(Some(_)) => {}
Err(ProxyError::Proxy(_)) => {
closed = true;
break;
}
Ok(None) => break,
Err(other) => panic!("unexpected attacker error: {other}"),
}
}
assert!(closed, "attacker session must fail closed");
} else {
let payload = [1u8, 9, 8, 7];
let mut plaintext = Vec::new();
for _ in 0..4 {
plaintext.push(0x00);
}
plaintext.push(0x01);
plaintext.extend_from_slice(&payload);
writer.write_all(&encrypt_for_reader(&plaintext)).await.unwrap();
let got = read_once(
&mut crypto_reader,
ProtoTag::Abridged,
&forensics,
&mut frame_counter,
&mut idle_state,
)
.await
.expect("benign session must parse")
.expect("benign session must return a frame");
assert_eq!(got.0.as_ref(), &payload);
}
});
}
while let Some(result) = set.join_next().await {
result.expect("mixed-population worker must not panic");
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn light_fuzz_parallel_patterns_no_hang_or_panic() {
let mut set = JoinSet::new();
for case in 0..40u64 {
set.spawn(async move {
let (reader, mut writer) = duplex(8192);
let mut crypto_reader = make_crypto_reader(reader);
let started = Instant::now();
let forensics = make_forensics(5000 + case, started);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(started);
let mut seed = 0x9E37_79B9u64 ^ (case << 8);
let mut plaintext = Vec::with_capacity(2048);
for _ in 0..256 {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let is_tiny = (seed & 1) == 0;
if is_tiny {
plaintext.push(0x00);
} else {
plaintext.push(0x01);
plaintext.extend_from_slice(&[(seed >> 8) as u8, 2, 3, 4]);
}
}
writer.write_all(&encrypt_for_reader(&plaintext)).await.unwrap();
drop(writer);
for _ in 0..320 {
let step = run_relay_test_step_timeout(
"fuzz case read step",
read_once(
&mut crypto_reader,
ProtoTag::Abridged,
&forensics,
&mut frame_counter,
&mut idle_state,
),
)
.await;
match step {
Ok(Some(_)) => {}
Err(ProxyError::Proxy(_)) => break,
Ok(None) => break,
Err(other) => panic!("unexpected fuzz case error: {other}"),
}
}
});
}
while let Some(result) = set.join_next().await {
result.expect("fuzz worker must not panic");
}
}

View File

@ -0,0 +1,425 @@
use super::*;
use crate::crypto::AesCtr;
use crate::stats::Stats;
use crate::stream::{BufferPool, CryptoReader, PooledBuffer};
use std::sync::Arc;
use std::sync::atomic::AtomicU64;
use std::time::Instant;
use tokio::io::{AsyncRead, AsyncWriteExt, duplex};
use tokio::time::{Duration as TokioDuration, sleep};
fn make_crypto_reader<T>(reader: T) -> CryptoReader<T>
where
T: AsyncRead + Unpin + Send + 'static,
{
let key = [0u8; 32];
let iv = 0u128;
CryptoReader::new(reader, AesCtr::new(&key, iv))
}
fn encrypt_for_reader(plaintext: &[u8]) -> Vec<u8> {
let key = [0u8; 32];
let iv = 0u128;
let mut cipher = AesCtr::new(&key, iv);
cipher.encrypt(plaintext)
}
fn make_forensics(conn_id: u64, started_at: Instant) -> RelayForensicsState {
RelayForensicsState {
trace_id: 0xB300_0000 + conn_id,
conn_id,
user: format!("tiny-frame-debt-proto-chunk-user-{conn_id}"),
peer: "127.0.0.1:50000".parse().expect("peer parse must succeed"),
peer_hash: hash_ip("127.0.0.1".parse().expect("ip parse must succeed")),
started_at,
bytes_c2me: 0,
bytes_me2c: Arc::new(AtomicU64::new(0)),
desync_all_full: false,
}
}
fn make_enabled_idle_policy() -> RelayClientIdlePolicy {
RelayClientIdlePolicy {
enabled: true,
soft_idle: Duration::from_millis(50),
hard_idle: Duration::from_millis(120),
grace_after_downstream_activity: Duration::from_secs(0),
legacy_frame_read_timeout: Duration::from_millis(50),
}
}
fn append_tiny_frame(plaintext: &mut Vec<u8>, proto: ProtoTag) {
match proto {
ProtoTag::Abridged => plaintext.push(0x00),
ProtoTag::Intermediate | ProtoTag::Secure => plaintext.extend_from_slice(&0u32.to_le_bytes()),
}
}
fn append_real_frame(plaintext: &mut Vec<u8>, proto: ProtoTag, payload: [u8; 4]) {
match proto {
ProtoTag::Abridged => {
plaintext.push(0x01);
plaintext.extend_from_slice(&payload);
}
ProtoTag::Intermediate | ProtoTag::Secure => {
plaintext.extend_from_slice(&4u32.to_le_bytes());
plaintext.extend_from_slice(&payload);
}
}
}
async fn write_chunked_with_jitter(
writer: &mut tokio::io::DuplexStream,
bytes: &[u8],
mut seed: u64,
) {
let mut offset = 0usize;
while offset < bytes.len() {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let chunk_len = 1 + ((seed as usize) & 0x1f);
let end = (offset + chunk_len).min(bytes.len());
writer.write_all(&bytes[offset..end]).await.unwrap();
let delay_ms = ((seed >> 16) % 3) as u64;
if delay_ms > 0 {
sleep(TokioDuration::from_millis(delay_ms)).await;
}
offset = end;
}
}
async fn read_once_with_state(
crypto_reader: &mut CryptoReader<tokio::io::DuplexStream>,
proto: ProtoTag,
forensics: &RelayForensicsState,
frame_counter: &mut u64,
idle_state: &mut RelayClientIdleState,
) -> Result<Option<(PooledBuffer, bool)>> {
let buffer_pool = Arc::new(BufferPool::new());
let stats = Stats::new();
let idle_policy = make_enabled_idle_policy();
let last_downstream_activity_ms = AtomicU64::new(0);
read_client_payload_with_idle_policy(
crypto_reader,
proto,
1024,
&buffer_pool,
forensics,
frame_counter,
&stats,
&idle_policy,
idle_state,
&last_downstream_activity_ms,
forensics.started_at,
)
.await
}
fn is_fail_closed_outcome(result: &Result<Option<(PooledBuffer, bool)>>) -> bool {
matches!(result, Err(ProxyError::Proxy(_)))
|| matches!(result, Err(ProxyError::Io(e)) if e.kind() == std::io::ErrorKind::TimedOut)
}
#[tokio::test]
async fn intermediate_chunked_zero_flood_fail_closed() {
let (reader, mut writer) = duplex(4096);
let mut crypto_reader = make_crypto_reader(reader);
let started = Instant::now();
let forensics = make_forensics(6101, started);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(started);
let mut plaintext = Vec::with_capacity(4 * 256);
for _ in 0..256 {
append_tiny_frame(&mut plaintext, ProtoTag::Intermediate);
}
let encrypted = encrypt_for_reader(&plaintext);
write_chunked_with_jitter(&mut writer, &encrypted, 0x1111_2222).await;
drop(writer);
let result = run_relay_test_step_timeout(
"intermediate flood read",
read_once_with_state(
&mut crypto_reader,
ProtoTag::Intermediate,
&forensics,
&mut frame_counter,
&mut idle_state,
),
)
.await;
assert!(
is_fail_closed_outcome(&result),
"zero-length flood must fail closed via debt guard or idle timeout"
);
assert_eq!(frame_counter, 0);
}
#[tokio::test]
async fn secure_chunked_zero_flood_fail_closed() {
let (reader, mut writer) = duplex(4096);
let mut crypto_reader = make_crypto_reader(reader);
let started = Instant::now();
let forensics = make_forensics(6102, started);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(started);
let mut plaintext = Vec::with_capacity(4 * 256);
for _ in 0..256 {
append_tiny_frame(&mut plaintext, ProtoTag::Secure);
}
let encrypted = encrypt_for_reader(&plaintext);
write_chunked_with_jitter(&mut writer, &encrypted, 0x3333_4444).await;
drop(writer);
let result = run_relay_test_step_timeout(
"secure flood read",
read_once_with_state(
&mut crypto_reader,
ProtoTag::Secure,
&forensics,
&mut frame_counter,
&mut idle_state,
),
)
.await;
assert!(
is_fail_closed_outcome(&result),
"secure zero-length flood must fail closed via debt guard or idle timeout"
);
assert_eq!(frame_counter, 0);
}
#[tokio::test]
async fn intermediate_chunked_alternating_attack_closes_before_eof() {
let (reader, mut writer) = duplex(8192);
let mut crypto_reader = make_crypto_reader(reader);
let started = Instant::now();
let forensics = make_forensics(6103, started);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(started);
let mut plaintext = Vec::with_capacity(8 * 200);
for n in 0..180u8 {
append_tiny_frame(&mut plaintext, ProtoTag::Intermediate);
append_real_frame(&mut plaintext, ProtoTag::Intermediate, [n, n ^ 1, n ^ 2, n ^ 3]);
}
let encrypted = encrypt_for_reader(&plaintext);
let writer_task = tokio::spawn(async move {
write_chunked_with_jitter(&mut writer, &encrypted, 0x5555_6666).await;
drop(writer);
});
let mut closed = false;
for _ in 0..240 {
let step = run_relay_test_step_timeout(
"intermediate alternating read step",
read_once_with_state(
&mut crypto_reader,
ProtoTag::Intermediate,
&forensics,
&mut frame_counter,
&mut idle_state,
),
)
.await;
match step {
Ok(Some(_)) => {}
Err(ProxyError::Proxy(_)) => {
closed = true;
break;
}
Ok(None) => break,
Err(other) => panic!("unexpected intermediate alternating error: {other}"),
}
}
writer_task.await.expect("intermediate writer task must not panic");
assert!(closed, "intermediate alternating attack must fail closed");
}
#[tokio::test]
async fn secure_chunked_alternating_attack_closes_before_eof() {
let (reader, mut writer) = duplex(8192);
let mut crypto_reader = make_crypto_reader(reader);
let started = Instant::now();
let forensics = make_forensics(6104, started);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(started);
let mut plaintext = Vec::with_capacity(8 * 200);
for n in 0..180u8 {
append_tiny_frame(&mut plaintext, ProtoTag::Secure);
append_real_frame(&mut plaintext, ProtoTag::Secure, [n, n ^ 7, n ^ 11, n ^ 19]);
}
let encrypted = encrypt_for_reader(&plaintext);
let writer_task = tokio::spawn(async move {
write_chunked_with_jitter(&mut writer, &encrypted, 0x7777_8888).await;
drop(writer);
});
let mut closed = false;
for _ in 0..240 {
let step = run_relay_test_step_timeout(
"secure alternating read step",
read_once_with_state(
&mut crypto_reader,
ProtoTag::Secure,
&forensics,
&mut frame_counter,
&mut idle_state,
),
)
.await;
match step {
Ok(Some(_)) => {}
Err(ProxyError::Proxy(_)) => {
closed = true;
break;
}
Ok(None) => break,
Err(other) => panic!("unexpected secure alternating error: {other}"),
}
}
writer_task.await.expect("secure writer task must not panic");
assert!(closed, "secure alternating attack must fail closed");
}
#[tokio::test]
async fn intermediate_chunked_safe_small_burst_still_returns_real_frame() {
let (reader, mut writer) = duplex(1024);
let mut crypto_reader = make_crypto_reader(reader);
let started = Instant::now();
let forensics = make_forensics(6105, started);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(started);
let payload = [9u8, 8, 7, 6];
let mut plaintext = Vec::new();
for _ in 0..7 {
append_tiny_frame(&mut plaintext, ProtoTag::Intermediate);
}
append_real_frame(&mut plaintext, ProtoTag::Intermediate, payload);
let encrypted = encrypt_for_reader(&plaintext);
write_chunked_with_jitter(&mut writer, &encrypted, 0xAAAA_BBBB).await;
let result = read_once_with_state(
&mut crypto_reader,
ProtoTag::Intermediate,
&forensics,
&mut frame_counter,
&mut idle_state,
)
.await
.expect("intermediate safe burst should parse")
.expect("intermediate safe burst should return a frame");
assert_eq!(result.0.as_ref(), &payload);
assert_eq!(frame_counter, 1);
}
#[tokio::test]
async fn secure_chunked_safe_small_burst_still_returns_real_frame() {
let (reader, mut writer) = duplex(1024);
let mut crypto_reader = make_crypto_reader(reader);
let started = Instant::now();
let forensics = make_forensics(6106, started);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(started);
let payload = [3u8, 1, 4, 1];
let mut plaintext = Vec::new();
for _ in 0..7 {
append_tiny_frame(&mut plaintext, ProtoTag::Secure);
}
append_real_frame(&mut plaintext, ProtoTag::Secure, payload);
let encrypted = encrypt_for_reader(&plaintext);
write_chunked_with_jitter(&mut writer, &encrypted, 0xCCCC_DDDD).await;
let result = read_once_with_state(
&mut crypto_reader,
ProtoTag::Secure,
&forensics,
&mut frame_counter,
&mut idle_state,
)
.await
.expect("secure safe burst should parse")
.expect("secure safe burst should return a frame");
assert_eq!(result.0.as_ref(), &payload);
assert_eq!(frame_counter, 1);
}
#[tokio::test]
async fn light_fuzz_proto_chunking_outcomes_are_bounded() {
let mut seed = 0xDEAD_BEEF_2026_0322u64;
for case in 0..48u64 {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let proto = if (seed & 1) == 0 {
ProtoTag::Intermediate
} else {
ProtoTag::Secure
};
let (reader, mut writer) = duplex(8192);
let mut crypto_reader = make_crypto_reader(reader);
let started = Instant::now();
let forensics = make_forensics(6200 + case, started);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(started);
let mut stream = Vec::new();
let mut local_seed = seed ^ case;
for _ in 0..220 {
local_seed ^= local_seed << 7;
local_seed ^= local_seed >> 9;
local_seed ^= local_seed << 8;
if (local_seed & 1) == 0 {
append_tiny_frame(&mut stream, proto);
} else {
let b = (local_seed >> 8) as u8;
append_real_frame(&mut stream, proto, [b, b ^ 0x12, b ^ 0x24, b ^ 0x48]);
}
}
let encrypted = encrypt_for_reader(&stream);
write_chunked_with_jitter(&mut writer, &encrypted, seed ^ 0x1234_5678).await;
drop(writer);
for _ in 0..260 {
let step = run_relay_test_step_timeout(
"fuzz proto read step",
read_once_with_state(
&mut crypto_reader,
proto,
&forensics,
&mut frame_counter,
&mut idle_state,
),
)
.await;
match step {
Ok(Some((_payload, _))) => {}
Err(ProxyError::Proxy(_)) => break,
Err(ProxyError::Io(e)) if e.kind() == std::io::ErrorKind::TimedOut => break,
Ok(None) => break,
Err(other) => panic!("unexpected proto chunking fuzz error: {other}"),
}
}
}
}

View File

@ -0,0 +1,798 @@
use super::*;
use crate::crypto::AesCtr;
use crate::stats::Stats;
use crate::stream::{BufferPool, CryptoReader};
use std::sync::atomic::AtomicU64;
use std::sync::Arc;
use std::time::Instant;
use tokio::io::{AsyncRead, AsyncWriteExt, duplex};
fn make_crypto_reader<T>(reader: T) -> CryptoReader<T>
where
T: AsyncRead + Unpin + Send + 'static,
{
let key = [0u8; 32];
let iv = 0u128;
CryptoReader::new(reader, AesCtr::new(&key, iv))
}
fn encrypt_for_reader(plaintext: &[u8]) -> Vec<u8> {
let key = [0u8; 32];
let iv = 0u128;
let mut cipher = AesCtr::new(&key, iv);
cipher.encrypt(plaintext)
}
fn make_forensics(conn_id: u64, started_at: Instant) -> RelayForensicsState {
RelayForensicsState {
trace_id: 0xB100_0000 + conn_id,
conn_id,
user: format!("tiny-frame-debt-user-{conn_id}"),
peer: "127.0.0.1:50000".parse().expect("peer parse must succeed"),
peer_hash: hash_ip("127.0.0.1".parse().expect("ip parse must succeed")),
started_at,
bytes_c2me: 0,
bytes_me2c: Arc::new(AtomicU64::new(0)),
desync_all_full: false,
}
}
fn make_enabled_idle_policy() -> RelayClientIdlePolicy {
RelayClientIdlePolicy {
enabled: true,
soft_idle: Duration::from_millis(50),
hard_idle: Duration::from_millis(120),
grace_after_downstream_activity: Duration::from_secs(0),
legacy_frame_read_timeout: Duration::from_millis(50),
}
}
async fn read_bounded(
crypto_reader: &mut CryptoReader<tokio::io::DuplexStream>,
proto_tag: ProtoTag,
buffer_pool: &Arc<BufferPool>,
forensics: &RelayForensicsState,
frame_counter: &mut u64,
stats: &Stats,
idle_policy: &RelayClientIdlePolicy,
idle_state: &mut RelayClientIdleState,
last_downstream_activity_ms: &AtomicU64,
session_started_at: Instant,
) -> Result<Option<(PooledBuffer, bool)>> {
run_relay_test_step_timeout(
"tiny-frame debt read step",
read_client_payload_with_idle_policy(
crypto_reader,
proto_tag,
1024,
buffer_pool,
forensics,
frame_counter,
stats,
idle_policy,
idle_state,
last_downstream_activity_ms,
session_started_at,
),
)
.await
}
fn simulate_tiny_debt_pattern(pattern: &[bool], max_steps: usize) -> (Option<usize>, u32, usize) {
let mut debt = 0u32;
let mut reals = 0usize;
for (idx, is_tiny) in pattern.iter().copied().take(max_steps).enumerate() {
if is_tiny {
debt = debt.saturating_add(TINY_FRAME_DEBT_PER_TINY);
if debt >= TINY_FRAME_DEBT_LIMIT {
return (Some(idx + 1), debt, reals);
}
} else {
reals = reals.saturating_add(1);
debt = debt.saturating_sub(1);
}
}
(None, debt, reals)
}
#[test]
fn tiny_frame_debt_constants_match_security_budget_expectations() {
assert_eq!(TINY_FRAME_DEBT_PER_TINY, 8);
assert_eq!(TINY_FRAME_DEBT_LIMIT, 512);
}
#[test]
fn relay_client_idle_state_initial_debt_is_zero() {
let state = RelayClientIdleState::new(Instant::now());
assert_eq!(state.tiny_frame_debt, 0);
}
#[test]
fn on_client_frame_does_not_reset_tiny_frame_debt() {
let now = Instant::now();
let mut state = RelayClientIdleState::new(now);
state.tiny_frame_debt = 77;
state.on_client_frame(now);
assert_eq!(state.tiny_frame_debt, 77);
}
#[test]
fn tiny_frame_debt_increment_is_saturating() {
let mut debt = u32::MAX - 1;
debt = debt.saturating_add(TINY_FRAME_DEBT_PER_TINY);
assert_eq!(debt, u32::MAX);
}
#[test]
fn tiny_frame_debt_decrement_is_saturating() {
let mut debt = 0u32;
debt = debt.saturating_sub(1);
assert_eq!(debt, 0);
}
#[test]
fn consecutive_tiny_frames_close_exactly_at_threshold() {
let max_tiny_without_close = (TINY_FRAME_DEBT_LIMIT / TINY_FRAME_DEBT_PER_TINY) as usize;
let pattern = vec![true; max_tiny_without_close];
let (closed_at, _, _) = simulate_tiny_debt_pattern(&pattern, pattern.len());
assert_eq!(closed_at, Some(max_tiny_without_close));
}
#[test]
fn one_less_than_threshold_tiny_frames_do_not_close() {
let tiny_count = (TINY_FRAME_DEBT_LIMIT / TINY_FRAME_DEBT_PER_TINY) as usize - 1;
let pattern = vec![true; tiny_count];
let (closed_at, debt, _) = simulate_tiny_debt_pattern(&pattern, pattern.len());
assert_eq!(closed_at, None);
assert!(debt < TINY_FRAME_DEBT_LIMIT);
}
#[test]
fn alternating_one_to_one_closes_with_bounded_real_frame_count() {
let mut pattern = Vec::with_capacity(512);
for _ in 0..256 {
pattern.push(true);
pattern.push(false);
}
let (closed_at, _, reals) = simulate_tiny_debt_pattern(&pattern, pattern.len());
assert!(closed_at.is_some());
assert!(reals <= 80, "expected bounded real frames before close, got {reals}");
}
#[test]
fn alternating_one_to_eight_is_stable_for_long_runs() {
let mut pattern = Vec::with_capacity(9 * 5000);
for _ in 0..5000 {
pattern.push(true);
for _ in 0..8 {
pattern.push(false);
}
}
let (closed_at, debt, _) = simulate_tiny_debt_pattern(&pattern, pattern.len());
assert_eq!(closed_at, None);
assert!(debt <= TINY_FRAME_DEBT_PER_TINY);
}
#[test]
fn alternating_one_to_seven_eventually_closes() {
let mut pattern = Vec::with_capacity(8 * 2000);
for _ in 0..2000 {
pattern.push(true);
for _ in 0..7 {
pattern.push(false);
}
}
let (closed_at, _, _) = simulate_tiny_debt_pattern(&pattern, pattern.len());
assert!(closed_at.is_some(), "1:7 tiny-to-real must eventually close");
}
#[test]
fn two_tiny_one_real_closes_faster_than_one_to_one() {
let mut one_to_one = Vec::with_capacity(512);
for _ in 0..256 {
one_to_one.push(true);
one_to_one.push(false);
}
let mut two_to_one = Vec::with_capacity(768);
for _ in 0..256 {
two_to_one.push(true);
two_to_one.push(true);
two_to_one.push(false);
}
let (a_close, _, _) = simulate_tiny_debt_pattern(&one_to_one, one_to_one.len());
let (b_close, _, _) = simulate_tiny_debt_pattern(&two_to_one, two_to_one.len());
assert!(a_close.is_some() && b_close.is_some());
assert!(b_close.unwrap_or(usize::MAX) < a_close.unwrap_or(0));
}
#[test]
fn burst_then_drain_can_recover_without_close() {
let burst_tiny = ((TINY_FRAME_DEBT_LIMIT / TINY_FRAME_DEBT_PER_TINY) / 2) as usize;
let mut pattern = Vec::with_capacity(burst_tiny + 600);
for _ in 0..burst_tiny {
pattern.push(true);
}
pattern.extend(std::iter::repeat_n(false, 600));
let (closed_at, debt, _) = simulate_tiny_debt_pattern(&pattern, pattern.len());
assert_eq!(closed_at, None);
assert_eq!(debt, 0);
}
#[test]
fn light_fuzz_tiny_frame_debt_model_stays_within_bounds() {
let mut seed = 0xA5A5_91C3_2026_0322u64;
for _case in 0..128 {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let len = 512 + ((seed as usize) & 0x3ff);
let mut pattern = Vec::with_capacity(len);
let mut local_seed = seed;
for _ in 0..len {
local_seed ^= local_seed << 7;
local_seed ^= local_seed >> 9;
local_seed ^= local_seed << 8;
pattern.push((local_seed & 1) == 0);
}
let (closed_at, debt, _) = simulate_tiny_debt_pattern(&pattern, pattern.len());
if closed_at.is_none() {
assert!(debt < TINY_FRAME_DEBT_LIMIT);
}
assert!(debt <= u32::MAX);
}
}
#[test]
fn stress_many_independent_simulations_keep_isolated_debt_state() {
for idx in 0..2048usize {
let mut pattern = Vec::with_capacity(64);
for j in 0..64usize {
pattern.push(((idx ^ j) & 3) == 0);
}
let (_closed_at, debt, _reals) = simulate_tiny_debt_pattern(&pattern, pattern.len());
assert!(debt <= TINY_FRAME_DEBT_LIMIT.saturating_add(TINY_FRAME_DEBT_PER_TINY));
}
}
#[tokio::test]
async fn idle_policy_enabled_intermediate_zero_length_flood_is_fail_closed() {
let (reader, mut writer) = duplex(4096);
let mut crypto_reader = make_crypto_reader(reader);
let buffer_pool = Arc::new(BufferPool::new());
let stats = Stats::new();
let session_started_at = Instant::now();
let forensics = make_forensics(11, session_started_at);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(session_started_at);
let idle_policy = make_enabled_idle_policy();
let last_downstream_activity_ms = AtomicU64::new(0);
let flood_plaintext = vec![0u8; 4 * 256];
let flood_encrypted = encrypt_for_reader(&flood_plaintext);
writer.write_all(&flood_encrypted).await.unwrap();
drop(writer);
let result = read_bounded(
&mut crypto_reader,
ProtoTag::Intermediate,
&buffer_pool,
&forensics,
&mut frame_counter,
&stats,
&idle_policy,
&mut idle_state,
&last_downstream_activity_ms,
session_started_at,
)
.await;
assert!(matches!(result, Err(ProxyError::Proxy(_))));
}
#[tokio::test]
async fn idle_policy_enabled_secure_zero_length_flood_is_fail_closed() {
let (reader, mut writer) = duplex(4096);
let mut crypto_reader = make_crypto_reader(reader);
let buffer_pool = Arc::new(BufferPool::new());
let stats = Stats::new();
let session_started_at = Instant::now();
let forensics = make_forensics(12, session_started_at);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(session_started_at);
let idle_policy = make_enabled_idle_policy();
let last_downstream_activity_ms = AtomicU64::new(0);
let flood_plaintext = vec![0u8; 4 * 256];
let flood_encrypted = encrypt_for_reader(&flood_plaintext);
writer.write_all(&flood_encrypted).await.unwrap();
drop(writer);
let result = read_bounded(
&mut crypto_reader,
ProtoTag::Secure,
&buffer_pool,
&forensics,
&mut frame_counter,
&stats,
&idle_policy,
&mut idle_state,
&last_downstream_activity_ms,
session_started_at,
)
.await;
assert!(matches!(result, Err(ProxyError::Proxy(_))));
}
#[tokio::test]
async fn intermediate_alternating_zero_and_real_eventually_closes() {
let (reader, mut writer) = duplex(8192);
let mut crypto_reader = make_crypto_reader(reader);
let buffer_pool = Arc::new(BufferPool::new());
let stats = Stats::new();
let session_started_at = Instant::now();
let forensics = make_forensics(13, session_started_at);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(session_started_at);
let idle_policy = make_enabled_idle_policy();
let last_downstream_activity_ms = AtomicU64::new(0);
let mut plaintext = Vec::with_capacity(3000);
for idx in 0..160u8 {
plaintext.extend_from_slice(&0u32.to_le_bytes());
plaintext.extend_from_slice(&4u32.to_le_bytes());
plaintext.extend_from_slice(&[idx, idx ^ 0x11, idx ^ 0x22, idx ^ 0x33]);
}
let encrypted = encrypt_for_reader(&plaintext);
writer.write_all(&encrypted).await.unwrap();
drop(writer);
let mut closed = false;
for _ in 0..220 {
let result = read_bounded(
&mut crypto_reader,
ProtoTag::Intermediate,
&buffer_pool,
&forensics,
&mut frame_counter,
&stats,
&idle_policy,
&mut idle_state,
&last_downstream_activity_ms,
session_started_at,
)
.await;
match result {
Ok(Some(_)) => {}
Err(ProxyError::Proxy(_)) => {
closed = true;
break;
}
Ok(None) => break,
Err(other) => panic!("unexpected error while probing alternating close: {other}"),
}
}
assert!(closed, "intermediate alternating attack must fail closed");
}
#[tokio::test]
async fn small_tiny_burst_followed_by_real_frame_does_not_spuriously_close() {
let (reader, mut writer) = duplex(1024);
let mut crypto_reader = make_crypto_reader(reader);
let buffer_pool = Arc::new(BufferPool::new());
let stats = Stats::new();
let session_started_at = Instant::now();
let forensics = make_forensics(14, session_started_at);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(session_started_at);
let idle_policy = make_enabled_idle_policy();
let last_downstream_activity_ms = AtomicU64::new(0);
let mut plaintext = Vec::with_capacity(64);
for _ in 0..8 {
plaintext.push(0x00);
}
plaintext.push(0x01);
plaintext.extend_from_slice(&[1, 2, 3, 4]);
let encrypted = encrypt_for_reader(&plaintext);
writer.write_all(&encrypted).await.unwrap();
let first = read_bounded(
&mut crypto_reader,
ProtoTag::Abridged,
&buffer_pool,
&forensics,
&mut frame_counter,
&stats,
&idle_policy,
&mut idle_state,
&last_downstream_activity_ms,
session_started_at,
)
.await;
match first {
Ok(Some((payload, _))) => assert_eq!(payload.as_ref(), &[1, 2, 3, 4]),
Err(e) => panic!("unexpected close after small tiny burst: {e}"),
Ok(None) => panic!("unexpected EOF before real frame"),
}
}
#[tokio::test]
async fn idle_policy_enabled_zero_length_flood_is_fail_closed() {
let (reader, mut writer) = duplex(4096);
let mut crypto_reader = make_crypto_reader(reader);
let buffer_pool = Arc::new(BufferPool::new());
let stats = Stats::new();
let session_started_at = Instant::now();
let forensics = make_forensics(1, session_started_at);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(session_started_at);
let idle_policy = make_enabled_idle_policy();
let last_downstream_activity_ms = AtomicU64::new(0);
let flood_plaintext = vec![0u8; 1024];
let flood_encrypted = encrypt_for_reader(&flood_plaintext);
writer
.write_all(&flood_encrypted)
.await
.expect("zero-length flood bytes must be writable");
drop(writer);
let result = read_bounded(
&mut crypto_reader,
ProtoTag::Abridged,
&buffer_pool,
&forensics,
&mut frame_counter,
&stats,
&idle_policy,
&mut idle_state,
&last_downstream_activity_ms,
session_started_at,
)
.await;
assert!(
matches!(result, Err(ProxyError::Proxy(_))),
"idle policy enabled must fail closed for pure zero-length flood"
);
}
#[tokio::test]
async fn idle_policy_enabled_alternating_tiny_real_eventually_closes() {
let (reader, mut writer) = duplex(8192);
let mut crypto_reader = make_crypto_reader(reader);
let buffer_pool = Arc::new(BufferPool::new());
let stats = Stats::new();
let session_started_at = Instant::now();
let forensics = make_forensics(2, session_started_at);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(session_started_at);
let idle_policy = make_enabled_idle_policy();
let last_downstream_activity_ms = AtomicU64::new(0);
let mut plaintext = Vec::with_capacity(256 * 6);
for idx in 0..=255u8 {
plaintext.push(0x00);
plaintext.push(0x01);
plaintext.extend_from_slice(&[idx, idx ^ 0x55, idx ^ 0xAA, 0x11]);
}
let encrypted = encrypt_for_reader(&plaintext);
writer
.write_all(&encrypted)
.await
.expect("alternating flood bytes must be writable");
drop(writer);
let mut saw_proxy_close = false;
for _ in 0..300 {
let result = read_bounded(
&mut crypto_reader,
ProtoTag::Abridged,
&buffer_pool,
&forensics,
&mut frame_counter,
&stats,
&idle_policy,
&mut idle_state,
&last_downstream_activity_ms,
session_started_at,
)
.await;
match result {
Ok(Some((_payload, _quickack))) => {}
Err(ProxyError::Proxy(_)) => {
saw_proxy_close = true;
break;
}
Err(ProxyError::Io(e)) => panic!("unexpected IO error before close: {e}"),
Ok(None) => panic!("unexpected EOF before debt-based closure"),
Err(other) => panic!("unexpected error before close: {other}"),
}
}
assert!(
saw_proxy_close,
"alternating tiny/real sequence must eventually fail closed"
);
}
#[tokio::test]
async fn enabled_idle_policy_valid_nonzero_frame_still_passes() {
let (reader, mut writer) = duplex(1024);
let mut crypto_reader = make_crypto_reader(reader);
let buffer_pool = Arc::new(BufferPool::new());
let stats = Stats::new();
let session_started_at = Instant::now();
let forensics = make_forensics(3, session_started_at);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(session_started_at);
let idle_policy = make_enabled_idle_policy();
let last_downstream_activity_ms = AtomicU64::new(0);
let payload = [7u8, 8, 9, 10];
let mut plaintext = Vec::with_capacity(1 + payload.len());
plaintext.push(0x01);
plaintext.extend_from_slice(&payload);
let encrypted = encrypt_for_reader(&plaintext);
writer
.write_all(&encrypted)
.await
.expect("nonzero frame must be writable");
let result = read_bounded(
&mut crypto_reader,
ProtoTag::Abridged,
&buffer_pool,
&forensics,
&mut frame_counter,
&stats,
&idle_policy,
&mut idle_state,
&last_downstream_activity_ms,
session_started_at,
)
.await
.expect("valid frame should decode")
.expect("valid frame should return payload");
assert_eq!(result.0.as_ref(), &payload);
assert!(!result.1);
assert_eq!(frame_counter, 1);
}
#[tokio::test]
async fn abridged_quickack_tiny_flood_is_fail_closed() {
let (reader, mut writer) = duplex(4096);
let mut crypto_reader = make_crypto_reader(reader);
let buffer_pool = Arc::new(BufferPool::new());
let stats = Stats::new();
let session_started_at = Instant::now();
let forensics = make_forensics(21, session_started_at);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(session_started_at);
let idle_policy = make_enabled_idle_policy();
let last_downstream_activity_ms = AtomicU64::new(0);
let flood_plaintext = vec![0x80u8; 256];
let flood_encrypted = encrypt_for_reader(&flood_plaintext);
writer.write_all(&flood_encrypted).await.unwrap();
drop(writer);
let result = read_bounded(
&mut crypto_reader,
ProtoTag::Abridged,
&buffer_pool,
&forensics,
&mut frame_counter,
&stats,
&idle_policy,
&mut idle_state,
&last_downstream_activity_ms,
session_started_at,
)
.await;
assert!(
matches!(result, Err(ProxyError::Proxy(_))),
"quickack-marked zero-length flood must fail closed"
);
}
#[tokio::test]
async fn abridged_extended_zero_len_flood_is_fail_closed() {
let (reader, mut writer) = duplex(4096);
let mut crypto_reader = make_crypto_reader(reader);
let buffer_pool = Arc::new(BufferPool::new());
let stats = Stats::new();
let session_started_at = Instant::now();
let forensics = make_forensics(22, session_started_at);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(session_started_at);
let idle_policy = make_enabled_idle_policy();
let last_downstream_activity_ms = AtomicU64::new(0);
let mut flood_plaintext = Vec::with_capacity(4 * 256);
for _ in 0..256 {
flood_plaintext.extend_from_slice(&[0x7f, 0x00, 0x00, 0x00]);
}
let flood_encrypted = encrypt_for_reader(&flood_plaintext);
writer.write_all(&flood_encrypted).await.unwrap();
drop(writer);
let result = read_bounded(
&mut crypto_reader,
ProtoTag::Abridged,
&buffer_pool,
&forensics,
&mut frame_counter,
&stats,
&idle_policy,
&mut idle_state,
&last_downstream_activity_ms,
session_started_at,
)
.await;
assert!(
matches!(result, Err(ProxyError::Proxy(_))),
"extended zero-length abridged flood must fail closed"
);
}
#[tokio::test]
async fn one_to_eight_abridged_wire_pattern_survives_without_false_positive_close() {
let mut plaintext = Vec::with_capacity(9 * 300);
for idx in 0..300usize {
plaintext.push(0x00);
for _ in 0..8 {
let b = idx as u8;
plaintext.push(0x01);
plaintext.extend_from_slice(&[b, b ^ 0x11, b ^ 0x22, b ^ 0x33]);
}
}
// Keep the test single-task and deterministic: make duplex capacity larger than the
// generated ciphertext so write_all cannot block waiting for a concurrent reader.
let duplex_capacity = plaintext.len().saturating_add(1024);
let (reader, mut writer) = duplex(duplex_capacity);
let mut crypto_reader = make_crypto_reader(reader);
let buffer_pool = Arc::new(BufferPool::new());
let stats = Stats::new();
let session_started_at = Instant::now();
let forensics = make_forensics(23, session_started_at);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(session_started_at);
let idle_policy = make_enabled_idle_policy();
let last_downstream_activity_ms = AtomicU64::new(0);
let encrypted = encrypt_for_reader(&plaintext);
writer.write_all(&encrypted).await.unwrap();
drop(writer);
let mut closed = false;
for _ in 0..3000 {
match read_bounded(
&mut crypto_reader,
ProtoTag::Abridged,
&buffer_pool,
&forensics,
&mut frame_counter,
&stats,
&idle_policy,
&mut idle_state,
&last_downstream_activity_ms,
session_started_at,
)
.await
{
Ok(Some(_)) => {}
Ok(None) => break,
Err(ProxyError::Proxy(_)) => {
closed = true;
break;
}
Err(other) => panic!("unexpected error in 1:8 wire test: {other}"),
}
}
assert!(
!closed,
"wire-level 1:8 tiny-to-real pattern should not trigger debt close"
);
}
#[tokio::test]
async fn deterministic_light_fuzz_abridged_wire_behavior_matches_model() {
let mut seed = 0xD1CE_BAAD_2026_0322u64;
for case_idx in 0..32u64 {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let events = 300 + ((seed as usize) & 0xff);
let mut pattern = Vec::with_capacity(events);
let mut local = seed;
for _ in 0..events {
local ^= local << 7;
local ^= local >> 9;
local ^= local << 8;
pattern.push((local & 0x03) == 0);
}
let mut plaintext = Vec::with_capacity(events * 6);
for (idx, tiny) in pattern.iter().copied().enumerate() {
if tiny {
plaintext.push(0x00);
} else {
let b = (idx as u8) ^ (case_idx as u8);
plaintext.push(0x01);
plaintext.extend_from_slice(&[b, b ^ 0x1F, b ^ 0x7A, b ^ 0xC3]);
}
}
let (reader, mut writer) = duplex(16 * 1024);
let mut crypto_reader = make_crypto_reader(reader);
let buffer_pool = Arc::new(BufferPool::new());
let stats = Stats::new();
let session_started_at = Instant::now();
let forensics = make_forensics(500 + case_idx, session_started_at);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(session_started_at);
let idle_policy = make_enabled_idle_policy();
let last_downstream_activity_ms = AtomicU64::new(0);
writer
.write_all(&encrypt_for_reader(&plaintext))
.await
.unwrap();
drop(writer);
let (expected_close, _, _) = simulate_tiny_debt_pattern(&pattern, pattern.len());
let mut observed_close = false;
for _ in 0..(events + 8) {
match read_bounded(
&mut crypto_reader,
ProtoTag::Abridged,
&buffer_pool,
&forensics,
&mut frame_counter,
&stats,
&idle_policy,
&mut idle_state,
&last_downstream_activity_ms,
session_started_at,
)
.await
{
Ok(Some(_)) => {}
Ok(None) => break,
Err(ProxyError::Proxy(_)) => {
observed_close = true;
break;
}
Err(other) => panic!("unexpected fuzz error: {other}"),
}
}
assert_eq!(
observed_close,
expected_close.is_some(),
"wire parser behavior must match debt model for case {case_idx}"
);
}
}

View File

@ -0,0 +1,121 @@
use super::*;
use crate::crypto::AesCtr;
use crate::stats::Stats;
use crate::stream::{BufferPool, CryptoReader};
use std::sync::atomic::AtomicU64;
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWriteExt, duplex};
use std::time::Instant;
fn make_crypto_reader<T>(reader: T) -> CryptoReader<T>
where
T: AsyncRead + Unpin + Send + 'static,
{
let key = [0u8; 32];
let iv = 0u128;
CryptoReader::new(reader, AesCtr::new(&key, iv))
}
fn encrypt_for_reader(plaintext: &[u8]) -> Vec<u8> {
let key = [0u8; 32];
let iv = 0u128;
let mut cipher = AesCtr::new(&key, iv);
cipher.encrypt(plaintext)
}
fn make_forensics(conn_id: u64, started_at: Instant) -> RelayForensicsState {
RelayForensicsState {
trace_id: 0xB000_0000 + conn_id,
conn_id,
user: format!("zero-len-test-user-{conn_id}"),
peer: "127.0.0.1:50000".parse().expect("peer parse must succeed"),
peer_hash: hash_ip("127.0.0.1".parse().expect("ip parse must succeed")),
started_at,
bytes_c2me: 0,
bytes_me2c: Arc::new(AtomicU64::new(0)),
desync_all_full: false,
}
}
#[tokio::test]
async fn adversarial_legacy_zero_length_flood_is_fail_closed() {
let (reader, mut writer) = duplex(1024);
let mut crypto_reader = make_crypto_reader(reader);
let buffer_pool = Arc::new(BufferPool::new());
let stats = Stats::new();
let session_started_at = Instant::now();
let forensics = make_forensics(1, session_started_at);
let mut frame_counter = 0u64;
let flood_plaintext = vec![0u8; 128];
let flood_encrypted = encrypt_for_reader(&flood_plaintext);
writer
.write_all(&flood_encrypted)
.await
.expect("zero-length flood bytes must be writable");
drop(writer);
let result = read_client_payload_legacy(
&mut crypto_reader,
ProtoTag::Abridged,
1024,
Duration::from_millis(30),
&buffer_pool,
&forensics,
&mut frame_counter,
&stats,
)
.await;
match result {
Err(ProxyError::Proxy(msg)) => {
assert!(
msg.contains("Excessive zero-length"),
"legacy mode must close flood with explicit zero-length reason, got: {msg}"
);
}
Ok(None) => panic!("legacy zero-length flood must not be accepted as EOF"),
Ok(Some(_)) => panic!("legacy zero-length flood must not produce a data frame"),
Err(err) => panic!("legacy zero-length flood must be a Proxy error, got: {err}"),
}
}
#[tokio::test]
async fn business_abridged_nonzero_frame_still_passes() {
let (reader, mut writer) = duplex(1024);
let mut crypto_reader = make_crypto_reader(reader);
let buffer_pool = Arc::new(BufferPool::new());
let stats = Stats::new();
let session_started_at = Instant::now();
let forensics = make_forensics(2, session_started_at);
let mut frame_counter = 0u64;
let payload = [1u8, 2, 3, 4];
let mut plaintext = Vec::with_capacity(1 + payload.len());
plaintext.push(0x01);
plaintext.extend_from_slice(&payload);
let encrypted = encrypt_for_reader(&plaintext);
writer
.write_all(&encrypted)
.await
.expect("nonzero abridged frame must be writable");
let result = read_client_payload_legacy(
&mut crypto_reader,
ProtoTag::Abridged,
1024,
Duration::from_millis(30),
&buffer_pool,
&forensics,
&mut frame_counter,
&stats,
)
.await
.expect("valid abridged frame should decode")
.expect("valid abridged frame should return payload");
assert_eq!(result.0.as_ref(), &payload);
assert!(!result.1, "quickack flag must remain false");
assert_eq!(frame_counter, 1);
}

View File

@ -0,0 +1,108 @@
use super::*;
use std::sync::Arc;
use std::sync::{Mutex, OnceLock};
fn cross_mode_lock_test_guard() -> std::sync::MutexGuard<'static, ()> {
static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
TEST_LOCK
.get_or_init(|| Mutex::new(()))
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
#[test]
fn same_user_returns_same_lock_identity() {
let _guard = cross_mode_lock_test_guard();
let locks = CROSS_MODE_QUOTA_USER_LOCKS.get_or_init(DashMap::new);
locks.clear();
let a = cross_mode_quota_user_lock("cross-mode-same-user");
let b = cross_mode_quota_user_lock("cross-mode-same-user");
assert!(
Arc::ptr_eq(&a, &b),
"same user must reuse a stable lock identity"
);
}
#[test]
fn saturation_overflow_path_returns_stable_striped_lock_without_cache_growth() {
let _guard = cross_mode_lock_test_guard();
let locks = CROSS_MODE_QUOTA_USER_LOCKS.get_or_init(DashMap::new);
locks.clear();
let prefix = format!("cross-mode-saturated-{}", std::process::id());
let mut retained = Vec::with_capacity(CROSS_MODE_QUOTA_USER_LOCKS_MAX);
for idx in 0..CROSS_MODE_QUOTA_USER_LOCKS_MAX {
retained.push(cross_mode_quota_user_lock(&format!("{prefix}-{idx}")));
}
assert_eq!(
locks.len(),
CROSS_MODE_QUOTA_USER_LOCKS_MAX,
"lock cache must be saturated for overflow check"
);
let overflow_user = format!("cross-mode-overflow-{}", std::process::id());
let overflow_a = cross_mode_quota_user_lock(&overflow_user);
let overflow_b = cross_mode_quota_user_lock(&overflow_user);
assert_eq!(
locks.len(),
CROSS_MODE_QUOTA_USER_LOCKS_MAX,
"overflow path must not grow bounded lock cache"
);
assert!(
locks.get(&overflow_user).is_none(),
"overflow user must stay on striped fallback while cache is saturated"
);
assert!(
Arc::ptr_eq(&overflow_a, &overflow_b),
"overflow user must receive a stable striped lock across repeated lookups"
);
drop(retained);
}
#[test]
fn reclaim_drops_stale_entries_but_preserves_active_user_lock_identity() {
let _guard = cross_mode_lock_test_guard();
let locks = CROSS_MODE_QUOTA_USER_LOCKS.get_or_init(DashMap::new);
locks.clear();
let prefix = format!("cross-mode-reclaim-{}", std::process::id());
let protected_user = format!("{prefix}-protected");
let protected_lock = cross_mode_quota_user_lock(&protected_user);
let mut retained = Vec::with_capacity(CROSS_MODE_QUOTA_USER_LOCKS_MAX.saturating_sub(1));
for idx in 0..(CROSS_MODE_QUOTA_USER_LOCKS_MAX.saturating_sub(1)) {
retained.push(cross_mode_quota_user_lock(&format!("{prefix}-{idx}")));
}
assert_eq!(
locks.len(),
CROSS_MODE_QUOTA_USER_LOCKS_MAX,
"fixture must saturate lock cache before reclaim path is exercised"
);
drop(retained);
let newcomer_user = format!("{prefix}-newcomer");
let _newcomer = cross_mode_quota_user_lock(&newcomer_user);
assert!(
locks.get(&protected_user).is_some(),
"active protected user must remain cache-resident after reclaim"
);
let locked = locks
.get(&protected_user)
.expect("protected user must remain in map after reclaim");
assert!(
Arc::ptr_eq(locked.value(), &protected_lock),
"reclaim must not swap active user lock identity"
);
assert!(
locks.get(&newcomer_user).is_some(),
"newcomer should become cacheable after stale entries are reclaimed"
);
}

View File

@ -0,0 +1,267 @@
use super::relay_bidirectional;
use crate::stats::Stats;
use crate::stream::BufferPool;
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::time::{Duration, timeout};
fn quota_test_guard() -> impl Drop {
super::quota_user_lock_test_scope()
}
#[tokio::test]
async fn negative_same_user_pipeline_stalls_while_middle_lock_is_held() {
let _guard = quota_test_guard();
let user = format!("relay-pipeline-stall-{}", std::process::id());
let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user);
let held_guard = held
.try_lock()
.expect("test must hold shared cross-mode lock");
let stats = Arc::new(Stats::new());
let (mut client_peer, relay_client) = duplex(1024);
let (relay_server, mut server_peer) = duplex(1024);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let relay_user = user.clone();
let relay_stats = Arc::clone(&stats);
let relay_task = tokio::spawn(async move {
relay_bidirectional(
client_reader,
client_writer,
server_reader,
server_writer,
256,
256,
&relay_user,
relay_stats,
Some(1024),
Arc::new(BufferPool::new()),
)
.await
});
server_peer
.write_all(&[0xA1])
.await
.expect("server write should enqueue while relay is stalled");
let mut one = [0u8; 1];
let blocked_read = timeout(Duration::from_millis(40), client_peer.read_exact(&mut one)).await;
assert!(
blocked_read.is_err(),
"same-user relay must remain blocked while cross-mode lock is held"
);
drop(held_guard);
timeout(Duration::from_millis(400), client_peer.read_exact(&mut one))
.await
.expect("blocked relay must resume after cross-mode lock release")
.expect("resumed relay must deliver queued byte");
assert_eq!(one, [0xA1]);
drop(client_peer);
drop(server_peer);
let relay_result = timeout(Duration::from_secs(1), relay_task)
.await
.expect("relay task must complete")
.expect("relay task must not panic");
assert!(relay_result.is_ok());
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn integration_other_user_pipeline_progresses_while_blocked_user_is_stalled() {
let _guard = quota_test_guard();
let blocked_user = format!("relay-pipeline-blocked-{}", std::process::id());
let free_user = format!("relay-pipeline-free-{}", std::process::id());
let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&blocked_user);
let held_guard = held
.try_lock()
.expect("test must hold blocked user's shared cross-mode lock");
let stats_blocked = Arc::new(Stats::new());
let stats_free = Arc::new(Stats::new());
let (mut blocked_client, blocked_relay_client) = duplex(1024);
let (blocked_relay_server, mut blocked_server) = duplex(1024);
let (blocked_client_reader, blocked_client_writer) = tokio::io::split(blocked_relay_client);
let (blocked_server_reader, blocked_server_writer) = tokio::io::split(blocked_relay_server);
let (mut free_client, free_relay_client) = duplex(1024);
let (free_relay_server, mut free_server) = duplex(1024);
let (free_client_reader, free_client_writer) = tokio::io::split(free_relay_client);
let (free_server_reader, free_server_writer) = tokio::io::split(free_relay_server);
let blocked_task = {
let user = blocked_user.clone();
let stats = Arc::clone(&stats_blocked);
tokio::spawn(async move {
relay_bidirectional(
blocked_client_reader,
blocked_client_writer,
blocked_server_reader,
blocked_server_writer,
256,
256,
&user,
stats,
Some(1024),
Arc::new(BufferPool::new()),
)
.await
})
};
let free_task = {
let user = free_user.clone();
let stats = Arc::clone(&stats_free);
tokio::spawn(async move {
relay_bidirectional(
free_client_reader,
free_client_writer,
free_server_reader,
free_server_writer,
256,
256,
&user,
stats,
Some(1024),
Arc::new(BufferPool::new()),
)
.await
})
};
blocked_server
.write_all(&[0xB1])
.await
.expect("blocked user server write should queue");
free_server
.write_all(&[0xC1])
.await
.expect("free user server write should queue");
let mut blocked_buf = [0u8; 1];
let mut free_buf = [0u8; 1];
let blocked_stalled = timeout(
Duration::from_millis(40),
blocked_client.read_exact(&mut blocked_buf),
)
.await;
assert!(
blocked_stalled.is_err(),
"blocked user must remain stalled while its lock is held"
);
timeout(Duration::from_millis(250), free_client.read_exact(&mut free_buf))
.await
.expect("free user must make progress while other user is blocked")
.expect("free user read must succeed");
assert_eq!(free_buf, [0xC1]);
drop(held_guard);
timeout(Duration::from_millis(400), blocked_client.read_exact(&mut blocked_buf))
.await
.expect("blocked user must resume after release")
.expect("blocked user resumed read must succeed");
assert_eq!(blocked_buf, [0xB1]);
drop(blocked_client);
drop(blocked_server);
drop(free_client);
drop(free_server);
assert!(
timeout(Duration::from_secs(1), blocked_task)
.await
.expect("blocked relay task must complete")
.expect("blocked relay task must not panic")
.is_ok()
);
assert!(
timeout(Duration::from_secs(1), free_task)
.await
.expect("free relay task must complete")
.expect("free relay task must not panic")
.is_ok()
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn stress_jittered_hold_release_cycles_preserve_pipeline_liveness() {
let _guard = quota_test_guard();
let mut seed = 0x5EED_C0DE_2026_0323u64;
for round in 0..24u32 {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let hold_ms = 2 + (seed % 10);
let user = format!("relay-pipeline-fuzz-{}-{round}", std::process::id());
let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user);
let held_guard = held
.try_lock()
.expect("test must hold lock during fuzz round");
let stats = Arc::new(Stats::new());
let (mut client_peer, relay_client) = duplex(1024);
let (relay_server, mut server_peer) = duplex(1024);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let relay_user = user.clone();
let relay_stats = Arc::clone(&stats);
let relay_task = tokio::spawn(async move {
relay_bidirectional(
client_reader,
client_writer,
server_reader,
server_writer,
256,
256,
&relay_user,
relay_stats,
Some(1024),
Arc::new(BufferPool::new()),
)
.await
});
server_peer
.write_all(&[0xD1])
.await
.expect("server write should queue in fuzz round");
let mut one = [0u8; 1];
let stalled = timeout(Duration::from_millis(30), client_peer.read_exact(&mut one)).await;
assert!(stalled.is_err(), "held phase must stall same-user relay");
tokio::time::sleep(Duration::from_millis(hold_ms)).await;
drop(held_guard);
timeout(Duration::from_millis(400), client_peer.read_exact(&mut one))
.await
.expect("released phase must resume same-user relay")
.expect("released phase read must succeed");
assert_eq!(one, [0xD1]);
drop(client_peer);
drop(server_peer);
assert!(
timeout(Duration::from_secs(1), relay_task)
.await
.expect("fuzz relay task must complete")
.expect("fuzz relay task must not panic")
.is_ok()
);
}
}

View File

@ -0,0 +1,213 @@
use super::relay_bidirectional;
use crate::stats::Stats;
use crate::stream::BufferPool;
use std::sync::{Arc, Mutex};
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::sync::{Barrier, watch};
use tokio::time::{Duration, Instant, timeout};
fn quota_test_guard() -> impl Drop {
super::quota_user_lock_test_scope()
}
fn percentile_index(len: usize, percentile: usize) -> usize {
((len * percentile) / 100).min(len.saturating_sub(1))
}
#[tokio::test]
async fn micro_benchmark_pipeline_release_to_delivery_latency_stays_bounded() {
let _guard = quota_test_guard();
let rounds = 64usize;
let user = format!("relay-pipeline-latency-single-{}", std::process::id());
let mut samples_ms = Vec::with_capacity(rounds);
for round in 0..rounds {
let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user);
let held_guard = held
.try_lock()
.expect("test must hold shared cross-mode lock before round");
let stats = Arc::new(Stats::new());
let (mut client_peer, relay_client) = duplex(1024);
let (relay_server, mut server_peer) = duplex(1024);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let relay_user = user.clone();
let relay_stats = Arc::clone(&stats);
let relay_task = tokio::spawn(async move {
relay_bidirectional(
client_reader,
client_writer,
server_reader,
server_writer,
256,
256,
&relay_user,
relay_stats,
Some(2048),
Arc::new(BufferPool::new()),
)
.await
});
server_peer
.write_all(&[(round as u8) ^ 0xA5])
.await
.expect("server write should queue before release");
let release_at = Instant::now();
drop(held_guard);
let mut one = [0u8; 1];
timeout(Duration::from_millis(450), client_peer.read_exact(&mut one))
.await
.expect("client must receive queued byte after release")
.expect("queued byte read must succeed");
samples_ms.push(release_at.elapsed().as_millis() as u64);
drop(client_peer);
drop(server_peer);
let relay_result = timeout(Duration::from_secs(1), relay_task)
.await
.expect("relay task must complete")
.expect("relay task must not panic");
assert!(relay_result.is_ok());
}
samples_ms.sort_unstable();
let p50_ms = samples_ms[percentile_index(samples_ms.len(), 50)];
let p95_ms = samples_ms[percentile_index(samples_ms.len(), 95)];
assert!(
p50_ms <= 45,
"single-flow release latency p50 must stay bounded; p50_ms={p50_ms}, samples={samples_ms:?}"
);
assert!(
p95_ms <= 130,
"single-flow release latency p95 must stay bounded; p95_ms={p95_ms}, samples={samples_ms:?}"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn stress_128_waiter_pipeline_release_latency_p95_stays_bounded() {
let _guard = quota_test_guard();
let waiters = 128usize;
let user = format!("relay-pipeline-latency-fanout-{}", std::process::id());
let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user);
let held_guard = held
.try_lock()
.expect("test must hold shared lock before fanout release benchmark");
let ready_barrier = Arc::new(Barrier::new(waiters + 1));
let release_at = Arc::new(Mutex::new(None::<Instant>));
let (release_tx, release_rx) = watch::channel(false);
let mut tasks = Vec::with_capacity(waiters);
for idx in 0..waiters {
let user = user.clone();
let barrier = Arc::clone(&ready_barrier);
let release_at = Arc::clone(&release_at);
let mut release_rx = release_rx.clone();
tasks.push(tokio::spawn(async move {
let stats = Arc::new(Stats::new());
let (mut client_peer, relay_client) = duplex(512);
let (relay_server, mut server_peer) = duplex(512);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let relay_user = user;
let relay_stats = Arc::clone(&stats);
let relay_task = tokio::spawn(async move {
relay_bidirectional(
client_reader,
client_writer,
server_reader,
server_writer,
256,
256,
&relay_user,
relay_stats,
Some(2048),
Arc::new(BufferPool::new()),
)
.await
});
server_peer
.write_all(&[(idx as u8) ^ 0x5A])
.await
.expect("fanout server write should queue before release");
barrier.wait().await;
release_rx
.changed()
.await
.expect("release signal should remain available");
let started = {
let guard = release_at.lock().unwrap_or_else(|poison| poison.into_inner());
guard.expect("release timestamp must be populated before signal")
};
let mut one = [0u8; 1];
timeout(Duration::from_millis(900), client_peer.read_exact(&mut one))
.await
.expect("fanout waiter must receive queued byte after release")
.expect("fanout waiter read must succeed");
drop(client_peer);
drop(server_peer);
let relay_result = timeout(Duration::from_secs(2), relay_task)
.await
.expect("fanout relay task must complete")
.expect("fanout relay task must not panic");
assert!(relay_result.is_ok());
started.elapsed().as_millis() as u64
}));
}
ready_barrier.wait().await;
{
let mut guard = release_at.lock().unwrap_or_else(|poison| poison.into_inner());
*guard = Some(Instant::now());
}
drop(held_guard);
release_tx
.send(true)
.expect("release broadcast must succeed");
let mut samples_ms = Vec::with_capacity(waiters);
timeout(Duration::from_secs(8), async {
for task in tasks {
let elapsed = task.await.expect("fanout waiter must not panic");
samples_ms.push(elapsed);
}
})
.await
.expect("fanout benchmark must complete in bounded time");
samples_ms.sort_unstable();
let p50_ms = samples_ms[percentile_index(samples_ms.len(), 50)];
let p95_ms = samples_ms[percentile_index(samples_ms.len(), 95)];
let max_ms = *samples_ms.last().unwrap_or(&0);
assert!(
p50_ms <= 120,
"fanout release latency p50 must stay bounded; p50_ms={p50_ms}, p95_ms={p95_ms}, max_ms={max_ms}"
);
assert!(
p95_ms <= 260,
"fanout release latency p95 must stay bounded; p50_ms={p50_ms}, p95_ms={p95_ms}, max_ms={max_ms}"
);
assert!(
max_ms <= 700,
"fanout release latency max must stay bounded; p50_ms={p50_ms}, p95_ms={p95_ms}, max_ms={max_ms}"
);
}

View File

@ -0,0 +1,604 @@
use super::*;
use crate::stats::Stats;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::task::{Context, Poll, Waker};
use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadBuf};
use tokio::sync::Barrier;
use tokio::time::{Duration, timeout};
#[derive(Default)]
struct WakeCounter {
wakes: AtomicUsize,
}
impl std::task::Wake for WakeCounter {
fn wake(self: Arc<Self>) {
self.wakes.fetch_add(1, Ordering::Relaxed);
}
fn wake_by_ref(self: &Arc<Self>) {
self.wakes.fetch_add(1, Ordering::Relaxed);
}
}
fn quota_test_guard() -> impl Drop {
super::quota_user_lock_test_scope()
}
fn build_context() -> (Arc<WakeCounter>, Context<'static>) {
let wake_counter = Arc::new(WakeCounter::default());
let waker = Waker::from(Arc::clone(&wake_counter));
let leaked_waker: &'static Waker = Box::leak(Box::new(waker));
(wake_counter, Context::from_waker(leaked_waker))
}
#[tokio::test]
async fn positive_cross_mode_uncontended_writer_progresses() {
let _guard = quota_test_guard();
let mut io = StatsIo::new(
tokio::io::sink(),
Arc::new(SharedCounters::new()),
Arc::new(Stats::new()),
"cross-mode-tdd-uncontended".to_string(),
Some(4096),
Arc::new(AtomicBool::new(false)),
tokio::time::Instant::now(),
);
let result = io.write_all(&[0x11, 0x22]).await;
assert!(result.is_ok(), "uncontended writer must progress");
}
#[tokio::test]
async fn adversarial_held_cross_mode_lock_blocks_writer_even_if_local_lock_free() {
let _guard = quota_test_guard();
let user = format!("cross-mode-tdd-held-{}", std::process::id());
let held = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
let _held_guard = held
.try_lock()
.expect("test must hold cross-mode lock before polling writer");
let mut io = StatsIo::new(
tokio::io::sink(),
Arc::new(SharedCounters::new()),
Arc::new(Stats::new()),
user,
Some(4096),
Arc::new(AtomicBool::new(false)),
tokio::time::Instant::now(),
);
let wake_counter = Arc::new(WakeCounter::default());
let waker = Waker::from(Arc::clone(&wake_counter));
let mut cx = Context::from_waker(&waker);
let poll = Pin::new(&mut io).poll_write(&mut cx, &[0xAA]);
assert!(poll.is_pending(), "writer must not bypass held cross-mode lock");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn integration_parallel_waiters_resume_after_cross_mode_release() {
let _guard = quota_test_guard();
let user = format!("cross-mode-tdd-resume-{}", std::process::id());
let held = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
let held_guard = held
.try_lock()
.expect("test must hold cross-mode lock before launching waiters");
let stats = Arc::new(Stats::new());
let mut waiters = Vec::new();
for _ in 0..16 {
let stats = Arc::clone(&stats);
let user = user.clone();
waiters.push(tokio::spawn(async move {
let mut io = StatsIo::new(
tokio::io::sink(),
Arc::new(SharedCounters::new()),
stats,
user,
Some(4096),
Arc::new(AtomicBool::new(false)),
tokio::time::Instant::now(),
);
io.write_all(&[0x7F]).await
}));
}
tokio::time::sleep(Duration::from_millis(5)).await;
drop(held_guard);
timeout(Duration::from_secs(1), async {
for waiter in waiters {
let result = waiter.await.expect("waiter task must not panic");
assert!(result.is_ok(), "waiter must complete after cross-mode release");
}
})
.await
.expect("all waiters must complete in bounded time");
}
#[tokio::test]
async fn adversarial_cross_mode_contention_wake_budget_stays_bounded() {
let _guard = quota_test_guard();
let user = format!("cross-mode-tdd-wakes-{}", std::process::id());
let held = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
let _held_guard = held
.try_lock()
.expect("test must hold cross-mode lock before polling");
let stats = Arc::new(Stats::new());
let mut ios = Vec::new();
let mut counters = Vec::new();
for _ in 0..20 {
ios.push(StatsIo::new(
tokio::io::sink(),
Arc::new(SharedCounters::new()),
Arc::clone(&stats),
user.clone(),
Some(2048),
Arc::new(AtomicBool::new(false)),
tokio::time::Instant::now(),
));
}
for io in &mut ios {
let wake_counter = Arc::new(WakeCounter::default());
let waker = Waker::from(Arc::clone(&wake_counter));
let mut cx = Context::from_waker(&waker);
let poll = Pin::new(io).poll_write(&mut cx, &[0x33]);
assert!(poll.is_pending());
counters.push(wake_counter);
}
tokio::time::sleep(Duration::from_millis(25)).await;
let total_wakes: usize = counters
.iter()
.map(|counter| counter.wakes.load(Ordering::Relaxed))
.sum();
assert!(
total_wakes <= 20 * 4,
"cross-mode contention should not create wake storms; wakes={total_wakes}"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn light_fuzz_cross_mode_release_timing_preserves_read_write_liveness() {
let _guard = quota_test_guard();
let mut seed = 0xC0DE_BAAD_2026_0322u64;
for round in 0..16u32 {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let sleep_ms = 2 + (seed as u64 % 8);
let user = format!("cross-mode-tdd-fuzz-{}-{round}", std::process::id());
let held = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
let held_guard = held
.try_lock()
.expect("test must hold cross-mode lock in fuzz round");
let stats = Arc::new(Stats::new());
let user_reader = user.clone();
let reader_task = tokio::spawn(async move {
let mut io = StatsIo::new(
tokio::io::empty(),
Arc::new(SharedCounters::new()),
Arc::clone(&stats),
user_reader,
Some(4096),
Arc::new(AtomicBool::new(false)),
tokio::time::Instant::now(),
);
let mut one = [0u8; 1];
io.read(&mut one).await
});
let user_writer = user.clone();
let writer_task = tokio::spawn(async move {
let mut io = StatsIo::new(
tokio::io::sink(),
Arc::new(SharedCounters::new()),
Arc::new(Stats::new()),
user_writer,
Some(4096),
Arc::new(AtomicBool::new(false)),
tokio::time::Instant::now(),
);
io.write_all(&[0x44]).await
});
tokio::time::sleep(Duration::from_millis(sleep_ms)).await;
drop(held_guard);
let read_done = timeout(Duration::from_millis(350), reader_task)
.await
.expect("reader task must complete after release")
.expect("reader task must not panic");
assert!(read_done.is_ok());
let write_done = timeout(Duration::from_millis(350), writer_task)
.await
.expect("writer task must complete after release")
.expect("writer task must not panic");
assert!(write_done.is_ok());
}
}
#[tokio::test]
async fn integration_middle_lock_blocks_relay_reader_for_same_user() {
let _guard = quota_test_guard();
let user = format!("cross-mode-middle-reader-block-{}", std::process::id());
let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user);
let _held_guard = held
.try_lock()
.expect("test must hold middle-relay shared lock");
let mut io = StatsIo::new(
tokio::io::empty(),
Arc::new(SharedCounters::new()),
Arc::new(Stats::new()),
user,
Some(1024),
Arc::new(AtomicBool::new(false)),
tokio::time::Instant::now(),
);
let (_wake_counter, mut cx) = build_context();
let mut one = [0u8; 1];
let mut buf = ReadBuf::new(&mut one);
let poll = Pin::new(&mut io).poll_read(&mut cx, &mut buf);
assert!(poll.is_pending());
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn integration_middle_lock_release_unblocks_relay_reader() {
let _guard = quota_test_guard();
let user = format!("cross-mode-middle-reader-release-{}", std::process::id());
let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user);
let held_guard = held
.try_lock()
.expect("test must hold middle-relay shared lock");
let task = tokio::spawn({
let user = user.clone();
async move {
let mut io = StatsIo::new(
tokio::io::empty(),
Arc::new(SharedCounters::new()),
Arc::new(Stats::new()),
user,
Some(1024),
Arc::new(AtomicBool::new(false)),
tokio::time::Instant::now(),
);
let mut one = [0u8; 1];
io.read(&mut one).await
}
});
tokio::time::sleep(Duration::from_millis(5)).await;
drop(held_guard);
let done = timeout(Duration::from_millis(300), task)
.await
.expect("reader task must complete after release")
.expect("reader task must not panic");
assert!(done.is_ok());
}
#[tokio::test]
async fn business_different_user_middle_lock_does_not_block_relay_writer() {
let _guard = quota_test_guard();
let held_user = format!("cross-mode-middle-held-{}", std::process::id());
let active_user = format!("cross-mode-middle-active-{}", std::process::id());
let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&held_user);
let _held_guard = held
.try_lock()
.expect("test must hold middle-relay lock for other user");
let mut io = StatsIo::new(
tokio::io::sink(),
Arc::new(SharedCounters::new()),
Arc::new(Stats::new()),
active_user,
Some(1024),
Arc::new(AtomicBool::new(false)),
tokio::time::Instant::now(),
);
let (_wake_counter, mut cx) = build_context();
let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x61]);
assert!(matches!(poll, Poll::Ready(Ok(1))));
}
#[tokio::test]
async fn edge_quota_none_bypasses_cross_mode_lock_even_when_held() {
let _guard = quota_test_guard();
let user = format!("cross-mode-none-limit-{}", std::process::id());
let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user);
let _held_guard = held
.try_lock()
.expect("test must hold lock while quota is disabled");
let mut io = StatsIo::new(
tokio::io::sink(),
Arc::new(SharedCounters::new()),
Arc::new(Stats::new()),
user,
None,
Arc::new(AtomicBool::new(false)),
tokio::time::Instant::now(),
);
let (_wake_counter, mut cx) = build_context();
let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x62, 0x63]);
assert!(matches!(poll, Poll::Ready(Ok(2))));
}
#[tokio::test]
async fn edge_quota_exceeded_flag_short_circuits_before_lock_path() {
let _guard = quota_test_guard();
let user = format!("cross-mode-pre-exceeded-{}", std::process::id());
let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user);
let _held_guard = held
.try_lock()
.expect("test must hold shared lock before poll");
let quota_exceeded = Arc::new(AtomicBool::new(true));
let mut io = StatsIo::new(
tokio::io::sink(),
Arc::new(SharedCounters::new()),
Arc::new(Stats::new()),
user,
Some(1024),
Arc::clone(&quota_exceeded),
tokio::time::Instant::now(),
);
let (_wake_counter, mut cx) = build_context();
let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x64]);
assert!(matches!(poll, Poll::Ready(Err(ref e)) if is_quota_io_error(e)));
}
#[tokio::test]
async fn adversarial_repoll_while_middle_lock_held_keeps_pending_without_usage_leak() {
let _guard = quota_test_guard();
let user = format!("cross-mode-repoll-held-{}", std::process::id());
let stats = Arc::new(Stats::new());
let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user);
let _held_guard = held
.try_lock()
.expect("test must hold lock for repoll sequence");
let mut io = StatsIo::new(
tokio::io::sink(),
Arc::new(SharedCounters::new()),
Arc::clone(&stats),
user.clone(),
Some(1024),
Arc::new(AtomicBool::new(false)),
tokio::time::Instant::now(),
);
let (_wake_counter, mut cx) = build_context();
for _ in 0..8 {
let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x65]);
assert!(poll.is_pending());
}
assert_eq!(stats.get_user_total_octets(&user), 0);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn integration_same_user_mixed_read_write_waiters_resume_after_release() {
let _guard = quota_test_guard();
let user = format!("cross-mode-mixed-resume-{}", std::process::id());
let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user);
let held_guard = held
.try_lock()
.expect("test must hold lock before spawning mixed waiters");
let mut tasks = Vec::new();
for i in 0..12usize {
let user = user.clone();
tasks.push(tokio::spawn(async move {
if i % 2 == 0 {
let mut io = StatsIo::new(
tokio::io::empty(),
Arc::new(SharedCounters::new()),
Arc::new(Stats::new()),
user,
Some(1024),
Arc::new(AtomicBool::new(false)),
tokio::time::Instant::now(),
);
let mut b = [0u8; 1];
io.read(&mut b).await.map(|_| ())
} else {
let mut io = StatsIo::new(
tokio::io::sink(),
Arc::new(SharedCounters::new()),
Arc::new(Stats::new()),
user,
Some(1024),
Arc::new(AtomicBool::new(false)),
tokio::time::Instant::now(),
);
io.write_all(&[0x66]).await
}
}));
}
tokio::time::sleep(Duration::from_millis(8)).await;
drop(held_guard);
timeout(Duration::from_secs(1), async {
for task in tasks {
let result = task.await.expect("mixed waiter task must not panic");
assert!(result.is_ok());
}
})
.await
.expect("all mixed waiters must finish after release");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn integration_one_user_blocked_other_user_progresses_under_middle_lock() {
let _guard = quota_test_guard();
let blocked_user = format!("cross-mode-blocked-{}", std::process::id());
let free_user = format!("cross-mode-free-{}", std::process::id());
let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&blocked_user);
let held_guard = held
.try_lock()
.expect("test must hold blocked user lock");
let blocked_task = tokio::spawn({
let blocked_user = blocked_user.clone();
async move {
let mut io = StatsIo::new(
tokio::io::sink(),
Arc::new(SharedCounters::new()),
Arc::new(Stats::new()),
blocked_user,
Some(1024),
Arc::new(AtomicBool::new(false)),
tokio::time::Instant::now(),
);
io.write_all(&[0x77]).await
}
});
let free_task = tokio::spawn({
let free_user = free_user.clone();
async move {
let mut io = StatsIo::new(
tokio::io::sink(),
Arc::new(SharedCounters::new()),
Arc::new(Stats::new()),
free_user,
Some(1024),
Arc::new(AtomicBool::new(false)),
tokio::time::Instant::now(),
);
io.write_all(&[0x78]).await
}
});
let free_done = timeout(Duration::from_millis(250), free_task)
.await
.expect("free user must not be blocked")
.expect("free user task must not panic");
assert!(free_done.is_ok());
drop(held_guard);
let blocked_done = timeout(Duration::from_secs(1), blocked_task)
.await
.expect("blocked user must resume after release")
.expect("blocked user task must not panic");
assert!(blocked_done.is_ok());
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn stress_middle_lock_release_allows_high_waiter_fanout_completion() {
let _guard = quota_test_guard();
let user = format!("cross-mode-fanout-{}", std::process::id());
let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user);
let held_guard = held
.try_lock()
.expect("test must hold lock before fanout");
let waiters = 48usize;
let gate = Arc::new(Barrier::new(waiters + 1));
let mut tasks = Vec::new();
for _ in 0..waiters {
let user = user.clone();
let gate = Arc::clone(&gate);
tasks.push(tokio::spawn(async move {
let mut io = StatsIo::new(
tokio::io::sink(),
Arc::new(SharedCounters::new()),
Arc::new(Stats::new()),
user,
Some(1024),
Arc::new(AtomicBool::new(false)),
tokio::time::Instant::now(),
);
gate.wait().await;
io.write_all(&[0x79]).await
}));
}
gate.wait().await;
tokio::time::sleep(Duration::from_millis(10)).await;
drop(held_guard);
timeout(Duration::from_secs(2), async {
for task in tasks {
let result = task.await.expect("fanout task must not panic");
assert!(result.is_ok());
}
})
.await
.expect("fanout waiters must complete after release");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn light_fuzz_middle_lock_hold_release_cycles_preserve_same_user_liveness() {
let _guard = quota_test_guard();
let mut seed = 0xA11C_EE55_2026_0323u64;
for round in 0..20u32 {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let hold_ms = 2 + (seed % 10);
let user = format!("cross-mode-middle-fuzz-{}-{round}", std::process::id());
let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user);
let held_guard = held
.try_lock()
.expect("test must hold lock in fuzz round");
let writer = tokio::spawn({
let user = user.clone();
async move {
let mut io = StatsIo::new(
tokio::io::sink(),
Arc::new(SharedCounters::new()),
Arc::new(Stats::new()),
user,
Some(1024),
Arc::new(AtomicBool::new(false)),
tokio::time::Instant::now(),
);
io.write_all(&[0x7A]).await
}
});
tokio::time::sleep(Duration::from_millis(hold_ms)).await;
drop(held_guard);
let done = timeout(Duration::from_millis(400), writer)
.await
.expect("writer must complete after lock release")
.expect("writer task must not panic");
assert!(done.is_ok());
}
}

View File

@ -0,0 +1,81 @@
use super::*;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::task::Waker;
use std::task::{Context, Poll};
#[derive(Default)]
struct WakeCounter {
wakes: AtomicUsize,
}
impl std::task::Wake for WakeCounter {
fn wake(self: Arc<Self>) {
self.wakes.fetch_add(1, Ordering::Relaxed);
}
fn wake_by_ref(self: &Arc<Self>) {
self.wakes.fetch_add(1, Ordering::Relaxed);
}
}
fn build_context() -> (Arc<WakeCounter>, Context<'static>) {
let wake_counter = Arc::new(WakeCounter::default());
let waker = Waker::from(Arc::clone(&wake_counter));
let leaked_waker: &'static Waker = Box::leak(Box::new(waker));
(wake_counter, Context::from_waker(leaked_waker))
}
#[tokio::test]
async fn adversarial_middle_held_cross_mode_lock_blocks_relay_writer() {
let _guard = quota_user_lock_test_scope();
let user = "cross-mode-lock-shared-user";
let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(user);
let _held_guard = held
.try_lock()
.expect("test must hold shared cross-mode lock before relay poll");
let mut io = StatsIo::new(
tokio::io::sink(),
Arc::new(SharedCounters::new()),
Arc::new(crate::stats::Stats::new()),
user.to_string(),
Some(1024),
Arc::new(AtomicBool::new(false)),
tokio::time::Instant::now(),
);
let (_wake_counter, mut cx) = build_context();
let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x41, 0x42, 0x43]);
assert!(
matches!(poll, Poll::Pending),
"relay writer must not bypass cross-mode lock held by middle-relay path"
);
}
#[tokio::test]
async fn business_cross_mode_lock_uncontended_allows_relay_writer_progress() {
let _guard = quota_user_lock_test_scope();
let user = "cross-mode-lock-progress-user";
let mut io = StatsIo::new(
tokio::io::sink(),
Arc::new(SharedCounters::new()),
Arc::new(crate::stats::Stats::new()),
user.to_string(),
Some(1024),
Arc::new(AtomicBool::new(false)),
tokio::time::Instant::now(),
);
let (_wake_counter, mut cx) = build_context();
let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x51, 0x52]);
assert!(
matches!(poll, Poll::Ready(Ok(2))),
"relay writer should progress when shared cross-mode lock is uncontended"
);
}

View File

@ -0,0 +1,340 @@
use super::*;
use crate::stats::Stats;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::task::{Context, Waker};
use tokio::io::AsyncWriteExt;
use tokio::time::{Duration, Instant, timeout};
#[derive(Default)]
struct WakeCounter {
wakes: AtomicUsize,
}
impl std::task::Wake for WakeCounter {
fn wake(self: Arc<Self>) {
self.wakes.fetch_add(1, Ordering::Relaxed);
}
fn wake_by_ref(self: &Arc<Self>) {
self.wakes.fetch_add(1, Ordering::Relaxed);
}
}
fn quota_test_guard() -> impl Drop {
super::quota_user_lock_test_scope()
}
#[tokio::test]
async fn positive_uncontended_dual_lock_writer_has_zero_retry_attempt() {
let _guard = quota_test_guard();
let mut io = StatsIo::new(
tokio::io::sink(),
Arc::new(SharedCounters::new()),
Arc::new(Stats::new()),
format!("dual-lock-alt-positive-{}", std::process::id()),
Some(2048),
Arc::new(AtomicBool::new(false)),
Instant::now(),
);
let write = io.write_all(&[0xAA, 0xBB]).await;
assert!(write.is_ok(), "uncontended write must complete");
assert_eq!(
io.quota_write_retry_attempt, 0,
"uncontended write must not advance retry backoff"
);
}
#[tokio::test]
async fn adversarial_alternating_local_and_cross_mode_contention_preserves_backoff_growth() {
let _guard = quota_test_guard();
let user = format!("dual-lock-alt-adversarial-{}", std::process::id());
let local_lock = quota_user_lock(&user);
let cross_mode_lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
let mut local_guard = Some(
local_lock
.try_lock()
.expect("test must hold local quota lock initially"),
);
let mut cross_guard = None;
let mut io = StatsIo::new(
tokio::io::sink(),
Arc::new(SharedCounters::new()),
Arc::new(Stats::new()),
user,
Some(2048),
Arc::new(AtomicBool::new(false)),
Instant::now(),
);
let wake_counter = Arc::new(WakeCounter::default());
let waker = Waker::from(Arc::clone(&wake_counter));
let mut cx = Context::from_waker(&waker);
let first = Pin::new(&mut io).poll_write(&mut cx, &[0x11]);
assert!(first.is_pending(), "held local lock must block first poll");
let mut observed_wakes = 0usize;
for idx in 0..18usize {
tokio::time::sleep(Duration::from_millis(6)).await;
if idx % 2 == 0 {
drop(local_guard.take());
cross_guard = Some(
cross_mode_lock
.try_lock()
.expect("cross-mode lock should be acquirable while local lock released"),
);
} else {
drop(cross_guard.take());
local_guard = Some(
local_lock
.try_lock()
.expect("local lock should be acquirable while cross lock released"),
);
}
let wakes = wake_counter.wakes.load(Ordering::Relaxed);
if wakes > observed_wakes {
observed_wakes = wakes;
let pending = Pin::new(&mut io).poll_write(&mut cx, &[0x12]);
assert!(
pending.is_pending(),
"alternating contention must keep write pending while one lock is held"
);
}
}
assert!(
io.quota_write_retry_attempt >= 2,
"alternating contention must still ramp retry backoff; got {}",
io.quota_write_retry_attempt
);
assert!(
wake_counter.wakes.load(Ordering::Relaxed) <= 32,
"alternating contention must stay wake-rate-limited"
);
drop(local_guard);
drop(cross_guard);
let ready = Pin::new(&mut io).poll_write(&mut cx, &[0x13]);
assert!(ready.is_ready(), "writer must resume after both locks released");
}
#[tokio::test]
async fn edge_retry_scheduler_resets_after_alternating_contention_clears() {
let _guard = quota_test_guard();
let user = format!("dual-lock-alt-edge-reset-{}", std::process::id());
let local_lock = quota_user_lock(&user);
let local_guard = local_lock
.try_lock()
.expect("test must hold local lock for edge scenario");
let mut io = StatsIo::new(
tokio::io::sink(),
Arc::new(SharedCounters::new()),
Arc::new(Stats::new()),
user,
Some(2048),
Arc::new(AtomicBool::new(false)),
Instant::now(),
);
let wake_counter = Arc::new(WakeCounter::default());
let waker = Waker::from(Arc::clone(&wake_counter));
let mut cx = Context::from_waker(&waker);
let first = Pin::new(&mut io).poll_write(&mut cx, &[0x21]);
assert!(first.is_pending());
tokio::time::sleep(Duration::from_millis(15)).await;
if wake_counter.wakes.load(Ordering::Relaxed) > 0 {
let next = Pin::new(&mut io).poll_write(&mut cx, &[0x22]);
assert!(next.is_pending());
}
drop(local_guard);
let ready = Pin::new(&mut io).poll_write(&mut cx, &[0x23]);
assert!(ready.is_ready());
assert_eq!(
io.quota_write_retry_attempt, 0,
"successful dual-lock acquisition must reset retry scheduler"
);
assert!(!io.quota_write_wake_scheduled);
assert!(io.quota_write_retry_sleep.is_none());
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn integration_cross_mode_waiters_remain_live_under_alternating_contention_then_resume() {
let _guard = quota_test_guard();
let user = format!("dual-lock-alt-integration-{}", std::process::id());
let local_lock = quota_user_lock(&user);
let cross_mode_lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
let mut waiters = Vec::new();
for _ in 0..16usize {
let user = user.clone();
waiters.push(tokio::spawn(async move {
let mut io = StatsIo::new(
tokio::io::sink(),
Arc::new(SharedCounters::new()),
Arc::new(Stats::new()),
user,
Some(2048),
Arc::new(AtomicBool::new(false)),
Instant::now(),
);
timeout(Duration::from_secs(2), io.write_all(&[0x31])).await
}));
}
let mut local_guard = Some(
local_lock
.try_lock()
.expect("integration toggle must acquire local lock first"),
);
let mut cross_guard = None;
for idx in 0..24usize {
tokio::time::sleep(Duration::from_millis(4)).await;
if idx % 2 == 0 {
drop(local_guard.take());
cross_guard = cross_mode_lock.try_lock().ok();
} else {
drop(cross_guard.take());
local_guard = local_lock.try_lock().ok();
}
}
drop(local_guard);
drop(cross_guard);
for waiter in waiters {
let done = waiter.await.expect("waiter task must not panic");
assert!(
done.is_ok(),
"waiter must finish once alternating contention window ends"
);
assert!(done.expect("waiter timeout must not fire").is_ok());
}
}
#[tokio::test]
async fn light_fuzz_alternating_contention_matrix_preserves_lock_gating() {
let _guard = quota_test_guard();
let user = format!("dual-lock-alt-fuzz-{}", std::process::id());
let local_lock = quota_user_lock(&user);
let cross_mode_lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
let mut seed = 0xD00D_BAAD_F00D_2026u64;
for _round in 0..64u32 {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let hold_mode = (seed % 3) as u8;
let local_guard = if hold_mode == 0 {
Some(
local_lock
.try_lock()
.expect("fuzz local lock should be acquirable"),
)
} else {
None
};
let cross_guard = if hold_mode == 1 {
Some(
cross_mode_lock
.try_lock()
.expect("fuzz cross lock should be acquirable"),
)
} else {
None
};
let mut io = StatsIo::new(
tokio::io::sink(),
Arc::new(SharedCounters::new()),
Arc::new(Stats::new()),
user.clone(),
Some(1024),
Arc::new(AtomicBool::new(false)),
Instant::now(),
);
let write = timeout(Duration::from_millis(35), io.write_all(&[0x51])).await;
if hold_mode == 2 {
assert!(write.is_ok(), "unheld fuzz round must make progress");
assert!(write.expect("unheld round timeout").is_ok());
} else {
assert!(
write.is_err(),
"held-lock fuzz round must remain pending inside bounded window"
);
}
drop(local_guard);
drop(cross_guard);
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn stress_fanout_alternating_contention_recovers_without_hanging() {
let _guard = quota_test_guard();
let user = format!("dual-lock-alt-stress-{}", std::process::id());
let local_lock = quota_user_lock(&user);
let cross_mode_lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
let mut waiters = Vec::new();
for _ in 0..48usize {
let user = user.clone();
waiters.push(tokio::spawn(async move {
let mut io = StatsIo::new(
tokio::io::sink(),
Arc::new(SharedCounters::new()),
Arc::new(Stats::new()),
user,
Some(4096),
Arc::new(AtomicBool::new(false)),
Instant::now(),
);
timeout(Duration::from_secs(3), io.write_all(&[0xA0, 0xA1])).await
}));
}
let mut local_guard = Some(
local_lock
.try_lock()
.expect("stress toggle must acquire local lock first"),
);
let mut cross_guard = None;
for idx in 0..40usize {
tokio::time::sleep(Duration::from_millis(3)).await;
if idx % 2 == 0 {
drop(local_guard.take());
cross_guard = cross_mode_lock.try_lock().ok();
} else {
drop(cross_guard.take());
local_guard = local_lock.try_lock().ok();
}
}
drop(local_guard);
drop(cross_guard);
for waiter in waiters {
let done = waiter.await.expect("stress waiter task must not panic");
assert!(done.is_ok(), "stress waiter timed out under alternating contention");
assert!(done.expect("stress waiter timeout should not fire").is_ok());
}
}

View File

@ -0,0 +1,74 @@
use super::*;
use crate::stats::Stats;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::task::{Context, Waker};
use tokio::time::{Duration, Instant};
#[derive(Default)]
struct WakeCounter {
wakes: AtomicUsize,
}
impl std::task::Wake for WakeCounter {
fn wake(self: Arc<Self>) {
self.wakes.fetch_add(1, Ordering::Relaxed);
}
fn wake_by_ref(self: &Arc<Self>) {
self.wakes.fetch_add(1, Ordering::Relaxed);
}
}
fn quota_test_guard() -> impl Drop {
super::quota_user_lock_test_scope()
}
#[tokio::test]
async fn adversarial_cross_mode_only_contention_backoff_attempt_must_ramp() {
let _guard = quota_test_guard();
let user = format!("dual-lock-backoff-{}", std::process::id());
let cross_mode_lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
let held_cross_mode_guard = cross_mode_lock
.try_lock()
.expect("test must hold cross-mode lock before polling");
let mut io = StatsIo::new(
tokio::io::sink(),
Arc::new(SharedCounters::new()),
Arc::new(Stats::new()),
user,
Some(2048),
Arc::new(AtomicBool::new(false)),
Instant::now(),
);
let wake_counter = Arc::new(WakeCounter::default());
let waker = Waker::from(Arc::clone(&wake_counter));
let mut cx = Context::from_waker(&waker);
let first = Pin::new(&mut io).poll_write(&mut cx, &[0xAA]);
assert!(first.is_pending(), "held cross-mode lock must block writer");
let started = Instant::now();
let mut last_wakes = 0usize;
while started.elapsed() < Duration::from_millis(120) {
let wakes = wake_counter.wakes.load(Ordering::Relaxed);
if wakes > last_wakes {
last_wakes = wakes;
let next = Pin::new(&mut io).poll_write(&mut cx, &[0xAB]);
assert!(next.is_pending(), "writer must remain blocked while lock is held");
}
tokio::time::sleep(Duration::from_millis(1)).await;
}
assert!(
io.quota_write_retry_attempt >= 2,
"retry attempt must ramp under sustained second-lock contention; got {}",
io.quota_write_retry_attempt
);
drop(held_cross_mode_guard);
}

View File

@ -0,0 +1,325 @@
use super::*;
use crate::stats::Stats;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::task::{Context, Waker};
use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadBuf};
use tokio::time::{Duration, Instant, timeout};
#[derive(Default)]
struct WakeCounter {
wakes: AtomicUsize,
}
impl std::task::Wake for WakeCounter {
fn wake(self: Arc<Self>) {
self.wakes.fetch_add(1, Ordering::Relaxed);
}
fn wake_by_ref(self: &Arc<Self>) {
self.wakes.fetch_add(1, Ordering::Relaxed);
}
}
fn quota_test_guard() -> impl Drop {
super::quota_user_lock_test_scope()
}
fn build_context() -> (Arc<WakeCounter>, Context<'static>) {
let wake_counter = Arc::new(WakeCounter::default());
let waker = Waker::from(Arc::clone(&wake_counter));
let leaked_waker: &'static Waker = Box::leak(Box::new(waker));
(wake_counter, Context::from_waker(leaked_waker))
}
#[tokio::test]
async fn positive_uncontended_dual_locks_writer_completes_without_retry_state() {
let _guard = quota_test_guard();
let mut io = StatsIo::new(
tokio::io::sink(),
Arc::new(SharedCounters::new()),
Arc::new(Stats::new()),
format!("dual-lock-positive-{}", std::process::id()),
Some(4096),
Arc::new(AtomicBool::new(false)),
Instant::now(),
);
let (_wake_counter, mut cx) = build_context();
let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x01, 0x02, 0x03]);
assert!(poll.is_ready());
assert_eq!(io.quota_write_retry_attempt, 0);
assert!(!io.quota_write_wake_scheduled);
assert!(io.quota_write_retry_sleep.is_none());
}
#[tokio::test]
async fn negative_local_lock_contention_read_retry_attempt_ramps() {
let _guard = quota_test_guard();
let user = format!("dual-lock-local-contention-{}", std::process::id());
let held = quota_user_lock(&user);
let held_guard = held
.try_lock()
.expect("test must hold local quota lock before polling");
let mut io = StatsIo::new(
tokio::io::empty(),
Arc::new(SharedCounters::new()),
Arc::new(Stats::new()),
user,
Some(2048),
Arc::new(AtomicBool::new(false)),
Instant::now(),
);
let (wake_counter, mut cx) = build_context();
let mut one = [0u8; 1];
let mut buf = ReadBuf::new(&mut one);
let first = Pin::new(&mut io).poll_read(&mut cx, &mut buf);
assert!(first.is_pending());
let started = Instant::now();
let mut observed = 0usize;
while started.elapsed() < Duration::from_millis(120) {
let wakes = wake_counter.wakes.load(Ordering::Relaxed);
if wakes > observed {
observed = wakes;
let mut step_buf = ReadBuf::new(&mut one);
let next = Pin::new(&mut io).poll_read(&mut cx, &mut step_buf);
assert!(next.is_pending());
}
tokio::time::sleep(Duration::from_millis(1)).await;
}
assert!(
io.quota_read_retry_attempt >= 2,
"retry attempt must ramp under sustained local-lock contention; got {}",
io.quota_read_retry_attempt
);
drop(held_guard);
}
#[tokio::test]
async fn edge_cross_mode_contention_release_resets_retry_scheduler_on_success() {
let _guard = quota_test_guard();
let user = format!("dual-lock-reset-{}", std::process::id());
let cross_mode = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
let held_guard = cross_mode
.try_lock()
.expect("test must hold cross-mode lock before polling");
let mut io = StatsIo::new(
tokio::io::sink(),
Arc::new(SharedCounters::new()),
Arc::new(Stats::new()),
user,
Some(2048),
Arc::new(AtomicBool::new(false)),
Instant::now(),
);
let (wake_counter, mut cx) = build_context();
let first = Pin::new(&mut io).poll_write(&mut cx, &[0x10]);
assert!(first.is_pending());
tokio::time::sleep(Duration::from_millis(20)).await;
if wake_counter.wakes.load(Ordering::Relaxed) > 0 {
let next = Pin::new(&mut io).poll_write(&mut cx, &[0x11]);
assert!(next.is_pending());
}
drop(held_guard);
let ready = Pin::new(&mut io).poll_write(&mut cx, &[0x12]);
assert!(ready.is_ready());
assert_eq!(io.quota_write_retry_attempt, 0);
assert!(!io.quota_write_wake_scheduled);
assert!(io.quota_write_retry_sleep.is_none());
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn adversarial_cross_mode_hold_blocks_many_waiters_without_usage_leak() {
let _guard = quota_test_guard();
let user = format!("dual-lock-adversarial-{}", std::process::id());
let stats = Arc::new(Stats::new());
let held = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
let held_guard = held
.try_lock()
.expect("test must hold cross-mode lock before launching waiters");
let mut tasks = Vec::new();
for _ in 0..24usize {
let stats = Arc::clone(&stats);
let user = user.clone();
tasks.push(tokio::spawn(async move {
let mut io = StatsIo::new(
tokio::io::sink(),
Arc::new(SharedCounters::new()),
stats,
user,
Some(1024),
Arc::new(AtomicBool::new(false)),
Instant::now(),
);
timeout(Duration::from_millis(40), io.write_all(&[0x33])).await
}));
}
for task in tasks {
let timed = task.await.expect("waiter task must not panic");
assert!(timed.is_err(), "held cross-mode lock must keep waiter pending");
}
assert_eq!(stats.get_user_total_octets(&user), 0);
drop(held_guard);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn integration_waiters_resume_after_cross_mode_release() {
let _guard = quota_test_guard();
let user = format!("dual-lock-integration-{}", std::process::id());
let held = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
let held_guard = held
.try_lock()
.expect("test must hold cross-mode lock before starting waiter");
let task = tokio::spawn({
let user = user.clone();
async move {
let mut io = StatsIo::new(
tokio::io::sink(),
Arc::new(SharedCounters::new()),
Arc::new(Stats::new()),
user,
Some(1024),
Arc::new(AtomicBool::new(false)),
Instant::now(),
);
io.write_all(&[0x44]).await
}
});
tokio::time::sleep(Duration::from_millis(10)).await;
drop(held_guard);
let done = timeout(Duration::from_secs(1), task)
.await
.expect("waiter task must complete after release")
.expect("waiter task must not panic");
assert!(done.is_ok());
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn light_fuzz_randomized_lock_holds_preserve_liveness_and_quota_bounds() {
let _guard = quota_test_guard();
let user = format!("dual-lock-fuzz-{}", std::process::id());
let stats = Arc::new(Stats::new());
let mut seed = 0xA55A_55AA_C3D2_E1F0u64;
for _round in 0..48u32 {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let hold_mode = (seed % 3) as u8;
let mut local_lock = None;
let mut cross_lock = None;
let mut local_guard = None;
let mut cross_guard = None;
if hold_mode == 0 {
local_lock = Some(quota_user_lock(&user));
local_guard = Some(
local_lock
.as_ref()
.expect("local lock should be present")
.try_lock()
.expect("local lock should be acquirable in fuzz round"),
);
} else if hold_mode == 1 {
cross_lock = Some(crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(
&user,
));
cross_guard = Some(
cross_lock
.as_ref()
.expect("cross lock should be present")
.try_lock()
.expect("cross lock should be acquirable in fuzz round"),
);
}
let mut io = StatsIo::new(
tokio::io::sink(),
Arc::new(SharedCounters::new()),
Arc::clone(&stats),
user.clone(),
Some(4096),
Arc::new(AtomicBool::new(false)),
Instant::now(),
);
let write = timeout(Duration::from_millis(25), io.write_all(&[0x7A])).await;
if hold_mode == 2 {
assert!(write.is_ok(), "unheld round must make progress");
} else {
assert!(write.is_err(), "held-lock round must stay blocked within timeout");
}
drop(local_guard);
drop(cross_guard);
drop(local_lock);
drop(cross_lock);
}
assert!(stats.get_user_total_octets(&user) <= 4096);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn stress_fanout_waiters_complete_after_release_without_panics() {
let _guard = quota_test_guard();
let user = format!("dual-lock-stress-{}", std::process::id());
let held = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
let held_guard = held
.try_lock()
.expect("test must hold cross-mode lock before stress fanout");
let waiters = 64usize;
let mut tasks = Vec::new();
for _ in 0..waiters {
let user = user.clone();
tasks.push(tokio::spawn(async move {
let mut io = StatsIo::new(
tokio::io::empty(),
Arc::new(SharedCounters::new()),
Arc::new(Stats::new()),
user,
Some(1024),
Arc::new(AtomicBool::new(false)),
Instant::now(),
);
let mut one = [0u8; 1];
io.read(&mut one).await
}));
}
tokio::time::sleep(Duration::from_millis(12)).await;
drop(held_guard);
timeout(Duration::from_secs(2), async {
for task in tasks {
let result = task.await.expect("stress waiter task must not panic");
assert!(result.is_ok());
}
})
.await
.expect("all stress waiters must complete after release");
}

View File

@ -0,0 +1,128 @@
use super::*;
use crate::stats::Stats;
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use tokio::io::AsyncWriteExt;
use tokio::time::{Duration, timeout};
fn quota_test_guard() -> impl Drop {
super::quota_user_lock_test_scope()
}
fn make_stats_io(user: String) -> StatsIo<tokio::io::Sink> {
StatsIo::new(
tokio::io::sink(),
Arc::new(SharedCounters::new()),
Arc::new(Stats::new()),
user,
Some(4096),
Arc::new(AtomicBool::new(false)),
tokio::time::Instant::now(),
)
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn light_fuzz_1024_round_hold_release_cycles_preserve_same_user_liveness() {
let _guard = quota_test_guard();
let user = format!("dual-lock-race-fuzz-{}", std::process::id());
let mut seed = 0xD1CE_BAAD_5EED_1234u64;
for round in 0..1024u32 {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let hold = (seed & 1) == 0;
let hold_ms = (seed % 3) as u64;
let maybe_lock = if hold {
Some(crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(
&user,
))
} else {
None
};
let maybe_guard = maybe_lock.as_ref().map(|lock| {
lock.try_lock()
.expect("cross-mode lock must be acquirable in fuzz round")
});
if hold {
let mut blocked_io = make_stats_io(user.clone());
let blocked = timeout(Duration::from_millis(5), blocked_io.write_all(&[0xA5])).await;
assert!(
blocked.is_err(),
"held round must block waiter before lock release (round={round})"
);
if hold_ms > 0 {
tokio::time::sleep(Duration::from_millis(hold_ms)).await;
}
} else {
let mut free_io = make_stats_io(user.clone());
let free = timeout(Duration::from_millis(120), free_io.write_all(&[0xA5])).await;
assert!(
free.is_ok(),
"unheld round must complete promptly (round={round})"
);
assert!(free.expect("unheld round should complete").is_ok());
}
drop(maybe_guard);
let done = timeout(Duration::from_millis(350), async {
let user = user.clone();
let mut io = make_stats_io(user);
io.write_all(&[0xA6]).await
})
.await
.expect("post-release write must complete in bounded time");
assert!(done.is_ok());
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn stress_jittered_three_waiter_rounds_do_not_starve_after_release() {
let _guard = quota_test_guard();
let user = format!("dual-lock-race-stress-{}", std::process::id());
let mut seed = 0xC0FF_EE77_4444_9999u64;
for round in 0..256u32 {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let hold_ms = (seed % 4) as u64;
let lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
let guard = lock
.try_lock()
.expect("cross-mode lock must be acquirable at round start");
let mut waiters = Vec::new();
for _ in 0..3usize {
let user = user.clone();
waiters.push(tokio::spawn(async move {
let mut io = make_stats_io(user);
io.write_all(&[0x55]).await
}));
}
tokio::time::sleep(Duration::from_millis(hold_ms)).await;
drop(guard);
timeout(Duration::from_secs(1), async {
for waiter in waiters {
let done = waiter.await.expect("waiter task must not panic");
assert!(
done.is_ok(),
"waiter must complete after release (round={round})"
);
}
})
.await
.expect("all waiters must complete in bounded time after release");
}
}

View File

@ -0,0 +1,332 @@
use super::relay_bidirectional;
use crate::error::ProxyError;
use crate::stats::Stats;
use crate::stream::BufferPool;
use rand::rngs::StdRng;
use rand::{RngExt, SeedableRng};
use std::sync::Arc;
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};
use tokio::time::{Duration, timeout};
async fn read_available<R: tokio::io::AsyncRead + Unpin>(reader: &mut R, budget: Duration) -> usize {
let start = tokio::time::Instant::now();
let mut total = 0usize;
let mut buf = [0u8; 128];
loop {
let elapsed = start.elapsed();
if elapsed >= budget {
break;
}
let remaining = budget.saturating_sub(elapsed);
match timeout(remaining, reader.read(&mut buf)).await {
Ok(Ok(0)) => break,
Ok(Ok(n)) => total = total.saturating_add(n),
Ok(Err(_)) | Err(_) => break,
}
}
total
}
#[tokio::test]
async fn positive_quota_path_forwards_both_directions_within_limit() {
let stats = Arc::new(Stats::new());
let user = "quota-extended-positive-user";
let (mut client_peer, relay_client) = duplex(2048);
let (relay_server, mut server_peer) = duplex(2048);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let relay = tokio::spawn(relay_bidirectional(
client_reader,
client_writer,
server_reader,
server_writer,
256,
256,
user,
Arc::clone(&stats),
Some(16),
Arc::new(BufferPool::new()),
));
client_peer.write_all(&[0xAA, 0xBB, 0xCC, 0xDD]).await.unwrap();
server_peer.read_exact(&mut [0u8; 4]).await.unwrap();
server_peer.write_all(&[0x11, 0x22, 0x33, 0x44]).await.unwrap();
client_peer.read_exact(&mut [0u8; 4]).await.unwrap();
drop(client_peer);
drop(server_peer);
let relay_result = timeout(Duration::from_secs(2), relay).await.unwrap().unwrap();
assert!(relay_result.is_ok());
assert!(stats.get_user_total_octets(user) <= 16);
}
#[tokio::test]
async fn negative_preloaded_quota_forbids_any_forwarding() {
let stats = Arc::new(Stats::new());
let user = "quota-extended-negative-user";
stats.add_user_octets_from(user, 8);
let (mut client_peer, relay_client) = duplex(1024);
let (relay_server, mut server_peer) = duplex(1024);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let relay = tokio::spawn(relay_bidirectional(
client_reader,
client_writer,
server_reader,
server_writer,
128,
128,
user,
Arc::clone(&stats),
Some(8),
Arc::new(BufferPool::new()),
));
client_peer.write_all(&[0xAA]).await.unwrap();
server_peer.write_all(&[0xBB]).await.unwrap();
assert_eq!(read_available(&mut server_peer, Duration::from_millis(120)).await, 0);
assert_eq!(read_available(&mut client_peer, Duration::from_millis(120)).await, 0);
let relay_result = timeout(Duration::from_secs(2), relay).await.unwrap().unwrap();
assert!(matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. })));
assert!(stats.get_user_total_octets(user) <= 8);
}
#[tokio::test]
async fn edge_quota_one_ensures_at_most_one_byte_across_directions() {
let stats = Arc::new(Stats::new());
let user = "quota-extended-edge-user";
let (mut client_peer, relay_client) = duplex(1024);
let (relay_server, mut server_peer) = duplex(1024);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let relay = tokio::spawn(relay_bidirectional(
client_reader,
client_writer,
server_reader,
server_writer,
128,
128,
user,
Arc::clone(&stats),
Some(1),
Arc::new(BufferPool::new()),
));
let _ = tokio::join!(
client_peer.write_all(&[0xFE]),
server_peer.write_all(&[0xEF]),
);
let mut buf = [0u8; 1];
let delivered_s2c = timeout(Duration::from_millis(120), client_peer.read(&mut buf)).await.unwrap().unwrap_or(0);
let delivered_c2s = timeout(Duration::from_millis(120), server_peer.read(&mut buf)).await.unwrap().unwrap_or(0);
assert!(delivered_s2c + delivered_c2s <= 1);
let relay_result = timeout(Duration::from_secs(2), relay).await.unwrap().unwrap();
assert!(matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. })));
}
#[tokio::test]
async fn adversarial_blackhat_alternating_jitter_does_not_overshoot_quota() {
let stats = Arc::new(Stats::new());
let user = "quota-extended-blackhat-user";
let quota = 24u64;
let (mut client_peer, relay_client) = duplex(4096);
let (relay_server, mut server_peer) = duplex(4096);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let relay = tokio::spawn(relay_bidirectional(
client_reader,
client_writer,
server_reader,
server_writer,
256,
256,
user,
Arc::clone(&stats),
Some(quota),
Arc::new(BufferPool::new()),
));
let mut total_forwarded = 0usize;
for i in 0..256usize {
if relay.is_finished() {
break;
}
if (i & 1) == 0 {
let _ = client_peer.write_all(&[(i as u8) ^ 0x57]).await;
let mut one = [0u8; 1];
if let Ok(Ok(n)) = timeout(Duration::from_millis(6), server_peer.read(&mut one)).await {
total_forwarded += n;
}
} else {
let _ = server_peer.write_all(&[(i as u8) ^ 0xA8]).await;
let mut one = [0u8; 1];
if let Ok(Ok(n)) = timeout(Duration::from_millis(6), client_peer.read(&mut one)).await {
total_forwarded += n;
}
}
tokio::time::sleep(Duration::from_millis(((i % 3) + 1) as u64)).await;
}
let relay_result = timeout(Duration::from_secs(3), relay).await.unwrap().unwrap();
assert!(matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. })));
assert!(total_forwarded <= quota as usize);
assert!(stats.get_user_total_octets(user) <= quota);
}
#[tokio::test]
async fn light_fuzz_random_quota_schedule_preserves_quota_invariants() {
let mut rng = StdRng::seed_from_u64(0xBEEF_C0DE);
for case in 0..32u64 {
let stats = Arc::new(Stats::new());
let user = format!("quota-extended-fuzz-{case}");
let quota = rng.random_range(1u64..=35u64);
let (mut client_peer, relay_client) = duplex(4096);
let (relay_server, mut server_peer) = duplex(4096);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let relay_user = user.clone();
let relay_stats = Arc::clone(&stats);
let relay = tokio::spawn(async move {
relay_bidirectional(
client_reader,
client_writer,
server_reader,
server_writer,
256,
256,
&relay_user,
Arc::clone(&relay_stats),
Some(quota),
Arc::new(BufferPool::new()),
)
.await
});
let mut total_forwarded = 0usize;
for _ in 0..96usize {
if relay.is_finished() {
break;
}
if rng.random::<bool>() {
let _ = client_peer.write_all(&[rng.random::<u8>()]).await;
let mut one = [0u8; 1];
if let Ok(Ok(n)) = timeout(Duration::from_millis(4), server_peer.read(&mut one)).await {
total_forwarded += n;
}
} else {
let _ = server_peer.write_all(&[rng.random::<u8>()]).await;
let mut one = [0u8; 1];
if let Ok(Ok(n)) = timeout(Duration::from_millis(4), client_peer.read(&mut one)).await {
total_forwarded += n;
}
}
}
drop(client_peer);
drop(server_peer);
let relay_result = timeout(Duration::from_secs(2), relay).await.unwrap().unwrap();
assert!(relay_result.is_ok() || matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. })));
assert!(total_forwarded <= quota as usize);
assert!(stats.get_user_total_octets(&user) <= quota);
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn stress_parallel_relays_for_one_user_obey_global_quota() {
let stats = Arc::new(Stats::new());
let user = "quota-extended-stress-user".to_string();
let quota = 64u64;
let mut tasks = Vec::new();
for worker in 0..4u8 {
let stats = Arc::clone(&stats);
let user = user.clone();
tasks.push(tokio::spawn(async move {
let (mut client_peer, relay_client) = duplex(2048);
let (relay_server, mut server_peer) = duplex(2048);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let relay_user = user.clone();
let relay_stats = Arc::clone(&stats);
let relay = tokio::spawn(async move {
relay_bidirectional(
client_reader,
client_writer,
server_reader,
server_writer,
128,
128,
&relay_user,
Arc::clone(&relay_stats),
Some(quota),
Arc::new(BufferPool::new()),
)
.await
});
let mut total = 0usize;
for step in 0..64u8 {
if relay.is_finished() {
break;
}
if (step as usize + worker as usize) % 2 == 0 {
let _ = client_peer.write_all(&[(step ^ 0x5A)]).await;
let mut one = [0u8; 1];
if let Ok(Ok(n)) = timeout(Duration::from_millis(6), server_peer.read(&mut one)).await {
total += n;
}
} else {
let _ = server_peer.write_all(&[(step ^ 0xA5)]).await;
let mut one = [0u8; 1];
if let Ok(Ok(n)) = timeout(Duration::from_millis(6), client_peer.read(&mut one)).await {
total += n;
}
}
tokio::time::sleep(Duration::from_millis(1)).await;
}
drop(client_peer);
drop(server_peer);
let relay_result = timeout(Duration::from_secs(2), relay).await.unwrap().unwrap();
assert!(relay_result.is_ok() || matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. })));
total
}));
}
let mut delivered = 0usize;
for task in tasks {
delivered += task.await.unwrap();
}
assert!(stats.get_user_total_octets(&user) <= quota);
assert!(delivered <= quota as usize);
}

View File

@ -0,0 +1,79 @@
use super::*;
use dashmap::DashMap;
use std::sync::Arc;
use tokio::time::{Duration, timeout};
#[test]
fn tdd_explicit_quota_lock_evict_reclaims_only_unheld_entries() {
let _guard = quota_user_lock_test_scope();
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let held_user = format!("quota-evict-held-{}", std::process::id());
let stale_a_user = format!("quota-evict-stale-a-{}", std::process::id());
let stale_b_user = format!("quota-evict-stale-b-{}", std::process::id());
let held = quota_user_lock(&held_user);
let stale_a = quota_user_lock(&stale_a_user);
let stale_b = quota_user_lock(&stale_b_user);
assert!(map.get(&held_user).is_some());
assert!(map.get(&stale_a_user).is_some());
assert!(map.get(&stale_b_user).is_some());
drop(stale_a);
drop(stale_b);
quota_user_lock_evict();
assert!(
map.get(&held_user).is_some(),
"held entry must survive eviction"
);
assert!(
map.get(&stale_a_user).is_none(),
"unheld stale entry must be reclaimed"
);
assert!(
map.get(&stale_b_user).is_none(),
"unheld stale entry must be reclaimed"
);
drop(held);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn tdd_periodic_quota_lock_evictor_reclaims_stale_entries_off_hot_path() {
let _guard = quota_user_lock_test_scope();
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let held_user = format!("quota-evict-loop-held-{}", std::process::id());
let stale_user = format!("quota-evict-loop-stale-{}", std::process::id());
let held = quota_user_lock(&held_user);
let stale = quota_user_lock(&stale_user);
assert_eq!(map.len(), 2);
drop(stale);
let evictor = spawn_quota_user_lock_evictor_for_tests(Duration::from_millis(5));
timeout(Duration::from_millis(200), async {
loop {
if map.get(&stale_user).is_none() {
break;
}
tokio::time::sleep(Duration::from_millis(5)).await;
}
})
.await
.expect("periodic quota lock evictor must reclaim stale entry");
evictor.abort();
assert!(map.get(&held_user).is_some());
assert!(map.get(&stale_user).is_none());
drop(held);
}

View File

@ -0,0 +1,153 @@
use super::*;
use dashmap::DashMap;
use std::sync::Arc;
use tokio::task::JoinSet;
use tokio::time::{Duration, timeout};
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn stress_background_evictor_with_high_churn_keeps_cache_bounded_and_live() {
let _guard = quota_user_lock_test_scope();
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let evictor = spawn_quota_user_lock_evictor_for_tests(Duration::from_millis(5));
let mut tasks = JoinSet::new();
for worker in 0..24u32 {
tasks.spawn(async move {
for round in 0..320u32 {
let user = format!(
"quota-evict-stress-user-{}-{}-{}",
std::process::id(),
worker,
round
);
let lock = quota_user_lock(&user);
if round % 19 == 0 {
tokio::task::yield_now().await;
}
drop(lock);
}
});
}
while let Some(done) = tasks.join_next().await {
done.expect("stress worker must not panic");
}
quota_user_lock_evict();
tokio::time::sleep(Duration::from_millis(20)).await;
assert!(
map.len() <= QUOTA_USER_LOCKS_MAX,
"quota lock map must remain bounded after churn + eviction"
);
let sanity_user = format!("quota-evict-stress-sanity-{}", std::process::id());
let sanity_lock = quota_user_lock(&sanity_user);
assert!(
map.get(&sanity_user).is_some(),
"sanity user should be cacheable after eviction reclaimed stale entries"
);
drop(sanity_lock);
evictor.abort();
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn adversarial_held_lock_survives_repeated_eviction_then_reclaims_after_release() {
let _guard = quota_user_lock_test_scope();
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let held_user = format!("quota-evict-held-survive-{}", std::process::id());
let held = quota_user_lock(&held_user);
let evictor = spawn_quota_user_lock_evictor_for_tests(Duration::from_millis(3));
for idx in 0..512u32 {
let user = format!("quota-evict-held-churn-{}-{}", std::process::id(), idx);
let temp = quota_user_lock(&user);
drop(temp);
if idx % 32 == 0 {
tokio::task::yield_now().await;
}
}
let reacquired = quota_user_lock(&held_user);
assert!(
Arc::ptr_eq(&held, &reacquired),
"held user lock identity must remain stable across repeated evictions"
);
assert!(
map.get(&held_user).is_some(),
"held user entry must not be reclaimed while externally referenced"
);
drop(reacquired);
drop(held);
timeout(Duration::from_millis(300), async {
loop {
if map.get(&held_user).is_none() {
break;
}
tokio::time::sleep(Duration::from_millis(5)).await;
}
})
.await
.expect("released held lock must be reclaimed by periodic evictor");
evictor.abort();
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn stress_saturation_then_periodic_eviction_recovers_cacheability_without_inline_retain() {
let _guard = quota_user_lock_test_scope();
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
let prefix = format!("quota-evict-saturated-{}", std::process::id());
for idx in 0..QUOTA_USER_LOCKS_MAX {
retained.push(quota_user_lock(&format!("{prefix}-{idx}")));
}
assert_eq!(map.len(), QUOTA_USER_LOCKS_MAX);
let overflow_user = format!("quota-evict-overflow-user-{}", std::process::id());
let overflow_before = quota_user_lock(&overflow_user);
assert!(
map.get(&overflow_user).is_none(),
"saturated map must initially route new user to overflow stripe"
);
drop(retained);
let evictor = spawn_quota_user_lock_evictor_for_tests(Duration::from_millis(4));
timeout(Duration::from_millis(400), async {
loop {
if map.len() < QUOTA_USER_LOCKS_MAX {
break;
}
tokio::time::sleep(Duration::from_millis(5)).await;
}
})
.await
.expect("periodic evictor must reclaim stale saturated entries");
let overflow_after = quota_user_lock(&overflow_user);
assert!(
map.get(&overflow_user).is_some(),
"after eviction, overflow user should become cacheable again"
);
assert!(
Arc::strong_count(&overflow_after) >= 2,
"cacheable lock should be held by map and caller"
);
drop(overflow_before);
drop(overflow_after);
evictor.abort();
}

View File

@ -0,0 +1,135 @@
use super::*;
use crate::stats::Stats;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::task::Waker;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
#[derive(Default)]
struct WakeCounter {
wakes: AtomicUsize,
}
impl std::task::Wake for WakeCounter {
fn wake(self: Arc<Self>) {
self.wakes.fetch_add(1, Ordering::Relaxed);
}
fn wake_by_ref(self: &Arc<Self>) {
self.wakes.fetch_add(1, Ordering::Relaxed);
}
}
fn build_context() -> (Arc<WakeCounter>, Context<'static>) {
let wake_counter = Arc::new(WakeCounter::default());
let waker = Waker::from(Arc::clone(&wake_counter));
// Context stores a reference; leak one Waker for deterministic test scope.
let leaked_waker: &'static Waker = Box::leak(Box::new(waker));
(wake_counter, Context::from_waker(leaked_waker))
}
#[tokio::test]
async fn adversarial_map_churn_cannot_bypass_held_writer_lock() {
let _guard = quota_user_lock_test_scope();
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let user = "quota-identity-writer-user";
let held_lock = quota_user_lock(user);
let _held_guard = held_lock
.try_lock()
.expect("test must hold initial user lock before StatsIo poll");
let mut io = StatsIo::new(
tokio::io::sink(),
Arc::new(SharedCounters::new()),
Arc::new(Stats::new()),
user.to_string(),
Some(1024),
Arc::new(AtomicBool::new(false)),
tokio::time::Instant::now(),
);
map.clear();
let churned_lock = quota_user_lock(user);
assert!(
!Arc::ptr_eq(&held_lock, &churned_lock),
"precondition: map churn should produce a distinct lock identity"
);
let (_wake_counter, mut cx) = build_context();
let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x11, 0x22, 0x33, 0x44]);
assert!(
matches!(poll, Poll::Pending),
"writer must remain pending on the originally-held lock identity"
);
}
#[tokio::test]
async fn adversarial_map_churn_cannot_bypass_held_reader_lock() {
let _guard = quota_user_lock_test_scope();
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let user = "quota-identity-reader-user";
let held_lock = quota_user_lock(user);
let _held_guard = held_lock
.try_lock()
.expect("test must hold initial user lock before StatsIo poll");
let mut io = StatsIo::new(
tokio::io::empty(),
Arc::new(SharedCounters::new()),
Arc::new(Stats::new()),
user.to_string(),
Some(1024),
Arc::new(AtomicBool::new(false)),
tokio::time::Instant::now(),
);
map.clear();
let churned_lock = quota_user_lock(user);
assert!(
!Arc::ptr_eq(&held_lock, &churned_lock),
"precondition: map churn should produce a distinct lock identity"
);
let (_wake_counter, mut cx) = build_context();
let mut storage = [0u8; 8];
let mut read_buf = ReadBuf::new(&mut storage);
let poll = Pin::new(&mut io).poll_read(&mut cx, &mut read_buf);
assert!(
matches!(poll, Poll::Pending),
"reader must remain pending on the originally-held lock identity"
);
}
#[tokio::test]
async fn business_no_lock_contention_keeps_writer_progress() {
let _guard = quota_user_lock_test_scope();
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let user = "quota-identity-progress-user";
let mut io = StatsIo::new(
tokio::io::sink(),
Arc::new(SharedCounters::new()),
Arc::new(Stats::new()),
user.to_string(),
Some(1024),
Arc::new(AtomicBool::new(false)),
tokio::time::Instant::now(),
);
let (_wake_counter, mut cx) = build_context();
let poll = Pin::new(&mut io).poll_write(&mut cx, &[0xAA, 0xBB]);
assert!(
matches!(poll, Poll::Ready(Ok(2))),
"writer should progress immediately without contention"
);
}

View File

@ -127,7 +127,7 @@ fn quota_lock_saturation_returns_stable_overflow_lock_without_cache_growth() {
}
#[test]
fn quota_lock_reclaims_unreferenced_entries_before_ephemeral_fallback() {
fn quota_lock_reclaims_unreferenced_entries_after_explicit_eviction_pass() {
let _guard = super::quota_user_lock_test_scope();
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
@ -142,6 +142,8 @@ fn quota_lock_reclaims_unreferenced_entries_before_ephemeral_fallback() {
drop(retained);
quota_user_lock_evict();
let overflow_user = format!("quota-reclaim-overflow-{}", std::process::id());
let overflow = quota_user_lock(&overflow_user);

View File

@ -0,0 +1,249 @@
use super::*;
use crate::stats::Stats;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::task::{Context, Waker};
use tokio::io::AsyncWriteExt;
use tokio::time::{Duration, Instant, timeout};
#[derive(Default)]
struct WakeCounter {
wakes: AtomicUsize,
}
impl std::task::Wake for WakeCounter {
fn wake(self: Arc<Self>) {
self.wakes.fetch_add(1, Ordering::Relaxed);
}
fn wake_by_ref(self: &Arc<Self>) {
self.wakes.fetch_add(1, Ordering::Relaxed);
}
}
fn quota_test_guard() -> impl Drop {
super::quota_user_lock_test_scope()
}
fn build_context() -> (Arc<WakeCounter>, Context<'static>) {
let wake_counter = Arc::new(WakeCounter::default());
let waker = Waker::from(Arc::clone(&wake_counter));
let leaked_waker: &'static Waker = Box::leak(Box::new(waker));
(wake_counter, Context::from_waker(leaked_waker))
}
fn sleep_slot_ptr(slot: &Option<Pin<Box<tokio::time::Sleep>>>) -> usize {
slot.as_ref()
.map(|sleep| (&**sleep) as *const tokio::time::Sleep as usize)
.unwrap_or(0)
}
#[tokio::test]
async fn tdd_single_pending_timer_does_not_allocate_on_each_repoll() {
let _guard = quota_test_guard();
let user = format!("retry-alloc-single-pending-{}", std::process::id());
let lock = quota_user_lock(&user);
let held_guard = lock
.try_lock()
.expect("test must hold local lock to force retry scheduling");
reset_quota_retry_sleep_allocs_for_tests();
let mut io = StatsIo::new(
tokio::io::sink(),
Arc::new(SharedCounters::new()),
Arc::new(Stats::new()),
user,
Some(2048),
Arc::new(AtomicBool::new(false)),
Instant::now(),
);
let (_wake_counter, mut cx) = build_context();
let first = Pin::new(&mut io).poll_write(&mut cx, &[0xA1]);
assert!(first.is_pending());
let allocs_after_first = quota_retry_sleep_allocs_for_tests();
let ptr_after_first = sleep_slot_ptr(&io.quota_write_retry_sleep);
let second = Pin::new(&mut io).poll_write(&mut cx, &[0xA2]);
assert!(second.is_pending());
let allocs_after_second = quota_retry_sleep_allocs_for_tests();
let ptr_after_second = sleep_slot_ptr(&io.quota_write_retry_sleep);
assert_eq!(allocs_after_first, 1, "first pending poll must allocate one timer");
assert_eq!(
allocs_after_second, 1,
"repoll while the same timer is pending must not allocate again"
);
assert_eq!(
ptr_after_first, ptr_after_second,
"repoll while pending should retain the same timer allocation"
);
drop(held_guard);
}
#[tokio::test]
async fn tdd_retry_cycle_allocates_once_per_fired_timer_cycle_not_per_poll() {
let _guard = quota_test_guard();
let user = format!("retry-alloc-per-cycle-{}", std::process::id());
let lock = quota_user_lock(&user);
let held_guard = lock
.try_lock()
.expect("test must hold local lock to keep write path pending");
reset_quota_retry_sleep_allocs_for_tests();
let mut io = StatsIo::new(
tokio::io::sink(),
Arc::new(SharedCounters::new()),
Arc::new(Stats::new()),
user,
Some(2048),
Arc::new(AtomicBool::new(false)),
Instant::now(),
);
let (wake_counter, mut cx) = build_context();
let mut polls = 0u64;
let mut observed_wakes = 0usize;
let started = Instant::now();
while started.elapsed() < Duration::from_millis(70) {
let poll = Pin::new(&mut io).poll_write(&mut cx, &[0xB1]);
polls = polls.saturating_add(1);
assert!(poll.is_pending());
let wakes = wake_counter.wakes.load(Ordering::Relaxed);
if wakes > observed_wakes {
observed_wakes = wakes;
}
tokio::time::sleep(Duration::from_millis(1)).await;
}
let allocs = quota_retry_sleep_allocs_for_tests();
assert!(allocs >= 2, "multiple fired cycles should allocate multiple timers");
assert!(
allocs < polls,
"timer allocations must be bounded by cycles, not by every repoll (allocs={allocs}, polls={polls})"
);
drop(held_guard);
}
#[tokio::test]
async fn adversarial_backoff_latency_envelope_stays_bounded_under_contention() {
let _guard = quota_test_guard();
let user = format!("retry-latency-envelope-{}", std::process::id());
let lock = quota_user_lock(&user);
let held_guard = lock
.try_lock()
.expect("test must hold local lock for sustained contention");
reset_quota_retry_sleep_allocs_for_tests();
let mut io = StatsIo::new(
tokio::io::sink(),
Arc::new(SharedCounters::new()),
Arc::new(Stats::new()),
user,
Some(2048),
Arc::new(AtomicBool::new(false)),
Instant::now(),
);
let (wake_counter, mut cx) = build_context();
let first = Pin::new(&mut io).poll_write(&mut cx, &[0xC1]);
assert!(first.is_pending());
let started = Instant::now();
let mut last_wakes = 0usize;
let mut wake_instants = Vec::new();
while started.elapsed() < Duration::from_millis(120) {
let wakes = wake_counter.wakes.load(Ordering::Relaxed);
if wakes > last_wakes {
last_wakes = wakes;
wake_instants.push(Instant::now());
let pending = Pin::new(&mut io).poll_write(&mut cx, &[0xC2]);
assert!(pending.is_pending());
}
tokio::time::sleep(Duration::from_millis(1)).await;
}
let mut max_gap = Duration::from_millis(0);
for idx in 1..wake_instants.len() {
let gap = wake_instants[idx].saturating_duration_since(wake_instants[idx - 1]);
if gap > max_gap {
max_gap = gap;
}
}
assert!(
max_gap <= Duration::from_millis(35),
"retry wake gap must remain bounded in test profile; observed max gap={max_gap:?}"
);
assert!(
quota_retry_sleep_allocs_for_tests() <= 16,
"allocation cycles must remain bounded during a short contention window"
);
drop(held_guard);
}
#[tokio::test]
async fn micro_benchmark_release_to_completion_latency_stays_bounded() {
let _guard = quota_test_guard();
let rounds = 96usize;
let mut samples_ms = Vec::with_capacity(rounds);
for round in 0..rounds {
let user = format!("retry-release-latency-{}-{round}", std::process::id());
let lock = quota_user_lock(&user);
let held_guard = lock
.try_lock()
.expect("test must hold local lock before spawning blocked writer");
let writer = tokio::spawn(async move {
let mut io = StatsIo::new(
tokio::io::sink(),
Arc::new(SharedCounters::new()),
Arc::new(Stats::new()),
user,
Some(2048),
Arc::new(AtomicBool::new(false)),
Instant::now(),
);
io.write_all(&[0xD1]).await
});
tokio::time::sleep(Duration::from_millis(2)).await;
let release_at = Instant::now();
drop(held_guard);
let done = timeout(Duration::from_millis(120), writer)
.await
.expect("blocked writer must complete after release")
.expect("writer task must not panic");
assert!(done.is_ok());
samples_ms.push(release_at.elapsed().as_millis() as u64);
}
samples_ms.sort_unstable();
let p95_idx = ((samples_ms.len() * 95) / 100).min(samples_ms.len().saturating_sub(1));
let p95_ms = samples_ms[p95_idx];
assert!(
p95_ms <= 40,
"contention release->completion p95 must stay bounded; p95_ms={p95_ms}, samples={samples_ms:?}"
);
}

View File

@ -0,0 +1,241 @@
use super::*;
use crate::stats::Stats;
use dashmap::DashMap;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::task::{Context, Waker};
use tokio::io::ReadBuf;
use tokio::time::{Duration, Instant};
#[derive(Default)]
struct WakeCounter {
wakes: AtomicUsize,
}
impl std::task::Wake for WakeCounter {
fn wake(self: Arc<Self>) {
self.wakes.fetch_add(1, Ordering::Relaxed);
}
fn wake_by_ref(self: &Arc<Self>) {
self.wakes.fetch_add(1, Ordering::Relaxed);
}
}
fn quota_test_guard() -> impl Drop {
super::quota_user_lock_test_scope()
}
fn saturate_quota_user_locks() -> Vec<Arc<std::sync::Mutex<()>>> {
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
for idx in 0..QUOTA_USER_LOCKS_MAX {
retained.push(quota_user_lock(&format!("quota-retry-bench-saturate-{idx}")));
}
retained
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn stress_contention_wake_rate_decays_with_backoff_curve() {
let _guard = quota_test_guard();
let _retained = saturate_quota_user_locks();
let user = format!("quota-backoff-bench-{}", std::process::id());
let stats = Arc::new(Stats::new());
let lock = quota_user_lock(&user);
let held_guard = lock
.try_lock()
.expect("test must hold quota lock before benchmark run");
let waiters = 64usize;
let mut ios = Vec::with_capacity(waiters);
let mut wake_counters = Vec::with_capacity(waiters);
for _ in 0..waiters {
ios.push(StatsIo::new(
tokio::io::sink(),
Arc::new(SharedCounters::new()),
Arc::clone(&stats),
user.clone(),
Some(4096),
Arc::new(AtomicBool::new(false)),
tokio::time::Instant::now(),
));
}
for io in &mut ios {
let counter = Arc::new(WakeCounter::default());
let waker = Waker::from(Arc::clone(&counter));
let mut cx = Context::from_waker(&waker);
let pending = Pin::new(io).poll_write(&mut cx, &[0x71]);
assert!(pending.is_pending());
wake_counters.push(counter);
}
let mut observed = vec![0usize; waiters];
let start = Instant::now();
let mut wakes_at_40ms = 0usize;
let mut wakes_at_160ms = 0usize;
while start.elapsed() < Duration::from_millis(200) {
for (idx, counter) in wake_counters.iter().enumerate() {
let wakes = counter.wakes.load(Ordering::Relaxed);
if wakes > observed[idx] {
observed[idx] = wakes;
let waker = Waker::from(Arc::clone(counter));
let mut cx = Context::from_waker(&waker);
let pending = Pin::new(&mut ios[idx]).poll_write(&mut cx, &[0x72]);
assert!(pending.is_pending());
}
}
let elapsed = start.elapsed();
if elapsed >= Duration::from_millis(40) && wakes_at_40ms == 0 {
wakes_at_40ms = wake_counters
.iter()
.map(|counter| counter.wakes.load(Ordering::Relaxed))
.sum();
}
if elapsed >= Duration::from_millis(160) && wakes_at_160ms == 0 {
wakes_at_160ms = wake_counters
.iter()
.map(|counter| counter.wakes.load(Ordering::Relaxed))
.sum();
}
tokio::time::sleep(Duration::from_millis(1)).await;
}
let total_wakes: usize = wake_counters
.iter()
.map(|counter| counter.wakes.load(Ordering::Relaxed))
.sum();
let wakes_at_200ms = total_wakes;
let early_window_wakes = wakes_at_40ms;
let late_window_wakes = wakes_at_200ms.saturating_sub(wakes_at_160ms);
assert!(
total_wakes <= waiters * 28,
"backoff benchmark exceeded wake budget; waiters={waiters}, wakes={total_wakes}"
);
assert!(
early_window_wakes > 0,
"benchmark failed to observe early contention wakes"
);
assert!(
late_window_wakes * 4 <= early_window_wakes * 3,
"wake-rate decay invariant violated; early_0_40ms={early_window_wakes}, late_160_200ms={late_window_wakes}, total={total_wakes}"
);
drop(held_guard);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn stress_read_contention_wake_rate_decays_with_backoff_curve() {
let _guard = quota_test_guard();
let _retained = saturate_quota_user_locks();
let user = format!("quota-backoff-read-bench-{}", std::process::id());
let stats = Arc::new(Stats::new());
let lock = quota_user_lock(&user);
let held_guard = lock
.try_lock()
.expect("test must hold quota lock before read benchmark run");
let waiters = 64usize;
let mut ios = Vec::with_capacity(waiters);
let mut wake_counters = Vec::with_capacity(waiters);
for _ in 0..waiters {
ios.push(StatsIo::new(
tokio::io::empty(),
Arc::new(SharedCounters::new()),
Arc::clone(&stats),
user.clone(),
Some(4096),
Arc::new(AtomicBool::new(false)),
tokio::time::Instant::now(),
));
}
for io in &mut ios {
let counter = Arc::new(WakeCounter::default());
let waker = Waker::from(Arc::clone(&counter));
let mut cx = Context::from_waker(&waker);
let mut storage = [0u8; 1];
let mut buf = ReadBuf::new(&mut storage);
let pending = Pin::new(io).poll_read(&mut cx, &mut buf);
assert!(pending.is_pending());
wake_counters.push(counter);
}
let mut observed = vec![0usize; waiters];
let start = Instant::now();
let mut wakes_at_40ms = 0usize;
let mut wakes_at_160ms = 0usize;
while start.elapsed() < Duration::from_millis(200) {
for (idx, counter) in wake_counters.iter().enumerate() {
let wakes = counter.wakes.load(Ordering::Relaxed);
if wakes > observed[idx] {
observed[idx] = wakes;
let waker = Waker::from(Arc::clone(counter));
let mut cx = Context::from_waker(&waker);
let mut storage = [0u8; 1];
let mut buf = ReadBuf::new(&mut storage);
let pending = Pin::new(&mut ios[idx]).poll_read(&mut cx, &mut buf);
assert!(pending.is_pending());
}
}
let elapsed = start.elapsed();
if elapsed >= Duration::from_millis(40) && wakes_at_40ms == 0 {
wakes_at_40ms = wake_counters
.iter()
.map(|counter| counter.wakes.load(Ordering::Relaxed))
.sum();
}
if elapsed >= Duration::from_millis(160) && wakes_at_160ms == 0 {
wakes_at_160ms = wake_counters
.iter()
.map(|counter| counter.wakes.load(Ordering::Relaxed))
.sum();
}
tokio::time::sleep(Duration::from_millis(1)).await;
}
let total_wakes: usize = wake_counters
.iter()
.map(|counter| counter.wakes.load(Ordering::Relaxed))
.sum();
let wakes_at_200ms = total_wakes;
let early_window_wakes = wakes_at_40ms;
let late_window_wakes = wakes_at_200ms.saturating_sub(wakes_at_160ms);
assert!(
total_wakes <= waiters * 28,
"read backoff benchmark exceeded wake budget; waiters={waiters}, wakes={total_wakes}"
);
assert!(
early_window_wakes > 0,
"read benchmark failed to observe early contention wakes"
);
assert!(
late_window_wakes * 4 <= early_window_wakes * 3,
"read wake-rate decay invariant violated; early_0_40ms={early_window_wakes}, late_160_200ms={late_window_wakes}, total={total_wakes}"
);
drop(held_guard);
}

View File

@ -0,0 +1,339 @@
use super::*;
use crate::stats::Stats;
use dashmap::DashMap;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::task::{Context, Waker};
use tokio::io::ReadBuf;
use tokio::time::{Duration, Instant};
#[derive(Default)]
struct WakeCounter {
wakes: AtomicUsize,
}
impl std::task::Wake for WakeCounter {
fn wake(self: Arc<Self>) {
self.wakes.fetch_add(1, Ordering::Relaxed);
}
fn wake_by_ref(self: &Arc<Self>) {
self.wakes.fetch_add(1, Ordering::Relaxed);
}
}
fn quota_test_guard() -> impl Drop {
super::quota_user_lock_test_scope()
}
fn saturate_quota_user_locks() -> Vec<Arc<std::sync::Mutex<()>>> {
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
for idx in 0..QUOTA_USER_LOCKS_MAX {
retained.push(quota_user_lock(&format!("quota-retry-backoff-saturate-{idx}")));
}
retained
}
#[tokio::test]
async fn positive_uncontended_writer_keeps_retry_wakes_zero() {
let _guard = quota_test_guard();
let stats = Arc::new(Stats::new());
let mut io = StatsIo::new(
tokio::io::sink(),
Arc::new(SharedCounters::new()),
Arc::clone(&stats),
"quota-backoff-positive".to_string(),
Some(2048),
Arc::new(AtomicBool::new(false)),
tokio::time::Instant::now(),
);
let wake_counter = Arc::new(WakeCounter::default());
let waker = Waker::from(Arc::clone(&wake_counter));
let mut cx = Context::from_waker(&waker);
let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x41, 0x42]);
assert!(poll.is_ready(), "uncontended writer must complete immediately");
assert_eq!(
wake_counter.wakes.load(Ordering::Relaxed),
0,
"uncontended path must not schedule deferred contention wakes"
);
}
#[tokio::test]
async fn adversarial_writer_sustained_contention_executor_repoll_is_rate_limited() {
let _guard = quota_test_guard();
let _retained = saturate_quota_user_locks();
let user = "quota-backoff-adversarial-writer";
let stats = Arc::new(Stats::new());
let lock = quota_user_lock(user);
let held_guard = lock
.try_lock()
.expect("test must hold quota lock before polling writer");
let mut io = StatsIo::new(
tokio::io::sink(),
Arc::new(SharedCounters::new()),
Arc::clone(&stats),
user.to_string(),
Some(2048),
Arc::new(AtomicBool::new(false)),
tokio::time::Instant::now(),
);
let wake_counter = Arc::new(WakeCounter::default());
let waker = Waker::from(Arc::clone(&wake_counter));
let mut cx = Context::from_waker(&waker);
let first = Pin::new(&mut io).poll_write(&mut cx, &[0xAA]);
assert!(first.is_pending());
let start = Instant::now();
let mut observed = 0usize;
while start.elapsed() < Duration::from_millis(80) {
let wakes = wake_counter.wakes.load(Ordering::Relaxed);
if wakes > observed {
observed = wakes;
let pending = Pin::new(&mut io).poll_write(&mut cx, &[0xAB]);
assert!(pending.is_pending());
}
tokio::time::sleep(Duration::from_millis(1)).await;
}
assert!(
wake_counter.wakes.load(Ordering::Relaxed) <= 16,
"sustained contention must be rate limited; observed wakes={} in 80ms",
wake_counter.wakes.load(Ordering::Relaxed)
);
drop(held_guard);
let ready = Pin::new(&mut io).poll_write(&mut cx, &[0xAC]);
assert!(ready.is_ready());
}
#[tokio::test]
async fn adversarial_reader_sustained_contention_executor_repoll_is_rate_limited() {
let _guard = quota_test_guard();
let _retained = saturate_quota_user_locks();
let user = "quota-backoff-adversarial-reader";
let stats = Arc::new(Stats::new());
let lock = quota_user_lock(user);
let held_guard = lock
.try_lock()
.expect("test must hold quota lock before polling reader");
let mut io = StatsIo::new(
tokio::io::empty(),
Arc::new(SharedCounters::new()),
Arc::clone(&stats),
user.to_string(),
Some(2048),
Arc::new(AtomicBool::new(false)),
tokio::time::Instant::now(),
);
let wake_counter = Arc::new(WakeCounter::default());
let waker = Waker::from(Arc::clone(&wake_counter));
let mut cx = Context::from_waker(&waker);
let mut storage = [0u8; 1];
let mut buf = ReadBuf::new(&mut storage);
let first = Pin::new(&mut io).poll_read(&mut cx, &mut buf);
assert!(first.is_pending());
let start = Instant::now();
let mut observed = 0usize;
while start.elapsed() < Duration::from_millis(80) {
let wakes = wake_counter.wakes.load(Ordering::Relaxed);
if wakes > observed {
observed = wakes;
let mut next = ReadBuf::new(&mut storage);
let pending = Pin::new(&mut io).poll_read(&mut cx, &mut next);
assert!(pending.is_pending());
}
tokio::time::sleep(Duration::from_millis(1)).await;
}
assert!(
wake_counter.wakes.load(Ordering::Relaxed) <= 16,
"sustained contention must be rate limited; observed wakes={} in 80ms",
wake_counter.wakes.load(Ordering::Relaxed)
);
drop(held_guard);
let mut done = ReadBuf::new(&mut storage);
let ready = Pin::new(&mut io).poll_read(&mut cx, &mut done);
assert!(ready.is_ready());
}
#[tokio::test]
async fn edge_backoff_attempt_resets_after_contention_release() {
let _guard = quota_test_guard();
let _retained = saturate_quota_user_locks();
let user = "quota-backoff-edge-reset";
let stats = Arc::new(Stats::new());
let lock = quota_user_lock(user);
let held_guard = lock
.try_lock()
.expect("test must hold quota lock before polling writer");
let mut io = StatsIo::new(
tokio::io::sink(),
Arc::new(SharedCounters::new()),
Arc::clone(&stats),
user.to_string(),
Some(2048),
Arc::new(AtomicBool::new(false)),
tokio::time::Instant::now(),
);
let wake_counter = Arc::new(WakeCounter::default());
let waker = Waker::from(Arc::clone(&wake_counter));
let mut cx = Context::from_waker(&waker);
let initial = Pin::new(&mut io).poll_write(&mut cx, &[0x31]);
assert!(initial.is_pending());
tokio::time::sleep(Duration::from_millis(15)).await;
let wakes = wake_counter.wakes.load(Ordering::Relaxed);
if wakes > 0 {
let pending = Pin::new(&mut io).poll_write(&mut cx, &[0x32]);
assert!(pending.is_pending());
}
drop(held_guard);
let ready = Pin::new(&mut io).poll_write(&mut cx, &[0x33]);
assert!(ready.is_ready());
assert!(
!io.quota_write_wake_scheduled,
"successful write must clear deferred wake scheduling flag"
);
assert!(
io.quota_write_retry_sleep.is_none(),
"successful write must clear deferred sleep slot"
);
}
#[tokio::test]
async fn light_fuzz_writer_repoll_schedule_keeps_wake_budget_bounded() {
let _guard = quota_test_guard();
let _retained = saturate_quota_user_locks();
let user = "quota-backoff-fuzz-writer";
let stats = Arc::new(Stats::new());
let lock = quota_user_lock(user);
let held_guard = lock
.try_lock()
.expect("test must hold quota lock before fuzz loop");
let mut io = StatsIo::new(
tokio::io::sink(),
Arc::new(SharedCounters::new()),
Arc::clone(&stats),
user.to_string(),
Some(2048),
Arc::new(AtomicBool::new(false)),
tokio::time::Instant::now(),
);
let wake_counter = Arc::new(WakeCounter::default());
let waker = Waker::from(Arc::clone(&wake_counter));
let mut cx = Context::from_waker(&waker);
let mut seed = 0x5EED_CAFE_7788_9900u64;
for _ in 0..64 {
let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x51]);
assert!(poll.is_pending());
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let sleep_ms = (seed % 4) as u64;
tokio::time::sleep(Duration::from_millis(sleep_ms)).await;
}
assert!(
wake_counter.wakes.load(Ordering::Relaxed) <= 24,
"fuzzed repoll schedule must keep wake budget bounded; observed wakes={}",
wake_counter.wakes.load(Ordering::Relaxed)
);
drop(held_guard);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn stress_multi_waiter_contention_keeps_global_wake_budget_bounded() {
let _guard = quota_test_guard();
let _retained = saturate_quota_user_locks();
let user = format!("quota-backoff-stress-{}", std::process::id());
let stats = Arc::new(Stats::new());
let lock = quota_user_lock(&user);
let held_guard = lock
.try_lock()
.expect("test must hold quota lock before launching stress waiters");
let waiters = 48usize;
let mut ios = Vec::with_capacity(waiters);
let mut wake_counters = Vec::with_capacity(waiters);
for _ in 0..waiters {
ios.push(StatsIo::new(
tokio::io::sink(),
Arc::new(SharedCounters::new()),
Arc::clone(&stats),
user.clone(),
Some(4096),
Arc::new(AtomicBool::new(false)),
tokio::time::Instant::now(),
));
}
for io in &mut ios {
let counter = Arc::new(WakeCounter::default());
let waker = Waker::from(Arc::clone(&counter));
let mut cx = Context::from_waker(&waker);
let pending = Pin::new(io).poll_write(&mut cx, &[0x61]);
assert!(pending.is_pending());
wake_counters.push(counter);
}
let start = Instant::now();
while start.elapsed() < Duration::from_millis(120) {
for (idx, counter) in wake_counters.iter().enumerate() {
if counter.wakes.load(Ordering::Relaxed) > 0 {
let waker = Waker::from(Arc::clone(counter));
let mut cx = Context::from_waker(&waker);
let pending = Pin::new(&mut ios[idx]).poll_write(&mut cx, &[0x62]);
assert!(pending.is_pending());
}
}
tokio::time::sleep(Duration::from_millis(1)).await;
}
let total_wakes: usize = wake_counters
.iter()
.map(|counter| counter.wakes.load(Ordering::Relaxed))
.sum();
assert!(
total_wakes <= waiters * 20,
"stress contention must keep aggregate wake budget bounded; waiters={waiters}, wakes={total_wakes}"
);
drop(held_guard);
}

View File

@ -0,0 +1,246 @@
use super::*;
use crate::stats::Stats;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::task::{Context, Poll, Waker};
use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadBuf};
use tokio::time::{Duration, timeout};
#[derive(Default)]
struct WakeCounter {
wakes: AtomicUsize,
}
impl std::task::Wake for WakeCounter {
fn wake(self: Arc<Self>) {
self.wakes.fetch_add(1, Ordering::Relaxed);
}
fn wake_by_ref(self: &Arc<Self>) {
self.wakes.fetch_add(1, Ordering::Relaxed);
}
}
fn quota_test_guard() -> impl Drop {
super::quota_user_lock_test_scope()
}
#[tokio::test]
async fn positive_uncontended_quota_limited_writer_completes() {
let _guard = quota_test_guard();
let stats = Arc::new(Stats::new());
let mut io = StatsIo::new(
tokio::io::sink(),
Arc::new(SharedCounters::new()),
Arc::clone(&stats),
"tdd-uncontended".to_string(),
Some(1024),
Arc::new(AtomicBool::new(false)),
tokio::time::Instant::now(),
);
let result = io.write_all(&[0x41, 0x42, 0x43]).await;
assert!(result.is_ok(), "uncontended writer must complete");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn adversarial_contended_writers_without_repoll_must_not_wake_storm() {
let _guard = quota_test_guard();
let user = format!("tdd-writer-storm-{}", std::process::id());
let held = quota_user_lock(&user);
let _held_guard = held
.try_lock()
.expect("test must hold quota lock before polling writers");
let stats = Arc::new(Stats::new());
let writers = 24usize;
let mut ios = Vec::with_capacity(writers);
let mut wake_counters = Vec::with_capacity(writers);
for _ in 0..writers {
ios.push(StatsIo::new(
tokio::io::sink(),
Arc::new(SharedCounters::new()),
Arc::clone(&stats),
user.clone(),
Some(1024),
Arc::new(AtomicBool::new(false)),
tokio::time::Instant::now(),
));
}
for io in &mut ios {
let counter = Arc::new(WakeCounter::default());
let waker = Waker::from(Arc::clone(&counter));
let mut cx = Context::from_waker(&waker);
let poll = Pin::new(io).poll_write(&mut cx, &[0xAA]);
assert!(poll.is_pending(), "writer must be pending under held lock");
wake_counters.push(counter);
}
tokio::time::sleep(Duration::from_millis(25)).await;
let total_wakes: usize = wake_counters
.iter()
.map(|counter| counter.wakes.load(Ordering::Relaxed))
.sum();
assert!(
total_wakes <= writers * 4,
"retry scheduler must remain bounded without repoll; observed wakes={total_wakes}, writers={writers}"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn adversarial_contended_readers_without_repoll_must_not_wake_storm() {
let _guard = quota_test_guard();
let user = format!("tdd-reader-storm-{}", std::process::id());
let held = quota_user_lock(&user);
let _held_guard = held
.try_lock()
.expect("test must hold quota lock before polling readers");
let stats = Arc::new(Stats::new());
let readers = 24usize;
let mut ios = Vec::with_capacity(readers);
let mut wake_counters = Vec::with_capacity(readers);
for _ in 0..readers {
ios.push(StatsIo::new(
tokio::io::empty(),
Arc::new(SharedCounters::new()),
Arc::clone(&stats),
user.clone(),
Some(1024),
Arc::new(AtomicBool::new(false)),
tokio::time::Instant::now(),
));
}
for io in &mut ios {
let counter = Arc::new(WakeCounter::default());
let waker = Waker::from(Arc::clone(&counter));
let mut cx = Context::from_waker(&waker);
let mut storage = [0u8; 1];
let mut buf = ReadBuf::new(&mut storage);
let poll = Pin::new(io).poll_read(&mut cx, &mut buf);
assert!(poll.is_pending(), "reader must be pending under held lock");
wake_counters.push(counter);
}
tokio::time::sleep(Duration::from_millis(25)).await;
let total_wakes: usize = wake_counters
.iter()
.map(|counter| counter.wakes.load(Ordering::Relaxed))
.sum();
assert!(
total_wakes <= readers * 4,
"retry scheduler must remain bounded without repoll; observed wakes={total_wakes}, readers={readers}"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn integration_contended_waiters_resume_after_lock_release() {
let _guard = quota_test_guard();
let user = format!("tdd-resume-{}", std::process::id());
let held = quota_user_lock(&user);
let held_guard = held
.try_lock()
.expect("test must hold quota lock before launching waiters");
let stats = Arc::new(Stats::new());
let mut waiters = Vec::new();
for _ in 0..12 {
let stats = Arc::clone(&stats);
let user = user.clone();
waiters.push(tokio::spawn(async move {
let mut io = StatsIo::new(
tokio::io::sink(),
Arc::new(SharedCounters::new()),
stats,
user,
Some(2048),
Arc::new(AtomicBool::new(false)),
tokio::time::Instant::now(),
);
io.write_all(&[0x5A]).await
}));
}
tokio::time::sleep(Duration::from_millis(5)).await;
drop(held_guard);
timeout(Duration::from_secs(1), async {
for waiter in waiters {
let result = waiter.await.expect("waiter task must not panic");
assert!(result.is_ok(), "waiter must complete after release");
}
})
.await
.expect("all waiters must complete in bounded time");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn light_fuzz_contention_rounds_keep_retry_wakes_bounded() {
let _guard = quota_test_guard();
let mut seed = 0x9E37_79B9_AA55_1234u64;
for round in 0..20u32 {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let writers = 8 + (seed as usize % 12);
let sleep_ms = 10 + (seed as u64 % 15);
let user = format!("tdd-fuzz-{}-{round}", std::process::id());
let held = quota_user_lock(&user);
let _held_guard = held
.try_lock()
.expect("test must hold quota lock in fuzz round");
let stats = Arc::new(Stats::new());
let mut ios = Vec::with_capacity(writers);
let mut wake_counters = Vec::with_capacity(writers);
for _ in 0..writers {
ios.push(StatsIo::new(
tokio::io::sink(),
Arc::new(SharedCounters::new()),
Arc::clone(&stats),
user.clone(),
Some(2048),
Arc::new(AtomicBool::new(false)),
tokio::time::Instant::now(),
));
}
for io in &mut ios {
let counter = Arc::new(WakeCounter::default());
let waker = Waker::from(Arc::clone(&counter));
let mut cx = Context::from_waker(&waker);
let poll = Pin::new(io).poll_write(&mut cx, &[0x7A]);
assert!(matches!(poll, Poll::Pending));
wake_counters.push(counter);
}
tokio::time::sleep(Duration::from_millis(sleep_ms)).await;
let total_wakes: usize = wake_counters
.iter()
.map(|counter| counter.wakes.load(Ordering::Relaxed))
.sum();
assert!(
total_wakes <= writers * 4,
"fuzz round must keep wakes bounded; round={round}, writers={writers}, wakes={total_wakes}, sleep_ms={sleep_ms}"
);
}
}

View File

@ -137,10 +137,10 @@ async fn quota_lock_contention_writer_schedules_single_deferred_wake_until_lock_
for _ in 0..8 {
tokio::task::yield_now().await;
}
assert_eq!(
wake_counter.wakes.load(Ordering::Relaxed),
wakes_after_first_yield,
"writer contention should not schedule unbounded wake storms before lock acquisition"
let wakes_after_second_window = wake_counter.wakes.load(Ordering::Relaxed);
assert!(
wakes_after_second_window <= wakes_after_first_yield.saturating_add(2),
"writer contention should keep retry wakes bounded before lock acquisition: before={wakes_after_first_yield}, after={wakes_after_second_window}"
);
drop(held_lock);

View File

@ -1884,6 +1884,32 @@ impl Stats {
stats.octets_to_client.fetch_add(bytes, Ordering::Relaxed);
}
pub fn sub_user_octets_to(&self, user: &str, bytes: u64) {
if !self.telemetry_user_enabled() {
return;
}
self.maybe_cleanup_user_stats();
let Some(stats) = self.user_stats.get(user) else {
return;
};
Self::touch_user_stats(stats.value());
let counter = &stats.octets_to_client;
let mut current = counter.load(Ordering::Relaxed);
loop {
let next = current.saturating_sub(bytes);
match counter.compare_exchange_weak(
current,
next,
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(_) => break,
Err(actual) => current = actual,
}
}
}
pub fn increment_user_msgs_from(&self, user: &str) {
if !self.telemetry_user_enabled() {
return;
@ -2440,3 +2466,7 @@ mod connection_lease_security_tests;
#[cfg(test)]
#[path = "tests/replay_checker_security_tests.rs"]
mod replay_checker_security_tests;
#[cfg(test)]
#[path = "tests/user_octets_sub_security_tests.rs"]
mod user_octets_sub_security_tests;

View File

@ -0,0 +1,151 @@
use super::*;
use std::sync::Arc;
use std::thread;
#[test]
fn sub_user_octets_to_underflow_saturates_at_zero() {
let stats = Stats::new();
let user = "sub-underflow-user";
stats.add_user_octets_to(user, 3);
stats.sub_user_octets_to(user, 100);
assert_eq!(stats.get_user_total_octets(user), 0);
}
#[test]
fn sub_user_octets_to_does_not_affect_octets_from_client() {
let stats = Stats::new();
let user = "sub-isolation-user";
stats.add_user_octets_from(user, 17);
stats.add_user_octets_to(user, 5);
stats.sub_user_octets_to(user, 3);
assert_eq!(stats.get_user_total_octets(user), 19);
}
#[test]
fn light_fuzz_add_sub_model_matches_saturating_reference() {
let stats = Stats::new();
let user = "sub-fuzz-user";
let mut seed = 0x91D2_4CB8_EE77_1101u64;
let mut model_to = 0u64;
for _ in 0..8192 {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let amt = ((seed >> 8) & 0x3f) + 1;
if (seed & 1) == 0 {
stats.add_user_octets_to(user, amt);
model_to = model_to.saturating_add(amt);
} else {
stats.sub_user_octets_to(user, amt);
model_to = model_to.saturating_sub(amt);
}
}
assert_eq!(stats.get_user_total_octets(user), model_to);
}
#[test]
fn stress_parallel_add_sub_never_underflows_or_panics() {
let stats = Arc::new(Stats::new());
let user = "sub-stress-user";
// Pre-fund with a large offset so subtractions never saturate at zero.
// This guarantees commutative updates, making the final state deterministic.
let base_offset = 10_000_000u64;
stats.add_user_octets_to(user, base_offset);
let mut workers = Vec::new();
for tid in 0..16u64 {
let stats_for_thread = Arc::clone(&stats);
workers.push(thread::spawn(move || {
let mut seed = 0xD00D_1000_0000_0000u64 ^ tid;
let mut net_delta = 0i64;
for _ in 0..4096 {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let amt = ((seed >> 8) & 0x1f) + 1;
if (seed & 1) == 0 {
stats_for_thread.add_user_octets_to(user, amt);
net_delta += amt as i64;
} else {
stats_for_thread.sub_user_octets_to(user, amt);
net_delta -= amt as i64;
}
}
net_delta
}));
}
let mut expected_net_delta = 0i64;
for worker in workers {
expected_net_delta += worker
.join()
.expect("sub-user stress worker must not panic");
}
let expected_total = (base_offset as i64 + expected_net_delta) as u64;
let total = stats.get_user_total_octets(user);
assert_eq!(
total, expected_total,
"concurrent add/sub lost updates or suffered ABA races"
);
}
#[test]
fn sub_user_octets_to_missing_user_is_noop() {
let stats = Stats::new();
stats.sub_user_octets_to("missing-user", 1024);
assert_eq!(stats.get_user_total_octets("missing-user"), 0);
}
#[test]
fn stress_parallel_per_user_models_remain_exact() {
let stats = Arc::new(Stats::new());
let mut workers = Vec::new();
for tid in 0..16u64 {
let stats_for_thread = Arc::clone(&stats);
workers.push(thread::spawn(move || {
let user = format!("sub-per-user-{tid}");
let mut seed = 0xFACE_0000_0000_0000u64 ^ tid;
let mut model = 0u64;
for _ in 0..4096 {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let amt = ((seed >> 8) & 0x3f) + 1;
if (seed & 1) == 0 {
stats_for_thread.add_user_octets_to(&user, amt);
model = model.saturating_add(amt);
} else {
stats_for_thread.sub_user_octets_to(&user, amt);
model = model.saturating_sub(amt);
}
}
(user, model)
}));
}
for worker in workers {
let (user, model) = worker
.join()
.expect("per-user subtract stress worker must not panic");
assert_eq!(
stats.get_user_total_octets(&user),
model,
"per-user parallel model diverged"
);
}
}