diff --git a/.cargo/deny.toml b/.cargo/deny.toml new file mode 100644 index 0000000..09a5dd9 --- /dev/null +++ b/.cargo/deny.toml @@ -0,0 +1,15 @@ +[bans] +multiple-versions = "deny" +wildcards = "allow" +highlight = "all" + +# Explicitly flag the weak cryptography so the agent is forced to justify its existence +[[bans.skip]] +name = "md-5" +version = "*" +reason = "MUST VERIFY: Only allowed for legacy checksums, never for security." + +[[bans.skip]] +name = "sha1" +version = "*" +reason = "MUST VERIFY: Only allowed for backwards compatibility." diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index effe3ea..799f2ce 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -45,6 +45,18 @@ jobs: - name: Run tests run: cargo test --verbose + - name: Stress quota-lock suites (PR only) + if: github.event_name == 'pull_request' + env: + RUST_TEST_THREADS: 16 + run: | + set -euo pipefail + for i in $(seq 1 12); do + echo "[quota-lock-stress] iteration ${i}/12" + cargo test quota_lock_ --bin telemt -- --nocapture --test-threads 16 + cargo test relay_quota_wake --bin telemt -- --nocapture --test-threads 16 + done + # clippy dont fail on warnings because of active development of telemt # and many warnings - name: Run clippy diff --git a/.gitignore b/.gitignore index 3a45e41..bc782ca 100644 --- a/.gitignore +++ b/.gitignore @@ -21,3 +21,4 @@ target #.idea/ proxy-secret +coverage-html/ \ No newline at end of file diff --git a/AGENTS.md b/AGENTS.md index e6c5f2e..c17cc76 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -5,6 +5,22 @@ Your responses are precise, minimal, and architecturally sound. You are working --- +### Context: The Telemt Project + +You are working on **Telemt**, a high-performance, production-grade Telegram MTProxy implementation written in Rust. It is explicitly designed to operate in highly hostile network environments and evade advanced network censorship. + +**Adversarial Threat Model:** +The proxy operates under constant surveillance by DPI (Deep Packet Inspection) systems and active scanners (state firewalls, mobile operator fraud controls). These entities actively probe IPs, analyze protocol handshakes, and look for known proxy signatures to block or throttle traffic. + +**Core Architectural Pillars:** +1. **TLS-Fronting (TLS-F) & TCP-Splitting (TCP-S):** To the outside world, Telemt looks like a standard TLS server. If a client presents a valid MTProxy key, the connection is handled internally. If a censor's scanner, web browser, or unauthorized crawler connects, Telemt seamlessly splices the TCP connection (L4) to a real, legitimate HTTPS fallback server (e.g., Nginx) without modifying the `ClientHello` or terminating the TLS handshake. +2. **Middle-End (ME) Orchestration:** A highly concurrent, generation-based pool managing upstream connections to Telegram Datacenters (DCs). It utilizes an **Adaptive Floor** (dynamically scaling writer connections based on traffic), **Hardswaps** (zero-downtime pool reconfiguration), and **STUN/NAT** reflection mechanisms. +3. **Strict KDF Routing:** Cryptographic Key Derivation Functions (KDF) in this protocol strictly rely on the exact pairing of Source IP/Port and Destination IP/Port. Deviations or missing port logic will silently break the MTProto handshake. +4. **Data Plane vs. Control Plane Isolation:** The Data Plane (readers, writers, payload relay, TCP splicing) must remain strictly non-blocking, zero-allocation in hot paths, and highly resilient to network backpressure. The Control Plane (API, metrics, pool generation swaps, config reloads) orchestrates the state asynchronously without stalling the Data Plane. + +Any modification you make must preserve Telemt's invisibility to censors, its strict memory-safety invariants, and its hot-path throughput. + + ### 0. Priority Resolution — Scope Control This section resolves conflicts between code quality enforcement and scope limitation. @@ -374,6 +390,12 @@ you MUST explain why existing invariants remain valid. - Do not modify existing tests unless the task explicitly requires it. - Do not weaken assertions. - Preserve determinism in testable components. +- Bug-first forces the discipline of proving you understand a bug before you fix it. Tests written after a fix almost always pass trivially and catch nothing new. +- Invariants over scenarios is the core shift. The route_mode table alone would have caught both BUG-1 and BUG-2 before they were written — "snapshot equals watch state after any transition burst" is a two-line property test that fails immediately on the current diverged-atomics code. +- Differential/model catches logic drift over time. +- Scheduler pressure is specifically aimed at the concurrent state bugs that keep reappearing. A single-threaded happy-path test of set_mode will never find subtle bugs; 10,000 concurrent calls will find it on the first run. +- Mutation gate answers your original complaint directly. It measures test power. If you can remove a bounds check and nothing breaks, the suite isn't covering that branch yet — it just says so explicitly. +- Dead parameter is a code smell rule. ### 15. Security Constraints diff --git a/Cargo.lock b/Cargo.lock index 787e357..fd8b44a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -20,7 +20,7 @@ checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0" dependencies = [ "cfg-if", "cipher", - "cpufeatures", + "cpufeatures 0.2.17", ] [[package]] @@ -46,6 +46,15 @@ dependencies = [ "memchr", ] +[[package]] +name = "alloca" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5a7d05ea6aea7e9e64d25b9156ba2fee3fdd659e34e41063cd2fc7cd020d7f4" +dependencies = [ + "cc", +] + [[package]] name = "allocator-api2" version = "0.2.21" @@ -102,9 +111,9 @@ checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" [[package]] name = "asn1-rs" -version = "0.5.2" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f6fd5ddaf0351dff5b8da21b2fb4ff8e08ddd02857f0bf69c47639106c0fff0" +checksum = "56624a96882bb8c26d61312ae18cb45868e5a9992ea73c58e45c3101e56a1e60" dependencies = [ "asn1-rs-derive", "asn1-rs-impl", @@ -112,31 +121,31 @@ dependencies = [ "nom", "num-traits", "rusticata-macros", - "thiserror 1.0.69", + "thiserror 2.0.18", "time", ] [[package]] name = "asn1-rs-derive" -version = "0.4.0" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "726535892e8eae7e70657b4c8ea93d26b8553afb1ce617caee529ef96d7dee6c" +checksum = "3109e49b1e4909e9db6515a30c633684d68cdeaa252f215214cb4fa1a5bfee2c" dependencies = [ "proc-macro2", "quote", - "syn 1.0.109", - "synstructure 0.12.6", + "syn", + "synstructure", ] [[package]] name = "asn1-rs-impl" -version = "0.1.0" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2777730b2039ac0f95f093556e61b6d26cebed5393ca6f152717777cec3a42ed" +checksum = "7b18050c2cd6fe86c3a76584ef5e0baf286d038cda203eb6223df2cc413565f7" dependencies = [ "proc-macro2", "quote", - "syn 1.0.109", + "syn", ] [[package]] @@ -147,7 +156,7 @@ checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -162,6 +171,28 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" +[[package]] +name = "aws-lc-rs" +version = "1.16.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a054912289d18629dc78375ba2c3726a3afe3ff71b4edba9dedfca0e3446d1fc" +dependencies = [ + "aws-lc-sys", + "zeroize", +] + +[[package]] +name = "aws-lc-sys" +version = "0.39.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fa7e52a4c5c547c741610a2c6f123f3881e409b714cd27e6798ef020c514f0a" +dependencies = [ + "cc", + "cmake", + "dunce", + "fs_extra", +] + [[package]] name = "base64" version = "0.22.1" @@ -212,7 +243,7 @@ dependencies = [ "cc", "cfg-if", "constant_time_eq", - "cpufeatures", + "cpufeatures 0.2.17", ] [[package]] @@ -273,21 +304,23 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7a0dd1ca384932ff3641c8718a02769f1698e7563dc6974ffd03346116310423" dependencies = [ "find-msvc-tools", + "jobserver", + "libc", "shlex", ] +[[package]] +name = "cesu8" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d43a04d8753f35258c91f8ec639f792891f748a1edbd759cf1dcea3382ad83c" + [[package]] name = "cfg-if" version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" -[[package]] -name = "cfg_aliases" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd16c4719339c4530435d38e511904438d07cce7950afa3718a84ac36c10e89e" - [[package]] name = "cfg_aliases" version = "0.2.1" @@ -302,7 +335,18 @@ checksum = "c3613f74bd2eac03dad61bd53dbe620703d4371614fe0bc3b9f04dd36fe4e818" dependencies = [ "cfg-if", "cipher", - "cpufeatures", + "cpufeatures 0.2.17", +] + +[[package]] +name = "chacha20" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f8d983286843e49675a4b7a2d174efe136dc93a18d69130dd18198a6c167601" +dependencies = [ + "cfg-if", + "cpufeatures 0.3.0", + "rand_core 0.10.0", ] [[package]] @@ -312,7 +356,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "10cd79432192d1c0f4e1a0fef9527696cc039165d729fb41b3f4f4f354c2dc35" dependencies = [ "aead", - "chacha20", + "chacha20 0.9.1", "cipher", "poly1305", "zeroize", @@ -395,6 +439,25 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c8d4a3bb8b1e0c1050499d1815f5ab16d04f0959b233085fb31653fbfc9d98f9" +[[package]] +name = "cmake" +version = "0.1.57" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75443c44cd6b379beb8c5b45d85d0773baf31cce901fe7bb252f4eff3008ef7d" +dependencies = [ + "cc", +] + +[[package]] +name = "combine" +version = "4.6.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba5a308b75df32fe02788e748662718f03fde005016435c444eea572398219fd" +dependencies = [ + "bytes", + "memchr", +] + [[package]] name = "const-oid" version = "0.9.6" @@ -407,6 +470,16 @@ version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3d52eff69cd5e647efe296129160853a42795992097e8af39800e1060caeea9b" +[[package]] +name = "core-foundation" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2a6cd9ae233e7f62ba4e9353e81a88df7fc8a5987b8d445b4d90c879bd156f6" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "core-foundation-sys" version = "0.8.7" @@ -422,6 +495,15 @@ dependencies = [ "libc", ] +[[package]] +name = "cpufeatures" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b2a41393f66f16b0823bb79094d54ac5fbd34ab292ddafb9a0456ac9f87d201" +dependencies = [ + "libc", +] + [[package]] name = "crc32c" version = "0.6.8" @@ -442,25 +524,24 @@ dependencies = [ [[package]] name = "criterion" -version = "0.5.1" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f" +checksum = "950046b2aa2492f9a536f5f4f9a3de7b9e2476e575e05bd6c333371add4d98f3" dependencies = [ + "alloca", "anes", "cast", "ciborium", "clap", "criterion-plot", - "is-terminal", "itertools", "num-traits", - "once_cell", "oorandom", + "page_size", "plotters", "rayon", "regex", "serde", - "serde_derive", "serde_json", "tinytemplate", "walkdir", @@ -468,9 +549,9 @@ dependencies = [ [[package]] name = "criterion-plot" -version = "0.5.0" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" +checksum = "d8d80a2f4f5b554395e47b5d8305bc3d27813bacb73493eb1001e8f76dae29ea" dependencies = [ "cast", "itertools", @@ -552,12 +633,39 @@ dependencies = [ ] [[package]] -name = "dashmap" -version = "5.5.3" +name = "curve25519-dalek" +version = "4.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856" +checksum = "97fb8b7c4503de7d6ae7b42ab72a5a59857b4c937ec27a3d4539dba95b5ab2be" dependencies = [ "cfg-if", + "cpufeatures 0.2.17", + "curve25519-dalek-derive", + "fiat-crypto", + "rustc_version", + "subtle", + "zeroize", +] + +[[package]] +name = "curve25519-dalek-derive" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f46882e17999c6cc590af592290432be3bce0428cb0d5f8b6715e4dc7b383eb3" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "dashmap" +version = "6.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" +dependencies = [ + "cfg-if", + "crossbeam-utils", "hashbrown 0.14.5", "lock_api", "once_cell", @@ -582,9 +690,9 @@ dependencies = [ [[package]] name = "der-parser" -version = "8.2.0" +version = "10.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dbd676fbbab537128ef0278adb5576cf363cff6aa22a7b24effe97347cfab61e" +checksum = "07da5016415d5a3c4dd39b11ed26f915f52fc4e0dc197d87908bc916e51bc1a6" dependencies = [ "asn1-rs", "displaydoc", @@ -622,9 +730,15 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] +[[package]] +name = "dunce" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813" + [[package]] name = "dynosaur" version = "0.3.0" @@ -642,7 +756,7 @@ checksum = "0b0713d5c1d52e774c5cd7bb8b043d7c0fc4f921abfb678556140bfbe6ab2364" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -670,7 +784,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -696,15 +810,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" [[package]] -name = "filetime" -version = "0.2.27" +name = "fiat-crypto" +version = "0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f98844151eee8917efc50bd9e8318cb963ae8b297431495d3f758616ea5c57db" -dependencies = [ - "cfg-if", - "libc", - "libredox", -] +checksum = "28dea519a9695b9977216879a3ebfddf92f1c08c05d984f8996aecd6ecdc811d" [[package]] name = "find-msvc-tools" @@ -739,6 +848,12 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "fs_extra" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" + [[package]] name = "fsevent-sys" version = "4.1.0" @@ -804,7 +919,7 @@ checksum = "e835b70203e41293343137df5c0664546da5745f82ec9b84d40be8336958447b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -882,6 +997,7 @@ dependencies = [ "cfg-if", "libc", "r-efi 6.0.0", + "rand_core 0.10.0", "wasip2", "wasip3", ] @@ -958,12 +1074,6 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" -[[package]] -name = "hermit-abi" -version = "0.5.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c" - [[package]] name = "hex" version = "0.4.3" @@ -986,7 +1096,7 @@ dependencies = [ "idna", "ipnet", "once_cell", - "rand", + "rand 0.9.2", "ring", "thiserror 2.0.18", "tinyvec", @@ -1008,7 +1118,7 @@ dependencies = [ "moka", "once_cell", "parking_lot", - "rand", + "rand 0.9.2", "resolv-conf", "smallvec", "thiserror 2.0.18", @@ -1116,7 +1226,6 @@ dependencies = [ "tokio", "tokio-rustls", "tower-service", - "webpki-roots 1.0.6", ] [[package]] @@ -1286,17 +1395,6 @@ dependencies = [ "serde_core", ] -[[package]] -name = "inotify" -version = "0.9.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8069d3ec154eb856955c1c0fbffefbf5f3c40a104ec912d4797314c1801abff" -dependencies = [ - "bitflags 1.3.2", - "inotify-sys", - "libc", -] - [[package]] name = "inotify" version = "0.11.1" @@ -1347,9 +1445,9 @@ checksum = "d98f6fed1fde3f8c21bc40a1abb88dd75e67924f9cffc3ef95607bad8017f8e2" [[package]] name = "ipnetwork" -version = "0.20.0" +version = "0.21.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf466541e9d546596ee94f9f69590f89473455f88372423e0008fc1a7daf100e" +checksum = "cf370abdafd54d13e54a620e8c3e1145f28e46cc9d704bc6d94414559df41763" dependencies = [ "serde", ] @@ -1364,22 +1462,11 @@ dependencies = [ "serde", ] -[[package]] -name = "is-terminal" -version = "0.4.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3640c1c38b8e4e43584d8df18be5fc6b0aa314ce6ebf51b53313d4306cca8e46" -dependencies = [ - "hermit-abi", - "libc", - "windows-sys 0.61.2", -] - [[package]] name = "itertools" -version = "0.10.5" +version = "0.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" dependencies = [ "either", ] @@ -1390,6 +1477,38 @@ version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8f42a60cbdf9a97f5d2305f08a87dc4e09308d1276d28c869c684d7777685682" +[[package]] +name = "jni" +version = "0.21.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a87aa2bb7d2af34197c04845522473242e1aa17c12f4935d5856491a7fb8c97" +dependencies = [ + "cesu8", + "cfg-if", + "combine", + "jni-sys", + "log", + "thiserror 1.0.69", + "walkdir", + "windows-sys 0.45.0", +] + +[[package]] +name = "jni-sys" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130" + +[[package]] +name = "jobserver" +version = "0.1.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9afb3de4395d6b3e67a780b6de64b51c978ecf11cb9a462c66be7d4ca9039d33" +dependencies = [ + "getrandom 0.3.4", + "libc", +] + [[package]] name = "js-sys" version = "0.3.91" @@ -1438,18 +1557,6 @@ version = "0.2.183" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b5b646652bf6661599e1da8901b3b9522896f01e736bad5f723fe7a3a27f899d" -[[package]] -name = "libredox" -version = "0.1.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1744e39d1d6a9948f4f388969627434e31128196de472883b39f148769bfe30a" -dependencies = [ - "bitflags 2.11.0", - "libc", - "plain", - "redox_syscall 0.7.3", -] - [[package]] name = "linux-raw-sys" version = "0.12.1" @@ -1538,18 +1645,6 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" -[[package]] -name = "mio" -version = "0.8.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4a650543ca06a924e8b371db273b2756685faae30f8487da1b56505a8f78b0c" -dependencies = [ - "libc", - "log", - "wasi", - "windows-sys 0.48.0", -] - [[package]] name = "mio" version = "1.1.1" @@ -1581,13 +1676,13 @@ dependencies = [ [[package]] name = "nix" -version = "0.28.0" +version = "0.31.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab2156c4fce2f8df6c499cc1c763e4394b7482525bf2a9701c9d79d215f519e4" +checksum = "5d6d0705320c1e6ba1d912b5e37cf18071b6c2e9b7fa8215a1e8a7651966f5d3" dependencies = [ "bitflags 2.11.0", "cfg-if", - "cfg_aliases 0.1.1", + "cfg_aliases", "libc", "memoffset", ] @@ -1602,25 +1697,6 @@ dependencies = [ "minimal-lexical", ] -[[package]] -name = "notify" -version = "6.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6205bd8bb1e454ad2e27422015fb5e4f2bcc7e08fa8f27058670d208324a4d2d" -dependencies = [ - "bitflags 2.11.0", - "crossbeam-channel", - "filetime", - "fsevent-sys", - "inotify 0.9.6", - "kqueue", - "libc", - "log", - "mio 0.8.11", - "walkdir", - "windows-sys 0.48.0", -] - [[package]] name = "notify" version = "8.2.0" @@ -1629,11 +1705,11 @@ checksum = "4d3d07927151ff8575b7087f245456e549fea62edf0ec4e565a5ee50c8402bc3" dependencies = [ "bitflags 2.11.0", "fsevent-sys", - "inotify 0.11.1", + "inotify", "kqueue", "libc", "log", - "mio 1.1.1", + "mio", "notify-types", "walkdir", "windows-sys 0.60.2", @@ -1693,9 +1769,9 @@ dependencies = [ [[package]] name = "oid-registry" -version = "0.6.1" +version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9bedf36ffb6ba96c2eb7144ef6270557b52e54b20c0a8e1eb2ff99a6c6959bff" +checksum = "12f40cff3dde1b6087cc5d5f5d4d65712f34016a03ed60e9c08dcc392736b5b7" dependencies = [ "asn1-rs", ] @@ -1722,6 +1798,22 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381" +[[package]] +name = "openssl-probe" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe" + +[[package]] +name = "page_size" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30d5b2194ed13191c1999ae0704b7839fb18384fa22e49b57eeaa97d79ce40da" +dependencies = [ + "libc", + "winapi", +] + [[package]] name = "parking_lot" version = "0.12.5" @@ -1740,7 +1832,7 @@ checksum = "2621685985a2ebf1c516881c026032ac7deafcda1a2c9b7850dc81e3dfcb64c1" dependencies = [ "cfg-if", "libc", - "redox_syscall 0.5.18", + "redox_syscall", "smallvec", "windows-link", ] @@ -1768,7 +1860,7 @@ checksum = "d9b20ed30f105399776b9c883e68e536ef602a16ae6f596d2c473591d6ad64c6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -1793,12 +1885,6 @@ dependencies = [ "spki", ] -[[package]] -name = "plain" -version = "0.2.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4596b6d070b27117e987119b4dac604f3c58cfb0b191112e24771b2faeac1a6" - [[package]] name = "plotters" version = "0.3.7" @@ -1833,7 +1919,7 @@ version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8159bd90725d2df49889a078b54f4f79e87f1f8a8444194cdca81d38f5393abf" dependencies = [ - "cpufeatures", + "cpufeatures 0.2.17", "opaque-debug", "universal-hash", ] @@ -1845,7 +1931,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9d1fe60d06143b2430aa532c94cfe9e29783047f06c0d7fd359a9a51b729fa25" dependencies = [ "cfg-if", - "cpufeatures", + "cpufeatures 0.2.17", "opaque-debug", "universal-hash", ] @@ -1887,7 +1973,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" dependencies = [ "proc-macro2", - "syn 2.0.117", + "syn", ] [[package]] @@ -1909,7 +1995,7 @@ dependencies = [ "bit-vec", "bitflags 2.11.0", "num-traits", - "rand", + "rand 0.9.2", "rand_chacha", "rand_xorshift", "regex-syntax", @@ -1931,7 +2017,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b9e20a958963c291dc322d98411f541009df2ced7b5a4f2bd52337638cfccf20" dependencies = [ "bytes", - "cfg_aliases 0.2.1", + "cfg_aliases", "pin-project-lite", "quinn-proto", "quinn-udp", @@ -1950,10 +2036,11 @@ version = "0.11.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "434b42fec591c96ef50e21e886936e66d3cc3f737104fdb9b737c40ffb94c098" dependencies = [ + "aws-lc-rs", "bytes", "getrandom 0.3.4", "lru-slab", - "rand", + "rand 0.9.2", "ring", "rustc-hash", "rustls", @@ -1971,7 +2058,7 @@ version = "0.5.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "addec6a0dcad8a8d96a771f815f0eaf55f9d1805756410b39f5fa81332574cbd" dependencies = [ - "cfg_aliases 0.2.1", + "cfg_aliases", "libc", "once_cell", "socket2 0.6.3", @@ -2010,6 +2097,17 @@ dependencies = [ "rand_core 0.9.5", ] +[[package]] +name = "rand" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc266eb313df6c5c09c1c7b1fbe2510961e5bcd3add930c1e31f7ed9da0feff8" +dependencies = [ + "chacha20 0.10.0", + "getrandom 0.4.2", + "rand_core 0.10.0", +] + [[package]] name = "rand_chacha" version = "0.9.0" @@ -2038,6 +2136,12 @@ dependencies = [ "getrandom 0.3.4", ] +[[package]] +name = "rand_core" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c8d0fd677905edcbeedbf2edb6494d676f0e98d54d5cf9bda0b061cb8fb8aba" + [[package]] name = "rand_xorshift" version = "0.4.0" @@ -2076,15 +2180,6 @@ dependencies = [ "bitflags 2.11.0", ] -[[package]] -name = "redox_syscall" -version = "0.7.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ce70a74e890531977d37e532c34d45e9055d2409ed08ddba14529471ed0be16" -dependencies = [ - "bitflags 2.11.0", -] - [[package]] name = "regex" version = "1.12.3" @@ -2116,9 +2211,9 @@ checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a" [[package]] name = "reqwest" -version = "0.12.28" +version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eddd3ca559203180a307f12d114c268abf583f59b03cb906fd0b3ff8646c1147" +checksum = "ab3f43e3283ab1488b624b44b0e988d0acea0b3214e694730a055cb6b2efa801" dependencies = [ "base64", "bytes", @@ -2136,9 +2231,7 @@ dependencies = [ "quinn", "rustls", "rustls-pki-types", - "serde", - "serde_json", - "serde_urlencoded", + "rustls-platform-verifier", "sync_wrapper", "tokio", "tokio-rustls", @@ -2149,7 +2242,6 @@ dependencies = [ "wasm-bindgen", "wasm-bindgen-futures", "web-sys", - "webpki-roots 1.0.6", ] [[package]] @@ -2228,6 +2320,7 @@ version = "0.23.37" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "758025cb5fccfd3bc2fd74708fd4682be41d99e5dff73c377c0646c6012c73a4" dependencies = [ + "aws-lc-rs", "once_cell", "ring", "rustls-pki-types", @@ -2236,6 +2329,18 @@ dependencies = [ "zeroize", ] +[[package]] +name = "rustls-native-certs" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "612460d5f7bea540c490b2b6395d8e34a953e52b491accd6c86c8164c5932a63" +dependencies = [ + "openssl-probe", + "rustls-pki-types", + "schannel", + "security-framework", +] + [[package]] name = "rustls-pki-types" version = "1.14.0" @@ -2247,11 +2352,39 @@ dependencies = [ ] [[package]] -name = "rustls-webpki" -version = "0.103.9" +name = "rustls-platform-verifier" +version = "0.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7df23109aa6c1567d1c575b9952556388da57401e4ace1d15f79eedad0d8f53" +checksum = "1d99feebc72bae7ab76ba994bb5e121b8d83d910ca40b36e0921f53becc41784" dependencies = [ + "core-foundation", + "core-foundation-sys", + "jni", + "log", + "once_cell", + "rustls", + "rustls-native-certs", + "rustls-platform-verifier-android", + "rustls-webpki", + "security-framework", + "security-framework-sys", + "webpki-root-certs", + "windows-sys 0.61.2", +] + +[[package]] +name = "rustls-platform-verifier-android" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f87165f0995f63a9fbeea62b64d10b4d9d8e78ec6d7d51fb2125fda7bb36788f" + +[[package]] +name = "rustls-webpki" +version = "0.103.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df33b2b81ac578cabaf06b89b0631153a3f416b0a886e8a7a1707fb51abbd1ef" +dependencies = [ + "aws-lc-rs", "ring", "rustls-pki-types", "untrusted", @@ -2290,6 +2423,15 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "schannel" +version = "0.1.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91c1b7e4904c873ef0710c1f407dde2e6287de2bebc1bbbf7d430bb7cbffd939" +dependencies = [ + "windows-sys 0.61.2", +] + [[package]] name = "scopeguard" version = "1.2.0" @@ -2304,7 +2446,30 @@ checksum = "22f968c5ea23d555e670b449c1c5e7b2fc399fdaec1d304a17cd48e288abc107" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", +] + +[[package]] +name = "security-framework" +version = "3.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7f4bc775c73d9a02cde8bf7b2ec4c9d12743edf609006c7facc23998404cd1d" +dependencies = [ + "bitflags 2.11.0", + "core-foundation", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework-sys" +version = "2.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ce2691df843ecc5d231c0b14ece2acc3efb62c0a398c7e1d875f3983ce020e3" +dependencies = [ + "core-foundation-sys", + "libc", ] [[package]] @@ -2350,7 +2515,7 @@ checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -2368,11 +2533,11 @@ dependencies = [ [[package]] name = "serde_spanned" -version = "0.6.9" +version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf41e0cfaf7226dca15e8197172c295a782857fcb97fad1808a166870dee75a3" +checksum = "f8bbf91e5a4d6315eee45e704372590b30e260ee83af6639d64557f51b067776" dependencies = [ - "serde", + "serde_core", ] [[package]] @@ -2394,7 +2559,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" dependencies = [ "cfg-if", - "cpufeatures", + "cpufeatures 0.2.17", "digest", ] @@ -2405,7 +2570,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283" dependencies = [ "cfg-if", - "cpufeatures", + "cpufeatures 0.2.17", "digest", ] @@ -2428,10 +2593,10 @@ dependencies = [ "libc", "log", "lru_time_cache", - "notify 8.2.0", + "notify", "percent-encoding", "pin-project", - "rand", + "rand 0.9.2", "sealed", "sendfd", "serde", @@ -2462,7 +2627,7 @@ dependencies = [ "chacha20poly1305", "hkdf", "md-5", - "rand", + "rand 0.9.2", "ring-compat", "sha1", ] @@ -2555,23 +2720,18 @@ version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596" +[[package]] +name = "static_assertions" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" + [[package]] name = "subtle" version = "2.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" -[[package]] -name = "syn" -version = "1.0.109" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" -dependencies = [ - "proc-macro2", - "quote", - "unicode-ident", -] - [[package]] name = "syn" version = "2.0.117" @@ -2592,18 +2752,6 @@ dependencies = [ "futures-core", ] -[[package]] -name = "synstructure" -version = "0.12.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f36bdaa60a83aca3921b5259d5400cbf5e90fc51931376a9bd4a0eb79aa7210f" -dependencies = [ - "proc-macro2", - "quote", - "syn 1.0.109", - "unicode-xid", -] - [[package]] name = "synstructure" version = "0.13.2" @@ -2612,7 +2760,7 @@ checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -2623,7 +2771,7 @@ checksum = "7b2093cf4c8eb1e67749a6762251bc9cd836b6fc171623bd0a9d324d37af2417" [[package]] name = "telemt" -version = "3.3.28" +version = "3.4.0" dependencies = [ "aes", "anyhow", @@ -2650,12 +2798,12 @@ dependencies = [ "lru", "md-5", "nix", - "notify 6.1.1", + "notify", "num-bigint", "num-traits", "parking_lot", "proptest", - "rand", + "rand 0.10.0", "regex", "reqwest", "rustls", @@ -2664,7 +2812,9 @@ dependencies = [ "sha1", "sha2", "shadowsocks", - "socket2 0.5.10", + "socket2 0.6.3", + "static_assertions", + "subtle", "thiserror 2.0.18", "tokio", "tokio-rustls", @@ -2674,7 +2824,8 @@ dependencies = [ "tracing", "tracing-subscriber", "url", - "webpki-roots 0.26.11", + "webpki-roots", + "x25519-dalek", "x509-parser", "zeroize", ] @@ -2718,7 +2869,7 @@ checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -2729,7 +2880,7 @@ checksum = "ebc4ee7f67670e9b64d05fa4253e753e016c6c95ff35b89b7941d6b856dec1d5" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -2815,7 +2966,7 @@ checksum = "27ad5e34374e03cfffefc301becb44e9dc3c17584f414349ebe29ed26661822d" dependencies = [ "bytes", "libc", - "mio 1.1.1", + "mio", "parking_lot", "pin-project-lite", "signal-hook-registry", @@ -2833,7 +2984,7 @@ checksum = "5c55a2eff8b69ce66c84f85e1da1c233edc36ceb85a2058d11b0d6a3c7e7569c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -2904,44 +3055,42 @@ dependencies = [ [[package]] name = "toml" -version = "0.8.23" +version = "1.0.7+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc1beb996b9d83529a9e75c17a1686767d148d70663143c7854d8b4a09ced362" -dependencies = [ - "serde", - "serde_spanned", - "toml_datetime", - "toml_edit", -] - -[[package]] -name = "toml_datetime" -version = "0.6.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22cddaf88f4fbc13c51aebbf5f8eceb5c7c5a9da2ac40a13519eb5b0a0e8f11c" -dependencies = [ - "serde", -] - -[[package]] -name = "toml_edit" -version = "0.22.27" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "41fe8c660ae4257887cf66394862d21dbca4a6ddd26f04a3560410406a2f819a" +checksum = "dd28d57d8a6f6e458bc0b8784f8fdcc4b99a437936056fa122cb234f18656a96" dependencies = [ "indexmap", - "serde", + "serde_core", "serde_spanned", "toml_datetime", - "toml_write", + "toml_parser", + "toml_writer", "winnow", ] [[package]] -name = "toml_write" -version = "0.1.2" +name = "toml_datetime" +version = "1.0.1+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d99f8c9a7727884afe522e9bd5edbfc91a3312b36a77b5fb8926e4c31a41801" +checksum = "9b320e741db58cac564e26c607d3cc1fdc4a88fd36c879568c07856ed83ff3e9" +dependencies = [ + "serde_core", +] + +[[package]] +name = "toml_parser" +version = "1.0.10+spec-1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7df25b4befd31c4816df190124375d5a20c6b6921e2cad937316de3fccd63420" +dependencies = [ + "winnow", +] + +[[package]] +name = "toml_writer" +version = "1.0.7+spec-1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f17aaa1c6e3dc22b1da4b6bba97d066e354c7945cac2f7852d4e4e7ca7a6b56d" [[package]] name = "tower" @@ -3007,7 +3156,7 @@ checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -3057,7 +3206,7 @@ checksum = "70977707304198400eb4835a78f6a9f928bf41bba420deb8fdb175cd965d77a7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -3245,7 +3394,7 @@ dependencies = [ "bumpalo", "proc-macro2", "quote", - "syn 2.0.117", + "syn", "wasm-bindgen-shared", ] @@ -3313,12 +3462,12 @@ dependencies = [ ] [[package]] -name = "webpki-roots" -version = "0.26.11" +name = "webpki-root-certs" +version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "521bc38abb08001b01866da9f51eb7c5d647a19260e00054a8c7fd5f9e57f7a9" +checksum = "804f18a4ac2676ffb4e8b5b5fa9ae38af06df08162314f96a68d2a363e21a8ca" dependencies = [ - "webpki-roots 1.0.6", + "rustls-pki-types", ] [[package]] @@ -3336,6 +3485,22 @@ version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72069c3113ab32ab29e5584db3c6ec55d416895e60715417b5b883a357c3e471" +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + [[package]] name = "winapi-util" version = "0.1.11" @@ -3345,6 +3510,12 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + [[package]] name = "windows-core" version = "0.62.2" @@ -3366,7 +3537,7 @@ checksum = "053e2e040ab57b9dc951b72c264860db7eb3b0200ba345b4e4c3b14f67855ddf" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -3377,7 +3548,7 @@ checksum = "3f316c4a2570ba26bbec722032c4099d8c8bc095efccdc15688708623367e358" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -3404,6 +3575,15 @@ dependencies = [ "windows-link", ] +[[package]] +name = "windows-sys" +version = "0.45.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75283be5efb2831d37ea142365f009c02ec203cd29a3ebecbc093d52315b66d0" +dependencies = [ + "windows-targets 0.42.2", +] + [[package]] name = "windows-sys" version = "0.48.0" @@ -3440,6 +3620,21 @@ dependencies = [ "windows-link", ] +[[package]] +name = "windows-targets" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e5180c00cd44c9b1c88adb3693291f1cd93605ded80c250a75d472756b4d071" +dependencies = [ + "windows_aarch64_gnullvm 0.42.2", + "windows_aarch64_msvc 0.42.2", + "windows_i686_gnu 0.42.2", + "windows_i686_msvc 0.42.2", + "windows_x86_64_gnu 0.42.2", + "windows_x86_64_gnullvm 0.42.2", + "windows_x86_64_msvc 0.42.2", +] + [[package]] name = "windows-targets" version = "0.48.5" @@ -3488,6 +3683,12 @@ dependencies = [ "windows_x86_64_msvc 0.53.1", ] +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8" + [[package]] name = "windows_aarch64_gnullvm" version = "0.48.5" @@ -3506,6 +3707,12 @@ version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a9d8416fa8b42f5c947f8482c43e7d89e73a173cead56d044f6a56104a6d1b53" +[[package]] +name = "windows_aarch64_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43" + [[package]] name = "windows_aarch64_msvc" version = "0.48.5" @@ -3524,6 +3731,12 @@ version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b9d782e804c2f632e395708e99a94275910eb9100b2114651e04744e9b125006" +[[package]] +name = "windows_i686_gnu" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f" + [[package]] name = "windows_i686_gnu" version = "0.48.5" @@ -3554,6 +3767,12 @@ version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fa7359d10048f68ab8b09fa71c3daccfb0e9b559aed648a8f95469c27057180c" +[[package]] +name = "windows_i686_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060" + [[package]] name = "windows_i686_msvc" version = "0.48.5" @@ -3572,6 +3791,12 @@ version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e7ac75179f18232fe9c285163565a57ef8d3c89254a30685b57d83a38d326c2" +[[package]] +name = "windows_x86_64_gnu" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36" + [[package]] name = "windows_x86_64_gnu" version = "0.48.5" @@ -3590,6 +3815,12 @@ version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c3842cdd74a865a8066ab39c8a7a473c0778a3f29370b5fd6b4b9aa7df4a499" +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3" + [[package]] name = "windows_x86_64_gnullvm" version = "0.48.5" @@ -3608,6 +3839,12 @@ version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ffa179e2d07eee8ad8f57493436566c7cc30ac536a3379fdf008f47f6bb7ae1" +[[package]] +name = "windows_x86_64_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0" + [[package]] name = "windows_x86_64_msvc" version = "0.48.5" @@ -3628,12 +3865,9 @@ checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650" [[package]] name = "winnow" -version = "0.7.15" +version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df79d97927682d2fd8adb29682d1140b343be4ac0f08fd68b7765d9c059d3945" -dependencies = [ - "memchr", -] +checksum = "a90e88e4667264a994d34e6d1ab2d26d398dcdca8b7f52bec8668957517fc7d8" [[package]] name = "winreg" @@ -3675,7 +3909,7 @@ dependencies = [ "heck", "indexmap", "prettyplease", - "syn 2.0.117", + "syn", "wasm-metadata", "wit-bindgen-core", "wit-component", @@ -3691,7 +3925,7 @@ dependencies = [ "prettyplease", "proc-macro2", "quote", - "syn 2.0.117", + "syn", "wit-bindgen-core", "wit-bindgen-rust", ] @@ -3740,10 +3974,22 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9edde0db4769d2dc68579893f2306b26c6ecfbe0ef499b013d731b7b9247e0b9" [[package]] -name = "x509-parser" -version = "0.15.1" +name = "x25519-dalek" +version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7069fba5b66b9193bd2c5d3d4ff12b839118f6bcbef5328efafafb5395cf63da" +checksum = "c7e468321c81fb07fa7f4c636c3972b9100f0346e5b6a9f2bd0603a52f7ed277" +dependencies = [ + "curve25519-dalek", + "rand_core 0.6.4", + "serde", + "zeroize", +] + +[[package]] +name = "x509-parser" +version = "0.18.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d43b0f71ce057da06bc0851b23ee24f3f86190b07203dd8f567d0b706a185202" dependencies = [ "asn1-rs", "data-encoding", @@ -3752,7 +3998,7 @@ dependencies = [ "nom", "oid-registry", "rusticata-macros", - "thiserror 1.0.69", + "thiserror 2.0.18", "time", ] @@ -3775,8 +4021,8 @@ checksum = "b659052874eb698efe5b9e8cf382204678a0086ebf46982b79d6ca3182927e5d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", - "synstructure 0.13.2", + "syn", + "synstructure", ] [[package]] @@ -3796,7 +4042,7 @@ checksum = "0e8bc7269b54418e7aeeef514aa68f8690b8c0489a06b0136e5f57c4c5ccab89" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -3816,8 +4062,8 @@ checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", - "synstructure 0.13.2", + "syn", + "synstructure", ] [[package]] @@ -3837,7 +4083,7 @@ checksum = "85a5b4158499876c763cb03bc4e49185d3cccbabb15b33c627f7884f43db852e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] @@ -3870,7 +4116,7 @@ checksum = "eadce39539ca5cb3985590102671f2567e659fca9666581ad3411d59207951f3" dependencies = [ "proc-macro2", "quote", - "syn 2.0.117", + "syn", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 97855f3..b4ea034 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "telemt" -version = "3.3.28" +version = "3.4.0" edition = "2024" [dependencies] @@ -22,17 +22,19 @@ hmac = "0.12" crc32fast = "1.4" crc32c = "0.6" zeroize = { version = "1.8", features = ["derive"] } +subtle = "2.6" +static_assertions = "1.1" # Network -socket2 = { version = "0.5", features = ["all"] } -nix = { version = "0.28", default-features = false, features = ["net"] } +socket2 = { version = "0.6", features = ["all"] } +nix = { version = "0.31", default-features = false, features = ["net"] } shadowsocks = { version = "1.24", features = ["aead-cipher-2022"] } # Serialization serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" -toml = "0.8" -x509-parser = "0.15" +toml = "1.0" +x509-parser = "0.18" # Utils bytes = "1.9" @@ -40,10 +42,10 @@ thiserror = "2.0" tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } parking_lot = "0.12" -dashmap = "5.5" +dashmap = "6.1" arc-swap = "1.7" lru = "0.16" -rand = "0.9" +rand = "0.10" chrono = { version = "0.4", features = ["serde"] } hex = "0.4" base64 = "0.22" @@ -52,23 +54,24 @@ regex = "1.11" crossbeam-queue = "0.3" num-bigint = "0.4" num-traits = "0.2" +x25519-dalek = "2" anyhow = "1.0" # HTTP -reqwest = { version = "0.12", features = ["rustls-tls"], default-features = false } -notify = { version = "6", features = ["macos_fsevent"] } -ipnetwork = "0.20" +reqwest = { version = "0.13", features = ["rustls"], default-features = false } +notify = "8.2" +ipnetwork = { version = "0.21", features = ["serde"] } hyper = { version = "1", features = ["server", "http1"] } hyper-util = { version = "0.1", features = ["tokio", "server-auto"] } http-body-util = "0.1" httpdate = "1.0" tokio-rustls = { version = "0.26", default-features = false, features = ["tls12"] } rustls = { version = "0.23", default-features = false, features = ["std", "tls12", "ring"] } -webpki-roots = "0.26" +webpki-roots = "1.0" [dev-dependencies] tokio-test = "0.4" -criterion = "0.5" +criterion = "0.8" proptest = "1.4" futures = "0.3" diff --git a/docs/CONFIG_PARAMS.en.md b/docs/CONFIG_PARAMS.en.md index 90da08a..738550c 100644 --- a/docs/CONFIG_PARAMS.en.md +++ b/docs/CONFIG_PARAMS.en.md @@ -260,6 +260,65 @@ This document lists all configuration keys accepted by `config.toml`. | tls_full_cert_ttl_secs | `u64` | `90` | — | TTL for sending full cert payload per (domain, client IP) tuple. | | alpn_enforce | `bool` | `true` | — | Enforces ALPN echo behavior based on client preference. | | mask_proxy_protocol | `u8` | `0` | — | PROXY protocol mode for mask backend (`0` disabled, `1` v1, `2` v2). | +| mask_shape_hardening | `bool` | `true` | — | Enables client->mask shape-channel hardening by applying controlled tail padding to bucket boundaries on mask relay shutdown. | +| mask_shape_bucket_floor_bytes | `usize` | `512` | Must be `> 0`; should be `<= mask_shape_bucket_cap_bytes`. | Minimum bucket size used by shape-channel hardening. | +| mask_shape_bucket_cap_bytes | `usize` | `4096` | Must be `>= mask_shape_bucket_floor_bytes`. | Maximum bucket size used by shape-channel hardening; traffic above cap is not padded further. | +| mask_shape_above_cap_blur | `bool` | `false` | Requires `mask_shape_hardening = true`; requires `mask_shape_above_cap_blur_max_bytes > 0`. | Adds bounded randomized tail bytes even when forwarded size already exceeds cap. | +| mask_shape_above_cap_blur_max_bytes | `usize` | `512` | Must be `<= 1048576`; must be `> 0` when `mask_shape_above_cap_blur = true`. | Maximum randomized extra bytes appended above cap. | +| mask_timing_normalization_enabled | `bool` | `false` | Requires `mask_timing_normalization_floor_ms > 0`; requires `ceiling >= floor`. | Enables timing envelope normalization on masking outcomes. | +| mask_timing_normalization_floor_ms | `u64` | `0` | Must be `> 0` when timing normalization is enabled; must be `<= ceiling`. | Lower bound (ms) for masking outcome normalization target. | +| mask_timing_normalization_ceiling_ms | `u64` | `0` | Must be `>= floor`; must be `<= 60000`. | Upper bound (ms) for masking outcome normalization target. | + +### Shape-channel hardening notes (`[censorship]`) + +These parameters are designed to reduce one specific fingerprint source during masking: the exact number of bytes sent from proxy to `mask_host` for invalid or probing traffic. + +Without hardening, a censor can often correlate probe input length with backend-observed length very precisely (for example: `5 + body_sent` on early TLS reject paths). That creates a length-based classifier signal. + +When `mask_shape_hardening = true`, Telemt pads the **client->mask** stream tail to a bucket boundary at relay shutdown: + +- Total bytes sent to mask are first measured. +- A bucket is selected using powers of two starting from `mask_shape_bucket_floor_bytes`. +- Padding is added only if total bytes are below `mask_shape_bucket_cap_bytes`. +- If bytes already exceed cap, no extra padding is added. + +This means multiple nearby probe sizes collapse into the same backend-observed size class, making active classification harder. + +Practical trade-offs: + +- Better anti-fingerprinting on size/shape channel. +- Slightly higher egress overhead for small probes due to padding. +- Behavior is intentionally conservative and enabled by default. + +Recommended starting profile: + +- `mask_shape_hardening = true` (default) +- `mask_shape_bucket_floor_bytes = 512` +- `mask_shape_bucket_cap_bytes = 4096` + +### Above-cap blur notes (`[censorship]`) + +`mask_shape_above_cap_blur` adds a second-stage blur for very large probes that are already above `mask_shape_bucket_cap_bytes`. + +- A random tail in `[0, mask_shape_above_cap_blur_max_bytes]` is appended. +- This reduces exact-size leakage above cap at bounded overhead. +- Keep `mask_shape_above_cap_blur_max_bytes` conservative to avoid unnecessary egress growth. + +### Timing normalization envelope notes (`[censorship]`) + +`mask_timing_normalization_enabled` smooths timing differences between masking outcomes by applying a target duration envelope. + +- A random target is selected in `[mask_timing_normalization_floor_ms, mask_timing_normalization_ceiling_ms]`. +- Fast paths are delayed up to the selected target. +- Slow paths are not forced to finish by the ceiling (the envelope is best-effort shaping, not truncation). + +Recommended starting profile for timing shaping: + +- `mask_timing_normalization_enabled = true` +- `mask_timing_normalization_floor_ms = 180` +- `mask_timing_normalization_ceiling_ms = 320` + +If your backend or network is very bandwidth-constrained, reduce cap first. If probes are still too distinguishable in your environment, increase floor gradually. ## [access] diff --git a/docs/model/FakeTLS.png b/docs/model/FakeTLS.png new file mode 100644 index 0000000..5f6782e Binary files /dev/null and b/docs/model/FakeTLS.png differ diff --git a/docs/model/architecture.png b/docs/model/architecture.png new file mode 100644 index 0000000..71d4a17 Binary files /dev/null and b/docs/model/architecture.png differ diff --git a/src/api/model.rs b/src/api/model.rs index 6578d35..94b50f6 100644 --- a/src/api/model.rs +++ b/src/api/model.rs @@ -1,10 +1,12 @@ use std::net::IpAddr; +use std::sync::OnceLock; use chrono::{DateTime, Utc}; use hyper::StatusCode; -use rand::Rng; use serde::{Deserialize, Serialize}; +use crate::crypto::SecureRandom; + const MAX_USERNAME_LEN: usize = 64; #[derive(Debug)] @@ -196,8 +198,6 @@ pub(super) struct ZeroPoolData { pub(super) pool_swap_total: u64, pub(super) pool_drain_active: u64, pub(super) pool_force_close_total: u64, - pub(super) pool_drain_soft_evict_total: u64, - pub(super) pool_drain_soft_evict_writer_total: u64, pub(super) pool_stale_pick_total: u64, pub(super) writer_removed_total: u64, pub(super) writer_removed_unexpected_total: u64, @@ -206,16 +206,6 @@ pub(super) struct ZeroPoolData { pub(super) refill_failed_total: u64, pub(super) writer_restored_same_endpoint_total: u64, pub(super) writer_restored_fallback_total: u64, - pub(super) teardown_attempt_total_normal: u64, - pub(super) teardown_attempt_total_hard_detach: u64, - pub(super) teardown_success_total_normal: u64, - pub(super) teardown_success_total_hard_detach: u64, - pub(super) teardown_timeout_total: u64, - pub(super) teardown_escalation_total: u64, - pub(super) teardown_noop_total: u64, - pub(super) teardown_cleanup_side_effect_failures_total: u64, - pub(super) teardown_duration_count_total: u64, - pub(super) teardown_duration_sum_seconds_total: f64, } #[derive(Serialize, Clone)] @@ -248,7 +238,6 @@ pub(super) struct MeWritersSummary { pub(super) available_pct: f64, pub(super) required_writers: usize, pub(super) alive_writers: usize, - pub(super) coverage_ratio: f64, pub(super) coverage_pct: f64, pub(super) fresh_alive_writers: usize, pub(super) fresh_coverage_pct: f64, @@ -297,7 +286,6 @@ pub(super) struct DcStatus { pub(super) floor_max: usize, pub(super) floor_capped: bool, pub(super) alive_writers: usize, - pub(super) coverage_ratio: f64, pub(super) coverage_pct: f64, pub(super) fresh_alive_writers: usize, pub(super) fresh_coverage_pct: f64, @@ -375,12 +363,6 @@ pub(super) struct MinimalMeRuntimeData { pub(super) me_reconnect_backoff_cap_ms: u64, pub(super) me_reconnect_fast_retry_count: u32, pub(super) me_pool_drain_ttl_secs: u64, - pub(super) me_instadrain: bool, - pub(super) me_pool_drain_soft_evict_enabled: bool, - pub(super) me_pool_drain_soft_evict_grace_secs: u64, - pub(super) me_pool_drain_soft_evict_per_writer: u8, - pub(super) me_pool_drain_soft_evict_budget_per_core: u16, - pub(super) me_pool_drain_soft_evict_cooldown_ms: u64, pub(super) me_pool_force_close_secs: u64, pub(super) me_pool_min_fresh_ratio: f32, pub(super) me_bind_stale_mode: &'static str, @@ -502,7 +484,9 @@ pub(super) fn is_valid_username(user: &str) -> bool { } pub(super) fn random_user_secret() -> String { + static API_SECRET_RNG: OnceLock = OnceLock::new(); + let rng = API_SECRET_RNG.get_or_init(SecureRandom::new); let mut bytes = [0u8; 16]; - rand::rng().fill(&mut bytes); + rng.fill(&mut bytes); hex::encode(bytes) } diff --git a/src/api/runtime_min.rs b/src/api/runtime_min.rs index 047fd9c..986f138 100644 --- a/src/api/runtime_min.rs +++ b/src/api/runtime_min.rs @@ -4,9 +4,6 @@ use std::time::{SystemTime, UNIX_EPOCH}; use serde::Serialize; use crate::config::ProxyConfig; -use crate::stats::{ - MeWriterCleanupSideEffectStep, MeWriterTeardownMode, MeWriterTeardownReason, Stats, -}; use super::ApiShared; @@ -101,50 +98,6 @@ pub(super) struct RuntimeMeQualityCountersData { pub(super) reconnect_success_total: u64, } -#[derive(Serialize)] -pub(super) struct RuntimeMeQualityTeardownAttemptData { - pub(super) reason: &'static str, - pub(super) mode: &'static str, - pub(super) total: u64, -} - -#[derive(Serialize)] -pub(super) struct RuntimeMeQualityTeardownSuccessData { - pub(super) mode: &'static str, - pub(super) total: u64, -} - -#[derive(Serialize)] -pub(super) struct RuntimeMeQualityTeardownSideEffectData { - pub(super) step: &'static str, - pub(super) total: u64, -} - -#[derive(Serialize)] -pub(super) struct RuntimeMeQualityTeardownDurationBucketData { - pub(super) le_seconds: &'static str, - pub(super) total: u64, -} - -#[derive(Serialize)] -pub(super) struct RuntimeMeQualityTeardownDurationData { - pub(super) mode: &'static str, - pub(super) count: u64, - pub(super) sum_seconds: f64, - pub(super) buckets: Vec, -} - -#[derive(Serialize)] -pub(super) struct RuntimeMeQualityTeardownData { - pub(super) attempts: Vec, - pub(super) success: Vec, - pub(super) timeout_total: u64, - pub(super) escalation_total: u64, - pub(super) noop_total: u64, - pub(super) cleanup_side_effect_failures: Vec, - pub(super) duration: Vec, -} - #[derive(Serialize)] pub(super) struct RuntimeMeQualityRouteDropData { pub(super) no_conn_total: u64, @@ -179,14 +132,12 @@ pub(super) struct RuntimeMeQualityDcRttData { pub(super) rtt_ema_ms: Option, pub(super) alive_writers: usize, pub(super) required_writers: usize, - pub(super) coverage_ratio: f64, pub(super) coverage_pct: f64, } #[derive(Serialize)] pub(super) struct RuntimeMeQualityPayload { pub(super) counters: RuntimeMeQualityCountersData, - pub(super) teardown: RuntimeMeQualityTeardownData, pub(super) route_drops: RuntimeMeQualityRouteDropData, pub(super) family_states: Vec, pub(super) drain_gate: RuntimeMeQualityDrainGateData, @@ -457,7 +408,6 @@ pub(super) async fn build_runtime_me_quality_data(shared: &ApiShared) -> Runtime reconnect_attempt_total: shared.stats.get_me_reconnect_attempts(), reconnect_success_total: shared.stats.get_me_reconnect_success(), }, - teardown: build_runtime_me_teardown_data(shared), route_drops: RuntimeMeQualityRouteDropData { no_conn_total: shared.stats.get_me_route_drop_no_conn(), channel_closed_total: shared.stats.get_me_route_drop_channel_closed(), @@ -480,7 +430,6 @@ pub(super) async fn build_runtime_me_quality_data(shared: &ApiShared) -> Runtime rtt_ema_ms: dc.rtt_ms, alive_writers: dc.alive_writers, required_writers: dc.required_writers, - coverage_ratio: dc.coverage_ratio, coverage_pct: dc.coverage_pct, }) .collect(), @@ -488,81 +437,6 @@ pub(super) async fn build_runtime_me_quality_data(shared: &ApiShared) -> Runtime } } -fn build_runtime_me_teardown_data(shared: &ApiShared) -> RuntimeMeQualityTeardownData { - let attempts = MeWriterTeardownReason::ALL - .iter() - .copied() - .flat_map(|reason| { - MeWriterTeardownMode::ALL - .iter() - .copied() - .map(move |mode| RuntimeMeQualityTeardownAttemptData { - reason: reason.as_str(), - mode: mode.as_str(), - total: shared.stats.get_me_writer_teardown_attempt_total(reason, mode), - }) - }) - .collect(); - - let success = MeWriterTeardownMode::ALL - .iter() - .copied() - .map(|mode| RuntimeMeQualityTeardownSuccessData { - mode: mode.as_str(), - total: shared.stats.get_me_writer_teardown_success_total(mode), - }) - .collect(); - - let cleanup_side_effect_failures = MeWriterCleanupSideEffectStep::ALL - .iter() - .copied() - .map(|step| RuntimeMeQualityTeardownSideEffectData { - step: step.as_str(), - total: shared - .stats - .get_me_writer_cleanup_side_effect_failures_total(step), - }) - .collect(); - - let duration = MeWriterTeardownMode::ALL - .iter() - .copied() - .map(|mode| { - let count = shared.stats.get_me_writer_teardown_duration_count(mode); - let mut buckets: Vec = Stats::me_writer_teardown_duration_bucket_labels() - .iter() - .enumerate() - .map(|(bucket_idx, label)| RuntimeMeQualityTeardownDurationBucketData { - le_seconds: label, - total: shared - .stats - .get_me_writer_teardown_duration_bucket_total(mode, bucket_idx), - }) - .collect(); - buckets.push(RuntimeMeQualityTeardownDurationBucketData { - le_seconds: "+Inf", - total: count, - }); - RuntimeMeQualityTeardownDurationData { - mode: mode.as_str(), - count, - sum_seconds: shared.stats.get_me_writer_teardown_duration_sum_seconds(mode), - buckets, - } - }) - .collect(); - - RuntimeMeQualityTeardownData { - attempts, - success, - timeout_total: shared.stats.get_me_writer_teardown_timeout_total(), - escalation_total: shared.stats.get_me_writer_teardown_escalation_total(), - noop_total: shared.stats.get_me_writer_teardown_noop_total(), - cleanup_side_effect_failures, - duration, - } -} - pub(super) async fn build_runtime_upstream_quality_data( shared: &ApiShared, ) -> RuntimeUpstreamQualityData { diff --git a/src/api/runtime_stats.rs b/src/api/runtime_stats.rs index 999e2cf..b646567 100644 --- a/src/api/runtime_stats.rs +++ b/src/api/runtime_stats.rs @@ -1,7 +1,7 @@ use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; use crate::config::ApiConfig; -use crate::stats::{MeWriterTeardownMode, Stats}; +use crate::stats::Stats; use crate::transport::upstream::IpPreference; use crate::transport::UpstreamRouteKind; @@ -96,8 +96,6 @@ pub(super) fn build_zero_all_data(stats: &Stats, configured_users: usize) -> Zer pool_swap_total: stats.get_pool_swap_total(), pool_drain_active: stats.get_pool_drain_active(), pool_force_close_total: stats.get_pool_force_close_total(), - pool_drain_soft_evict_total: stats.get_pool_drain_soft_evict_total(), - pool_drain_soft_evict_writer_total: stats.get_pool_drain_soft_evict_writer_total(), pool_stale_pick_total: stats.get_pool_stale_pick_total(), writer_removed_total: stats.get_me_writer_removed_total(), writer_removed_unexpected_total: stats.get_me_writer_removed_unexpected_total(), @@ -106,29 +104,6 @@ pub(super) fn build_zero_all_data(stats: &Stats, configured_users: usize) -> Zer refill_failed_total: stats.get_me_refill_failed_total(), writer_restored_same_endpoint_total: stats.get_me_writer_restored_same_endpoint_total(), writer_restored_fallback_total: stats.get_me_writer_restored_fallback_total(), - teardown_attempt_total_normal: stats - .get_me_writer_teardown_attempt_total_by_mode(MeWriterTeardownMode::Normal), - teardown_attempt_total_hard_detach: stats - .get_me_writer_teardown_attempt_total_by_mode(MeWriterTeardownMode::HardDetach), - teardown_success_total_normal: stats - .get_me_writer_teardown_success_total(MeWriterTeardownMode::Normal), - teardown_success_total_hard_detach: stats - .get_me_writer_teardown_success_total(MeWriterTeardownMode::HardDetach), - teardown_timeout_total: stats.get_me_writer_teardown_timeout_total(), - teardown_escalation_total: stats.get_me_writer_teardown_escalation_total(), - teardown_noop_total: stats.get_me_writer_teardown_noop_total(), - teardown_cleanup_side_effect_failures_total: stats - .get_me_writer_cleanup_side_effect_failures_total_all(), - teardown_duration_count_total: stats - .get_me_writer_teardown_duration_count(MeWriterTeardownMode::Normal) - .saturating_add( - stats.get_me_writer_teardown_duration_count(MeWriterTeardownMode::HardDetach), - ), - teardown_duration_sum_seconds_total: stats - .get_me_writer_teardown_duration_sum_seconds(MeWriterTeardownMode::Normal) - + stats.get_me_writer_teardown_duration_sum_seconds( - MeWriterTeardownMode::HardDetach, - ), }, desync: ZeroDesyncData { secure_padding_invalid_total: stats.get_secure_padding_invalid(), @@ -340,7 +315,6 @@ async fn get_minimal_payload_cached( available_pct: status.available_pct, required_writers: status.required_writers, alive_writers: status.alive_writers, - coverage_ratio: status.coverage_ratio, coverage_pct: status.coverage_pct, fresh_alive_writers: status.fresh_alive_writers, fresh_coverage_pct: status.fresh_coverage_pct, @@ -398,7 +372,6 @@ async fn get_minimal_payload_cached( floor_max: entry.floor_max, floor_capped: entry.floor_capped, alive_writers: entry.alive_writers, - coverage_ratio: entry.coverage_ratio, coverage_pct: entry.coverage_pct, fresh_alive_writers: entry.fresh_alive_writers, fresh_coverage_pct: entry.fresh_coverage_pct, @@ -452,12 +425,6 @@ async fn get_minimal_payload_cached( me_reconnect_backoff_cap_ms: runtime.me_reconnect_backoff_cap_ms, me_reconnect_fast_retry_count: runtime.me_reconnect_fast_retry_count, me_pool_drain_ttl_secs: runtime.me_pool_drain_ttl_secs, - me_instadrain: runtime.me_instadrain, - me_pool_drain_soft_evict_enabled: runtime.me_pool_drain_soft_evict_enabled, - me_pool_drain_soft_evict_grace_secs: runtime.me_pool_drain_soft_evict_grace_secs, - me_pool_drain_soft_evict_per_writer: runtime.me_pool_drain_soft_evict_per_writer, - me_pool_drain_soft_evict_budget_per_core: runtime.me_pool_drain_soft_evict_budget_per_core, - me_pool_drain_soft_evict_cooldown_ms: runtime.me_pool_drain_soft_evict_cooldown_ms, me_pool_force_close_secs: runtime.me_pool_force_close_secs, me_pool_min_fresh_ratio: runtime.me_pool_min_fresh_ratio, me_bind_stale_mode: runtime.me_bind_stale_mode, @@ -526,7 +493,6 @@ fn disabled_me_writers(now_epoch_secs: u64, reason: &'static str) -> MeWritersDa available_pct: 0.0, required_writers: 0, alive_writers: 0, - coverage_ratio: 0.0, coverage_pct: 0.0, fresh_alive_writers: 0, fresh_coverage_pct: 0.0, diff --git a/src/cli.rs b/src/cli.rs index 87dcfb5..035fe92 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -3,7 +3,7 @@ use std::fs; use std::path::{Path, PathBuf}; use std::process::Command; -use rand::Rng; +use rand::RngExt; /// Options for the init command pub struct InitOptions { @@ -246,7 +246,7 @@ tls_full_cert_ttl_secs = 90 [access] replay_check_len = 65536 -replay_window_secs = 1800 +replay_window_secs = 120 ignore_time_skew = false [access.users] diff --git a/src/config/defaults.rs b/src/config/defaults.rs index be540b0..e3d729c 100644 --- a/src/config/defaults.rs +++ b/src/config/defaults.rs @@ -86,13 +86,31 @@ pub(crate) fn default_replay_check_len() -> usize { } pub(crate) fn default_replay_window_secs() -> u64 { - 1800 + // Keep replay cache TTL tight by default to reduce replay surface. + // Deployments with higher RTT or longer reconnect jitter can override this in config. + 120 } pub(crate) fn default_handshake_timeout() -> u64 { 30 } +pub(crate) fn default_relay_idle_policy_v2_enabled() -> bool { + true +} + +pub(crate) fn default_relay_client_idle_soft_secs() -> u64 { + 120 +} + +pub(crate) fn default_relay_client_idle_hard_secs() -> u64 { + 360 +} + +pub(crate) fn default_relay_idle_grace_after_downstream_activity_secs() -> u64 { + 30 +} + pub(crate) fn default_connect_timeout() -> u64 { 10 } @@ -485,17 +503,49 @@ pub(crate) fn default_tls_full_cert_ttl_secs() -> u64 { } pub(crate) fn default_server_hello_delay_min_ms() -> u64 { - 0 + 8 } pub(crate) fn default_server_hello_delay_max_ms() -> u64 { - 0 + 24 } pub(crate) fn default_alpn_enforce() -> bool { true } +pub(crate) fn default_mask_shape_hardening() -> bool { + true +} + +pub(crate) fn default_mask_shape_bucket_floor_bytes() -> usize { + 512 +} + +pub(crate) fn default_mask_shape_bucket_cap_bytes() -> usize { + 4096 +} + +pub(crate) fn default_mask_shape_above_cap_blur() -> bool { + false +} + +pub(crate) fn default_mask_shape_above_cap_blur_max_bytes() -> usize { + 512 +} + +pub(crate) fn default_mask_timing_normalization_enabled() -> bool { + false +} + +pub(crate) fn default_mask_timing_normalization_floor_ms() -> u64 { + 0 +} + +pub(crate) fn default_mask_timing_normalization_ceiling_ms() -> u64 { + 0 +} + pub(crate) fn default_stun_servers() -> Vec { vec![ "stun.l.google.com:5349".to_string(), diff --git a/src/config/hot_reload.rs b/src/config/hot_reload.rs index 4cf7676..10fc976 100644 --- a/src/config/hot_reload.rs +++ b/src/config/hot_reload.rs @@ -37,9 +37,7 @@ use crate::config::{ }; use super::load::{LoadedConfig, ProxyConfig}; -const HOT_RELOAD_STABLE_SNAPSHOTS: u8 = 2; const HOT_RELOAD_DEBOUNCE: Duration = Duration::from_millis(50); -const HOT_RELOAD_STABLE_RECHECK: Duration = Duration::from_millis(75); // ── Hot fields ──────────────────────────────────────────────────────────────── @@ -58,11 +56,6 @@ pub struct HotFields { pub me_pool_drain_ttl_secs: u64, pub me_instadrain: bool, pub me_pool_drain_threshold: u64, - pub me_pool_drain_soft_evict_enabled: bool, - pub me_pool_drain_soft_evict_grace_secs: u64, - pub me_pool_drain_soft_evict_per_writer: u8, - pub me_pool_drain_soft_evict_budget_per_core: u16, - pub me_pool_drain_soft_evict_cooldown_ms: u64, pub me_pool_min_fresh_ratio: f32, pub me_reinit_drain_timeout_secs: u64, pub me_hardswap_warmup_delay_min_ms: u64, @@ -146,15 +139,6 @@ impl HotFields { me_pool_drain_ttl_secs: cfg.general.me_pool_drain_ttl_secs, me_instadrain: cfg.general.me_instadrain, me_pool_drain_threshold: cfg.general.me_pool_drain_threshold, - me_pool_drain_soft_evict_enabled: cfg.general.me_pool_drain_soft_evict_enabled, - me_pool_drain_soft_evict_grace_secs: cfg.general.me_pool_drain_soft_evict_grace_secs, - me_pool_drain_soft_evict_per_writer: cfg.general.me_pool_drain_soft_evict_per_writer, - me_pool_drain_soft_evict_budget_per_core: cfg - .general - .me_pool_drain_soft_evict_budget_per_core, - me_pool_drain_soft_evict_cooldown_ms: cfg - .general - .me_pool_drain_soft_evict_cooldown_ms, me_pool_min_fresh_ratio: cfg.general.me_pool_min_fresh_ratio, me_reinit_drain_timeout_secs: cfg.general.me_reinit_drain_timeout_secs, me_hardswap_warmup_delay_min_ms: cfg.general.me_hardswap_warmup_delay_min_ms, @@ -346,49 +330,19 @@ impl WatchManifest { #[derive(Debug, Default)] struct ReloadState { applied_snapshot_hash: Option, - candidate_snapshot_hash: Option, - candidate_hits: u8, } impl ReloadState { fn new(applied_snapshot_hash: Option) -> Self { - Self { - applied_snapshot_hash, - candidate_snapshot_hash: None, - candidate_hits: 0, - } + Self { applied_snapshot_hash } } fn is_applied(&self, hash: u64) -> bool { self.applied_snapshot_hash == Some(hash) } - fn observe_candidate(&mut self, hash: u64) -> u8 { - if self.candidate_snapshot_hash == Some(hash) { - self.candidate_hits = self.candidate_hits.saturating_add(1); - } else { - self.candidate_snapshot_hash = Some(hash); - self.candidate_hits = 1; - } - self.candidate_hits - } - - fn reset_candidate(&mut self) { - self.candidate_snapshot_hash = None; - self.candidate_hits = 0; - } - fn mark_applied(&mut self, hash: u64) { self.applied_snapshot_hash = Some(hash); - self.reset_candidate(); - } - - fn pending_candidate(&self) -> Option<(u64, u8)> { - let hash = self.candidate_snapshot_hash?; - if self.candidate_hits < HOT_RELOAD_STABLE_SNAPSHOTS { - return Some((hash, self.candidate_hits)); - } - None } } @@ -481,15 +435,6 @@ fn overlay_hot_fields(old: &ProxyConfig, new: &ProxyConfig) -> ProxyConfig { cfg.general.me_pool_drain_ttl_secs = new.general.me_pool_drain_ttl_secs; cfg.general.me_instadrain = new.general.me_instadrain; cfg.general.me_pool_drain_threshold = new.general.me_pool_drain_threshold; - cfg.general.me_pool_drain_soft_evict_enabled = new.general.me_pool_drain_soft_evict_enabled; - cfg.general.me_pool_drain_soft_evict_grace_secs = - new.general.me_pool_drain_soft_evict_grace_secs; - cfg.general.me_pool_drain_soft_evict_per_writer = - new.general.me_pool_drain_soft_evict_per_writer; - cfg.general.me_pool_drain_soft_evict_budget_per_core = - new.general.me_pool_drain_soft_evict_budget_per_core; - cfg.general.me_pool_drain_soft_evict_cooldown_ms = - new.general.me_pool_drain_soft_evict_cooldown_ms; cfg.general.me_pool_min_fresh_ratio = new.general.me_pool_min_fresh_ratio; cfg.general.me_reinit_drain_timeout_secs = new.general.me_reinit_drain_timeout_secs; cfg.general.me_hardswap_warmup_delay_min_ms = new.general.me_hardswap_warmup_delay_min_ms; @@ -615,8 +560,6 @@ fn warn_non_hot_changes(old: &ProxyConfig, new: &ProxyConfig, non_hot_changed: b || old.server.listen_tcp != new.server.listen_tcp || old.server.listen_unix_sock != new.server.listen_unix_sock || old.server.listen_unix_sock_perm != new.server.listen_unix_sock_perm - || old.server.max_connections != new.server.max_connections - || old.server.accept_permit_timeout_ms != new.server.accept_permit_timeout_ms { warned = true; warn!("config reload: server listener settings changed; restart required"); @@ -637,6 +580,21 @@ fn warn_non_hot_changes(old: &ProxyConfig, new: &ProxyConfig, non_hot_changed: b || old.censorship.tls_full_cert_ttl_secs != new.censorship.tls_full_cert_ttl_secs || old.censorship.alpn_enforce != new.censorship.alpn_enforce || old.censorship.mask_proxy_protocol != new.censorship.mask_proxy_protocol + || old.censorship.mask_shape_hardening != new.censorship.mask_shape_hardening + || old.censorship.mask_shape_bucket_floor_bytes + != new.censorship.mask_shape_bucket_floor_bytes + || old.censorship.mask_shape_bucket_cap_bytes + != new.censorship.mask_shape_bucket_cap_bytes + || old.censorship.mask_shape_above_cap_blur + != new.censorship.mask_shape_above_cap_blur + || old.censorship.mask_shape_above_cap_blur_max_bytes + != new.censorship.mask_shape_above_cap_blur_max_bytes + || old.censorship.mask_timing_normalization_enabled + != new.censorship.mask_timing_normalization_enabled + || old.censorship.mask_timing_normalization_floor_ms + != new.censorship.mask_timing_normalization_floor_ms + || old.censorship.mask_timing_normalization_ceiling_ms + != new.censorship.mask_timing_normalization_ceiling_ms { warned = true; warn!("config reload: censorship settings changed; restart required"); @@ -677,9 +635,6 @@ fn warn_non_hot_changes(old: &ProxyConfig, new: &ProxyConfig, non_hot_changed: b } if old.general.me_route_no_writer_mode != new.general.me_route_no_writer_mode || old.general.me_route_no_writer_wait_ms != new.general.me_route_no_writer_wait_ms - || old.general.me_route_hybrid_max_wait_ms != new.general.me_route_hybrid_max_wait_ms - || old.general.me_route_blocking_send_timeout_ms - != new.general.me_route_blocking_send_timeout_ms || old.general.me_route_inline_recovery_attempts != new.general.me_route_inline_recovery_attempts || old.general.me_route_inline_recovery_wait_ms @@ -688,10 +643,6 @@ fn warn_non_hot_changes(old: &ProxyConfig, new: &ProxyConfig, non_hot_changed: b warned = true; warn!("config reload: general.me_route_no_writer_* changed; restart required"); } - if old.general.me_c2me_send_timeout_ms != new.general.me_c2me_send_timeout_ms { - warned = true; - warn!("config reload: general.me_c2me_send_timeout_ms changed; restart required"); - } if old.general.unknown_dc_log_path != new.general.unknown_dc_log_path || old.general.unknown_dc_file_log_enabled != new.general.unknown_dc_file_log_enabled { @@ -886,25 +837,6 @@ fn log_changes( old_hot.me_pool_drain_threshold, new_hot.me_pool_drain_threshold, ); } - if old_hot.me_pool_drain_soft_evict_enabled != new_hot.me_pool_drain_soft_evict_enabled - || old_hot.me_pool_drain_soft_evict_grace_secs - != new_hot.me_pool_drain_soft_evict_grace_secs - || old_hot.me_pool_drain_soft_evict_per_writer - != new_hot.me_pool_drain_soft_evict_per_writer - || old_hot.me_pool_drain_soft_evict_budget_per_core - != new_hot.me_pool_drain_soft_evict_budget_per_core - || old_hot.me_pool_drain_soft_evict_cooldown_ms - != new_hot.me_pool_drain_soft_evict_cooldown_ms - { - info!( - "config reload: me_pool_drain_soft_evict: enabled={} grace={}s per_writer={} budget_per_core={} cooldown={}ms", - new_hot.me_pool_drain_soft_evict_enabled, - new_hot.me_pool_drain_soft_evict_grace_secs, - new_hot.me_pool_drain_soft_evict_per_writer, - new_hot.me_pool_drain_soft_evict_budget_per_core, - new_hot.me_pool_drain_soft_evict_cooldown_ms - ); - } if (old_hot.me_pool_min_fresh_ratio - new_hot.me_pool_min_fresh_ratio).abs() > f32::EPSILON { info!( @@ -1208,7 +1140,6 @@ fn reload_config( let loaded = match ProxyConfig::load_with_metadata(config_path) { Ok(loaded) => loaded, Err(e) => { - reload_state.reset_candidate(); error!("config reload: failed to parse {:?}: {}", config_path, e); return None; } @@ -1221,7 +1152,6 @@ fn reload_config( let next_manifest = WatchManifest::from_source_files(&source_files); if let Err(e) = new_cfg.validate() { - reload_state.reset_candidate(); error!("config reload: validation failed: {}; keeping old config", e); return Some(next_manifest); } @@ -1230,17 +1160,6 @@ fn reload_config( return Some(next_manifest); } - let candidate_hits = reload_state.observe_candidate(rendered_hash); - if candidate_hits < HOT_RELOAD_STABLE_SNAPSHOTS { - info!( - snapshot_hash = rendered_hash, - candidate_hits, - required_hits = HOT_RELOAD_STABLE_SNAPSHOTS, - "config reload: candidate snapshot observed but not stable yet" - ); - return Some(next_manifest); - } - let old_cfg = config_tx.borrow().clone(); let applied_cfg = overlay_hot_fields(&old_cfg, &new_cfg); let old_hot = HotFields::from_config(&old_cfg); @@ -1260,7 +1179,6 @@ fn reload_config( if old_hot.dns_overrides != applied_hot.dns_overrides && let Err(e) = crate::network::dns_overrides::install_entries(&applied_hot.dns_overrides) { - reload_state.reset_candidate(); error!( "config reload: invalid network.dns_overrides: {}; keeping old config", e @@ -1281,73 +1199,6 @@ fn reload_config( Some(next_manifest) } -async fn reload_with_internal_stable_rechecks( - config_path: &PathBuf, - config_tx: &watch::Sender>, - log_tx: &watch::Sender, - detected_ip_v4: Option, - detected_ip_v6: Option, - reload_state: &mut ReloadState, -) -> Option { - let mut next_manifest = reload_config( - config_path, - config_tx, - log_tx, - detected_ip_v4, - detected_ip_v6, - reload_state, - ); - let mut rechecks_left = HOT_RELOAD_STABLE_SNAPSHOTS.saturating_sub(1); - - while rechecks_left > 0 { - let Some((snapshot_hash, candidate_hits)) = reload_state.pending_candidate() else { - break; - }; - - info!( - snapshot_hash, - candidate_hits, - required_hits = HOT_RELOAD_STABLE_SNAPSHOTS, - rechecks_left, - recheck_delay_ms = HOT_RELOAD_STABLE_RECHECK.as_millis(), - "config reload: scheduling internal stable recheck" - ); - tokio::time::sleep(HOT_RELOAD_STABLE_RECHECK).await; - - let recheck_manifest = reload_config( - config_path, - config_tx, - log_tx, - detected_ip_v4, - detected_ip_v6, - reload_state, - ); - if recheck_manifest.is_some() { - next_manifest = recheck_manifest; - } - - if reload_state.is_applied(snapshot_hash) { - info!( - snapshot_hash, - "config reload: applied after internal stable recheck" - ); - break; - } - - if reload_state.pending_candidate().is_none() { - info!( - snapshot_hash, - "config reload: internal stable recheck aborted" - ); - break; - } - - rechecks_left = rechecks_left.saturating_sub(1); - } - - next_manifest -} - // ── Public API ──────────────────────────────────────────────────────────────── /// Spawn the hot-reload watcher task. @@ -1471,16 +1322,28 @@ pub fn spawn_config_watcher( tokio::time::sleep(HOT_RELOAD_DEBOUNCE).await; while notify_rx.try_recv().is_ok() {} - if let Some(next_manifest) = reload_with_internal_stable_rechecks( + let mut next_manifest = reload_config( &config_path, &config_tx, &log_tx, detected_ip_v4, detected_ip_v6, &mut reload_state, - ) - .await - { + ); + if next_manifest.is_none() { + tokio::time::sleep(HOT_RELOAD_DEBOUNCE).await; + while notify_rx.try_recv().is_ok() {} + next_manifest = reload_config( + &config_path, + &config_tx, + &log_tx, + detected_ip_v4, + detected_ip_v6, + &mut reload_state, + ); + } + + if let Some(next_manifest) = next_manifest { apply_watch_manifest( inotify_watcher.as_mut(), poll_watcher.as_mut(), @@ -1605,7 +1468,7 @@ mod tests { } #[test] - fn reload_requires_stable_snapshot_before_hot_apply() { + fn reload_applies_hot_change_on_first_observed_snapshot() { let initial_tag = "11111111111111111111111111111111"; let final_tag = "22222222222222222222222222222222"; let path = temp_config_path("telemt_hot_reload_stable"); @@ -1617,55 +1480,13 @@ mod tests { let (log_tx, _log_rx) = watch::channel(initial_cfg.general.log_level.clone()); let mut reload_state = ReloadState::new(Some(initial_hash)); - write_reload_config(&path, None, None); - reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap(); - assert_eq!( - config_tx.borrow().general.ad_tag.as_deref(), - Some(initial_tag) - ); - write_reload_config(&path, Some(final_tag), None); - reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap(); - assert_eq!( - config_tx.borrow().general.ad_tag.as_deref(), - Some(initial_tag) - ); - reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap(); assert_eq!(config_tx.borrow().general.ad_tag.as_deref(), Some(final_tag)); let _ = std::fs::remove_file(path); } - #[tokio::test] - async fn reload_cycle_applies_after_single_external_event() { - let initial_tag = "10101010101010101010101010101010"; - let final_tag = "20202020202020202020202020202020"; - let path = temp_config_path("telemt_hot_reload_single_event"); - - write_reload_config(&path, Some(initial_tag), None); - let initial_cfg = Arc::new(ProxyConfig::load(&path).unwrap()); - let initial_hash = ProxyConfig::load_with_metadata(&path).unwrap().rendered_hash; - let (config_tx, _config_rx) = watch::channel(initial_cfg.clone()); - let (log_tx, _log_rx) = watch::channel(initial_cfg.general.log_level.clone()); - let mut reload_state = ReloadState::new(Some(initial_hash)); - - write_reload_config(&path, Some(final_tag), None); - reload_with_internal_stable_rechecks( - &path, - &config_tx, - &log_tx, - None, - None, - &mut reload_state, - ) - .await - .unwrap(); - - assert_eq!(config_tx.borrow().general.ad_tag.as_deref(), Some(final_tag)); - let _ = std::fs::remove_file(path); - } - #[test] fn reload_keeps_hot_apply_when_non_hot_fields_change() { let initial_tag = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"; @@ -1681,7 +1502,6 @@ mod tests { write_reload_config(&path, Some(final_tag), Some(initial_cfg.server.port + 1)); reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap(); - reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap(); let applied = config_tx.borrow().clone(); assert_eq!(applied.general.ad_tag.as_deref(), Some(final_tag)); @@ -1689,4 +1509,31 @@ mod tests { let _ = std::fs::remove_file(path); } + + #[test] + fn reload_recovers_after_parse_error_on_next_attempt() { + let initial_tag = "cccccccccccccccccccccccccccccccc"; + let final_tag = "dddddddddddddddddddddddddddddddd"; + let path = temp_config_path("telemt_hot_reload_parse_recovery"); + + write_reload_config(&path, Some(initial_tag), None); + let initial_cfg = Arc::new(ProxyConfig::load(&path).unwrap()); + let initial_hash = ProxyConfig::load_with_metadata(&path).unwrap().rendered_hash; + let (config_tx, _config_rx) = watch::channel(initial_cfg.clone()); + let (log_tx, _log_rx) = watch::channel(initial_cfg.general.log_level.clone()); + let mut reload_state = ReloadState::new(Some(initial_hash)); + + std::fs::write(&path, "[access.users\nuser = \"broken\"\n").unwrap(); + assert!(reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).is_none()); + assert_eq!( + config_tx.borrow().general.ad_tag.as_deref(), + Some(initial_tag) + ); + + write_reload_config(&path, Some(final_tag), None); + reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap(); + assert_eq!(config_tx.borrow().general.ad_tag.as_deref(), Some(final_tag)); + + let _ = std::fs::remove_file(path); + } } diff --git a/src/config/load.rs b/src/config/load.rs index c797637..2c50f4e 100644 --- a/src/config/load.rs +++ b/src/config/load.rs @@ -5,7 +5,7 @@ use std::hash::{DefaultHasher, Hash, Hasher}; use std::net::{IpAddr, SocketAddr}; use std::path::{Path, PathBuf}; -use rand::Rng; +use rand::RngExt; use serde::{Deserialize, Serialize}; use shadowsocks::config::ServerConfig as ShadowsocksServerConfig; use tracing::warn; @@ -360,6 +360,125 @@ impl ProxyConfig { )); } + if config.timeouts.client_handshake == 0 { + return Err(ProxyError::Config( + "timeouts.client_handshake must be > 0".to_string(), + )); + } + + let handshake_timeout_ms = config + .timeouts + .client_handshake + .checked_mul(1000) + .ok_or_else(|| { + ProxyError::Config( + "timeouts.client_handshake is too large to validate milliseconds budget" + .to_string(), + ) + })?; + + if config.censorship.server_hello_delay_max_ms >= handshake_timeout_ms { + return Err(ProxyError::Config( + "censorship.server_hello_delay_max_ms must be < timeouts.client_handshake * 1000" + .to_string(), + )); + } + + if config.censorship.mask_shape_bucket_floor_bytes == 0 { + return Err(ProxyError::Config( + "censorship.mask_shape_bucket_floor_bytes must be > 0".to_string(), + )); + } + + if config.censorship.mask_shape_bucket_cap_bytes + < config.censorship.mask_shape_bucket_floor_bytes + { + return Err(ProxyError::Config( + "censorship.mask_shape_bucket_cap_bytes must be >= censorship.mask_shape_bucket_floor_bytes" + .to_string(), + )); + } + + if config.censorship.mask_shape_above_cap_blur + && !config.censorship.mask_shape_hardening + { + return Err(ProxyError::Config( + "censorship.mask_shape_above_cap_blur requires censorship.mask_shape_hardening = true" + .to_string(), + )); + } + + if config.censorship.mask_shape_above_cap_blur + && config.censorship.mask_shape_above_cap_blur_max_bytes == 0 + { + return Err(ProxyError::Config( + "censorship.mask_shape_above_cap_blur_max_bytes must be > 0 when censorship.mask_shape_above_cap_blur is enabled" + .to_string(), + )); + } + + if config.censorship.mask_shape_above_cap_blur_max_bytes > 1_048_576 { + return Err(ProxyError::Config( + "censorship.mask_shape_above_cap_blur_max_bytes must be <= 1048576" + .to_string(), + )); + } + + if config.censorship.mask_timing_normalization_ceiling_ms + < config.censorship.mask_timing_normalization_floor_ms + { + return Err(ProxyError::Config( + "censorship.mask_timing_normalization_ceiling_ms must be >= censorship.mask_timing_normalization_floor_ms" + .to_string(), + )); + } + + if config.censorship.mask_timing_normalization_enabled + && config.censorship.mask_timing_normalization_floor_ms == 0 + { + return Err(ProxyError::Config( + "censorship.mask_timing_normalization_floor_ms must be > 0 when censorship.mask_timing_normalization_enabled is true" + .to_string(), + )); + } + + if config.censorship.mask_timing_normalization_ceiling_ms > 60_000 { + return Err(ProxyError::Config( + "censorship.mask_timing_normalization_ceiling_ms must be <= 60000" + .to_string(), + )); + } + + if config.timeouts.relay_client_idle_soft_secs == 0 { + return Err(ProxyError::Config( + "timeouts.relay_client_idle_soft_secs must be > 0".to_string(), + )); + } + + if config.timeouts.relay_client_idle_hard_secs == 0 { + return Err(ProxyError::Config( + "timeouts.relay_client_idle_hard_secs must be > 0".to_string(), + )); + } + + if config.timeouts.relay_client_idle_hard_secs + < config.timeouts.relay_client_idle_soft_secs + { + return Err(ProxyError::Config( + "timeouts.relay_client_idle_hard_secs must be >= timeouts.relay_client_idle_soft_secs" + .to_string(), + )); + } + + if config.timeouts.relay_idle_grace_after_downstream_activity_secs + > config.timeouts.relay_client_idle_hard_secs + { + return Err(ProxyError::Config( + "timeouts.relay_idle_grace_after_downstream_activity_secs must be <= timeouts.relay_client_idle_hard_secs" + .to_string(), + )); + } + if config.general.me_writer_cmd_channel_capacity == 0 { return Err(ProxyError::Config( "general.me_writer_cmd_channel_capacity must be > 0".to_string(), @@ -860,7 +979,7 @@ impl ProxyConfig { if !config.censorship.tls_emulation && config.censorship.fake_cert_len == default_fake_cert_len() { - config.censorship.fake_cert_len = rand::rng().gen_range(1024..4096); + config.censorship.fake_cert_len = rand::rng().random_range(1024..4096); } // Resolve listen_tcp: explicit value wins, otherwise auto-detect. @@ -982,6 +1101,18 @@ impl ProxyConfig { } } +#[cfg(test)] +#[path = "tests/load_idle_policy_tests.rs"] +mod load_idle_policy_tests; + +#[cfg(test)] +#[path = "tests/load_security_tests.rs"] +mod load_security_tests; + +#[cfg(test)] +#[path = "tests/load_mask_shape_security_tests.rs"] +mod load_mask_shape_security_tests; + #[cfg(test)] mod tests { use super::*; diff --git a/src/config/tests/load_idle_policy_tests.rs b/src/config/tests/load_idle_policy_tests.rs new file mode 100644 index 0000000..087fd75 --- /dev/null +++ b/src/config/tests/load_idle_policy_tests.rs @@ -0,0 +1,78 @@ +use super::*; +use std::fs; +use std::path::PathBuf; +use std::time::{SystemTime, UNIX_EPOCH}; + +fn write_temp_config(contents: &str) -> PathBuf { + let nonce = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("system time must be after unix epoch") + .as_nanos(); + let path = std::env::temp_dir().join(format!("telemt-idle-policy-{nonce}.toml")); + fs::write(&path, contents).expect("temp config write must succeed"); + path +} + +fn remove_temp_config(path: &PathBuf) { + let _ = fs::remove_file(path); +} + +#[test] +fn load_rejects_relay_hard_idle_smaller_than_soft_idle_with_clear_error() { + let path = write_temp_config( + r#" +[timeouts] +relay_client_idle_soft_secs = 120 +relay_client_idle_hard_secs = 60 +"#, + ); + + let err = ProxyConfig::load(&path).expect_err("config with hard= timeouts.relay_client_idle_soft_secs"), + "error must explain the violated hard>=soft invariant, got: {msg}" + ); + + remove_temp_config(&path); +} + +#[test] +fn load_rejects_relay_grace_larger_than_hard_idle_with_clear_error() { + let path = write_temp_config( + r#" +[timeouts] +relay_client_idle_soft_secs = 60 +relay_client_idle_hard_secs = 120 +relay_idle_grace_after_downstream_activity_secs = 121 +"#, + ); + + let err = ProxyConfig::load(&path).expect_err("config with grace>hard must fail"); + let msg = err.to_string(); + assert!( + msg.contains("timeouts.relay_idle_grace_after_downstream_activity_secs must be <= timeouts.relay_client_idle_hard_secs"), + "error must explain the violated grace<=hard invariant, got: {msg}" + ); + + remove_temp_config(&path); +} + +#[test] +fn load_rejects_zero_handshake_timeout_with_clear_error() { + let path = write_temp_config( + r#" +[timeouts] +client_handshake = 0 +"#, + ); + + let err = ProxyConfig::load(&path).expect_err("config with zero handshake timeout must fail"); + let msg = err.to_string(); + assert!( + msg.contains("timeouts.client_handshake must be > 0"), + "error must explain that handshake timeout must be positive, got: {msg}" + ); + + remove_temp_config(&path); +} diff --git a/src/config/tests/load_mask_shape_security_tests.rs b/src/config/tests/load_mask_shape_security_tests.rs new file mode 100644 index 0000000..41df0f5 --- /dev/null +++ b/src/config/tests/load_mask_shape_security_tests.rs @@ -0,0 +1,195 @@ +use super::*; +use std::fs; +use std::path::PathBuf; +use std::time::{SystemTime, UNIX_EPOCH}; + +fn write_temp_config(contents: &str) -> PathBuf { + let nonce = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("system time must be after unix epoch") + .as_nanos(); + let path = std::env::temp_dir().join(format!("telemt-load-mask-shape-security-{nonce}.toml")); + fs::write(&path, contents).expect("temp config write must succeed"); + path +} + +fn remove_temp_config(path: &PathBuf) { + let _ = fs::remove_file(path); +} + +#[test] +fn load_rejects_zero_mask_shape_bucket_floor_bytes() { + let path = write_temp_config( + r#" +[censorship] +mask_shape_bucket_floor_bytes = 0 +mask_shape_bucket_cap_bytes = 4096 +"#, + ); + + let err = + ProxyConfig::load(&path).expect_err("zero mask_shape_bucket_floor_bytes must be rejected"); + let msg = err.to_string(); + assert!( + msg.contains("censorship.mask_shape_bucket_floor_bytes must be > 0"), + "error must explain floor>0 invariant, got: {msg}" + ); + + remove_temp_config(&path); +} + +#[test] +fn load_rejects_mask_shape_bucket_cap_less_than_floor() { + let path = write_temp_config( + r#" +[censorship] +mask_shape_bucket_floor_bytes = 1024 +mask_shape_bucket_cap_bytes = 512 +"#, + ); + + let err = + ProxyConfig::load(&path).expect_err("mask_shape_bucket_cap_bytes < floor must be rejected"); + let msg = err.to_string(); + assert!( + msg.contains( + "censorship.mask_shape_bucket_cap_bytes must be >= censorship.mask_shape_bucket_floor_bytes" + ), + "error must explain cap>=floor invariant, got: {msg}" + ); + + remove_temp_config(&path); +} + +#[test] +fn load_accepts_mask_shape_bucket_cap_equal_to_floor() { + let path = write_temp_config( + r#" +[censorship] +mask_shape_hardening = true +mask_shape_bucket_floor_bytes = 1024 +mask_shape_bucket_cap_bytes = 1024 +"#, + ); + + let cfg = ProxyConfig::load(&path).expect("equal cap and floor must be accepted"); + assert!(cfg.censorship.mask_shape_hardening); + assert_eq!(cfg.censorship.mask_shape_bucket_floor_bytes, 1024); + assert_eq!(cfg.censorship.mask_shape_bucket_cap_bytes, 1024); + + remove_temp_config(&path); +} + +#[test] +fn load_rejects_above_cap_blur_when_shape_hardening_disabled() { + let path = write_temp_config( + r#" +[censorship] +mask_shape_hardening = false +mask_shape_above_cap_blur = true +mask_shape_above_cap_blur_max_bytes = 64 +"#, + ); + + let err = ProxyConfig::load(&path) + .expect_err("above-cap blur must require shape hardening enabled"); + let msg = err.to_string(); + assert!( + msg.contains("censorship.mask_shape_above_cap_blur requires censorship.mask_shape_hardening = true"), + "error must explain blur prerequisite, got: {msg}" + ); + + remove_temp_config(&path); +} + +#[test] +fn load_rejects_above_cap_blur_with_zero_max_bytes() { + let path = write_temp_config( + r#" +[censorship] +mask_shape_hardening = true +mask_shape_above_cap_blur = true +mask_shape_above_cap_blur_max_bytes = 0 +"#, + ); + + let err = ProxyConfig::load(&path) + .expect_err("above-cap blur max bytes must be > 0 when enabled"); + let msg = err.to_string(); + assert!( + msg.contains("censorship.mask_shape_above_cap_blur_max_bytes must be > 0 when censorship.mask_shape_above_cap_blur is enabled"), + "error must explain blur max bytes invariant, got: {msg}" + ); + + remove_temp_config(&path); +} + +#[test] +fn load_rejects_timing_normalization_floor_zero_when_enabled() { + let path = write_temp_config( + r#" +[censorship] +mask_timing_normalization_enabled = true +mask_timing_normalization_floor_ms = 0 +mask_timing_normalization_ceiling_ms = 200 +"#, + ); + + let err = ProxyConfig::load(&path) + .expect_err("timing normalization floor must be > 0 when enabled"); + let msg = err.to_string(); + assert!( + msg.contains("censorship.mask_timing_normalization_floor_ms must be > 0 when censorship.mask_timing_normalization_enabled is true"), + "error must explain timing floor invariant, got: {msg}" + ); + + remove_temp_config(&path); +} + +#[test] +fn load_rejects_timing_normalization_ceiling_below_floor() { + let path = write_temp_config( + r#" +[censorship] +mask_timing_normalization_enabled = true +mask_timing_normalization_floor_ms = 220 +mask_timing_normalization_ceiling_ms = 200 +"#, + ); + + let err = ProxyConfig::load(&path) + .expect_err("timing normalization ceiling must be >= floor"); + let msg = err.to_string(); + assert!( + msg.contains("censorship.mask_timing_normalization_ceiling_ms must be >= censorship.mask_timing_normalization_floor_ms"), + "error must explain timing ceiling/floor invariant, got: {msg}" + ); + + remove_temp_config(&path); +} + +#[test] +fn load_accepts_valid_timing_normalization_and_above_cap_blur_config() { + let path = write_temp_config( + r#" +[censorship] +mask_shape_hardening = true +mask_shape_above_cap_blur = true +mask_shape_above_cap_blur_max_bytes = 128 +mask_timing_normalization_enabled = true +mask_timing_normalization_floor_ms = 150 +mask_timing_normalization_ceiling_ms = 240 +"#, + ); + + let cfg = ProxyConfig::load(&path) + .expect("valid blur and timing normalization settings must be accepted"); + assert!(cfg.censorship.mask_shape_hardening); + assert!(cfg.censorship.mask_shape_above_cap_blur); + assert_eq!(cfg.censorship.mask_shape_above_cap_blur_max_bytes, 128); + assert!(cfg.censorship.mask_timing_normalization_enabled); + assert_eq!(cfg.censorship.mask_timing_normalization_floor_ms, 150); + assert_eq!(cfg.censorship.mask_timing_normalization_ceiling_ms, 240); + + remove_temp_config(&path); +} diff --git a/src/config/tests/load_security_tests.rs b/src/config/tests/load_security_tests.rs new file mode 100644 index 0000000..a1a35ac --- /dev/null +++ b/src/config/tests/load_security_tests.rs @@ -0,0 +1,84 @@ +use super::*; +use std::fs; +use std::path::PathBuf; +use std::time::{SystemTime, UNIX_EPOCH}; + +fn write_temp_config(contents: &str) -> PathBuf { + let nonce = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("system time must be after unix epoch") + .as_nanos(); + let path = std::env::temp_dir().join(format!("telemt-load-security-{nonce}.toml")); + fs::write(&path, contents).expect("temp config write must succeed"); + path +} + +fn remove_temp_config(path: &PathBuf) { + let _ = fs::remove_file(path); +} + +#[test] +fn load_rejects_server_hello_delay_equal_to_handshake_timeout_budget() { + let path = write_temp_config( + r#" +[timeouts] +client_handshake = 1 + +[censorship] +server_hello_delay_max_ms = 1000 +"#, + ); + + let err = ProxyConfig::load(&path) + .expect_err("delay equal to handshake timeout must be rejected"); + let msg = err.to_string(); + assert!( + msg.contains("censorship.server_hello_delay_max_ms must be < timeouts.client_handshake * 1000"), + "error must explain delay, + /// Port for the Prometheus-compatible metrics endpoint. /// Enables metrics when set; binds on all interfaces (dual-stack) by default. #[serde(default)] @@ -1270,6 +1277,7 @@ impl Default for ServerConfig { listen_tcp: None, proxy_protocol: false, proxy_protocol_header_timeout_ms: default_proxy_protocol_header_timeout_ms(), + proxy_protocol_trusted_cidrs: Vec::new(), metrics_port: None, metrics_listen: None, metrics_whitelist: default_metrics_whitelist(), @@ -1286,6 +1294,24 @@ pub struct TimeoutsConfig { #[serde(default = "default_handshake_timeout")] pub client_handshake: u64, + /// Enables soft/hard relay client idle policy for middle-relay sessions. + #[serde(default = "default_relay_idle_policy_v2_enabled")] + pub relay_idle_policy_v2_enabled: bool, + + /// Soft idle threshold for middle-relay client uplink activity in seconds. + /// Hitting this threshold marks the session as idle-candidate, but does not close it. + #[serde(default = "default_relay_client_idle_soft_secs")] + pub relay_client_idle_soft_secs: u64, + + /// Hard idle threshold for middle-relay client uplink activity in seconds. + /// Hitting this threshold closes the session. + #[serde(default = "default_relay_client_idle_hard_secs")] + pub relay_client_idle_hard_secs: u64, + + /// Additional grace in seconds added to hard idle window after recent downstream activity. + #[serde(default = "default_relay_idle_grace_after_downstream_activity_secs")] + pub relay_idle_grace_after_downstream_activity_secs: u64, + #[serde(default = "default_connect_timeout")] pub tg_connect: u64, @@ -1308,6 +1334,11 @@ impl Default for TimeoutsConfig { fn default() -> Self { Self { client_handshake: default_handshake_timeout(), + relay_idle_policy_v2_enabled: default_relay_idle_policy_v2_enabled(), + relay_client_idle_soft_secs: default_relay_client_idle_soft_secs(), + relay_client_idle_hard_secs: default_relay_client_idle_hard_secs(), + relay_idle_grace_after_downstream_activity_secs: + default_relay_idle_grace_after_downstream_activity_secs(), tg_connect: default_connect_timeout(), client_keepalive: default_keepalive(), client_ack: default_ack_timeout(), @@ -1381,6 +1412,40 @@ pub struct AntiCensorshipConfig { /// Allows the backend to see the real client IP. #[serde(default)] pub mask_proxy_protocol: u8, + + /// Enable shape-channel hardening on mask backend path by padding + /// client->mask stream tail to configured buckets on stream end. + #[serde(default = "default_mask_shape_hardening")] + pub mask_shape_hardening: bool, + + /// Minimum bucket size for mask shape hardening padding. + #[serde(default = "default_mask_shape_bucket_floor_bytes")] + pub mask_shape_bucket_floor_bytes: usize, + + /// Maximum bucket size for mask shape hardening padding. + #[serde(default = "default_mask_shape_bucket_cap_bytes")] + pub mask_shape_bucket_cap_bytes: usize, + + /// Add bounded random tail bytes even when total bytes already exceed + /// mask_shape_bucket_cap_bytes. + #[serde(default = "default_mask_shape_above_cap_blur")] + pub mask_shape_above_cap_blur: bool, + + /// Maximum random bytes appended above cap when above-cap blur is enabled. + #[serde(default = "default_mask_shape_above_cap_blur_max_bytes")] + pub mask_shape_above_cap_blur_max_bytes: usize, + + /// Enable outcome-time normalization envelope for masking fallback. + #[serde(default = "default_mask_timing_normalization_enabled")] + pub mask_timing_normalization_enabled: bool, + + /// Lower bound (ms) for masking outcome timing envelope. + #[serde(default = "default_mask_timing_normalization_floor_ms")] + pub mask_timing_normalization_floor_ms: u64, + + /// Upper bound (ms) for masking outcome timing envelope. + #[serde(default = "default_mask_timing_normalization_ceiling_ms")] + pub mask_timing_normalization_ceiling_ms: u64, } impl Default for AntiCensorshipConfig { @@ -1402,6 +1467,14 @@ impl Default for AntiCensorshipConfig { tls_full_cert_ttl_secs: default_tls_full_cert_ttl_secs(), alpn_enforce: default_alpn_enforce(), mask_proxy_protocol: 0, + mask_shape_hardening: default_mask_shape_hardening(), + mask_shape_bucket_floor_bytes: default_mask_shape_bucket_floor_bytes(), + mask_shape_bucket_cap_bytes: default_mask_shape_bucket_cap_bytes(), + mask_shape_above_cap_blur: default_mask_shape_above_cap_blur(), + mask_shape_above_cap_blur_max_bytes: default_mask_shape_above_cap_blur_max_bytes(), + mask_timing_normalization_enabled: default_mask_timing_normalization_enabled(), + mask_timing_normalization_floor_ms: default_mask_timing_normalization_floor_ms(), + mask_timing_normalization_ceiling_ms: default_mask_timing_normalization_ceiling_ms(), } } } diff --git a/src/crypto/random.rs b/src/crypto/random.rs index a88efc6..2f52188 100644 --- a/src/crypto/random.rs +++ b/src/crypto/random.rs @@ -3,7 +3,7 @@ #![allow(deprecated)] #![allow(dead_code)] -use rand::{Rng, RngCore, SeedableRng}; +use rand::{Rng, RngExt, SeedableRng}; use rand::rngs::StdRng; use parking_lot::Mutex; use zeroize::Zeroize; @@ -101,7 +101,7 @@ impl SecureRandom { return 0; } let mut inner = self.inner.lock(); - inner.rng.gen_range(0..max) + inner.rng.random_range(0..max) } /// Generate random bits @@ -141,7 +141,7 @@ impl SecureRandom { pub fn shuffle(&self, slice: &mut [T]) { let mut inner = self.inner.lock(); for i in (1..slice.len()).rev() { - let j = inner.rng.gen_range(0..=i); + let j = inner.rng.random_range(0..=i); slice.swap(i, j); } } diff --git a/src/ip_tracker.rs b/src/ip_tracker.rs index fce20b6..c35c587 100644 --- a/src/ip_tracker.rs +++ b/src/ip_tracker.rs @@ -7,8 +7,9 @@ use std::net::IpAddr; use std::sync::Arc; use std::sync::atomic::{AtomicU64, Ordering}; use std::time::{Duration, Instant}; +use std::sync::Mutex; -use tokio::sync::RwLock; +use tokio::sync::{Mutex as AsyncMutex, RwLock}; use crate::config::UserMaxUniqueIpsMode; @@ -21,6 +22,8 @@ pub struct UserIpTracker { limit_mode: Arc>, limit_window: Arc>, last_compact_epoch_secs: Arc, + pub(crate) cleanup_queue: Arc>>, + cleanup_drain_lock: Arc>, } impl UserIpTracker { @@ -33,6 +36,67 @@ impl UserIpTracker { limit_mode: Arc::new(RwLock::new(UserMaxUniqueIpsMode::ActiveWindow)), limit_window: Arc::new(RwLock::new(Duration::from_secs(30))), last_compact_epoch_secs: Arc::new(AtomicU64::new(0)), + cleanup_queue: Arc::new(Mutex::new(Vec::new())), + cleanup_drain_lock: Arc::new(AsyncMutex::new(())), + } + } + + + pub fn enqueue_cleanup(&self, user: String, ip: IpAddr) { + match self.cleanup_queue.lock() { + Ok(mut queue) => queue.push((user, ip)), + Err(poisoned) => { + let mut queue = poisoned.into_inner(); + queue.push((user.clone(), ip)); + self.cleanup_queue.clear_poison(); + tracing::warn!( + "UserIpTracker cleanup_queue lock poisoned; recovered and enqueued IP cleanup for {} ({})", + user, + ip + ); + } + } + } + + pub(crate) async fn drain_cleanup_queue(&self) { + // Serialize queue draining and active-IP mutation so check-and-add cannot + // observe stale active entries that are already queued for removal. + let _drain_guard = self.cleanup_drain_lock.lock().await; + let to_remove = { + match self.cleanup_queue.lock() { + Ok(mut queue) => { + if queue.is_empty() { + return; + } + std::mem::take(&mut *queue) + } + Err(poisoned) => { + let mut queue = poisoned.into_inner(); + if queue.is_empty() { + self.cleanup_queue.clear_poison(); + return; + } + let drained = std::mem::take(&mut *queue); + self.cleanup_queue.clear_poison(); + drained + } + } + }; + + let mut active_ips = self.active_ips.write().await; + for (user, ip) in to_remove { + if let Some(user_ips) = active_ips.get_mut(&user) { + if let Some(count) = user_ips.get_mut(&ip) { + if *count > 1 { + *count -= 1; + } else { + user_ips.remove(&ip); + } + } + if user_ips.is_empty() { + active_ips.remove(&user); + } + } } } @@ -118,6 +182,7 @@ impl UserIpTracker { } pub async fn check_and_add(&self, username: &str, ip: IpAddr) -> Result<(), String> { + self.drain_cleanup_queue().await; self.maybe_compact_empty_users().await; let default_max_ips = *self.default_max_ips.read().await; let limit = { @@ -194,6 +259,7 @@ impl UserIpTracker { } pub async fn get_recent_counts_for_users(&self, users: &[String]) -> HashMap { + self.drain_cleanup_queue().await; let window = *self.limit_window.read().await; let now = Instant::now(); let recent_ips = self.recent_ips.read().await; @@ -214,6 +280,7 @@ impl UserIpTracker { } pub async fn get_active_ips_for_users(&self, users: &[String]) -> HashMap> { + self.drain_cleanup_queue().await; let active_ips = self.active_ips.read().await; let mut out = HashMap::with_capacity(users.len()); for user in users { @@ -228,6 +295,7 @@ impl UserIpTracker { } pub async fn get_recent_ips_for_users(&self, users: &[String]) -> HashMap> { + self.drain_cleanup_queue().await; let window = *self.limit_window.read().await; let now = Instant::now(); let recent_ips = self.recent_ips.read().await; @@ -250,11 +318,13 @@ impl UserIpTracker { } pub async fn get_active_ip_count(&self, username: &str) -> usize { + self.drain_cleanup_queue().await; let active_ips = self.active_ips.read().await; active_ips.get(username).map(|ips| ips.len()).unwrap_or(0) } pub async fn get_active_ips(&self, username: &str) -> Vec { + self.drain_cleanup_queue().await; let active_ips = self.active_ips.read().await; active_ips .get(username) @@ -263,6 +333,7 @@ impl UserIpTracker { } pub async fn get_stats(&self) -> Vec<(String, usize, usize)> { + self.drain_cleanup_queue().await; let active_ips = self.active_ips.read().await; let max_ips = self.max_ips.read().await; let default_max_ips = *self.default_max_ips.read().await; @@ -301,6 +372,7 @@ impl UserIpTracker { } pub async fn is_ip_active(&self, username: &str, ip: IpAddr) -> bool { + self.drain_cleanup_queue().await; let active_ips = self.active_ips.read().await; active_ips .get(username) diff --git a/src/maestro/helpers.rs b/src/maestro/helpers.rs index f43e308..f916633 100644 --- a/src/maestro/helpers.rs +++ b/src/maestro/helpers.rs @@ -10,6 +10,16 @@ use crate::transport::middle_proxy::{ ProxyConfigData, fetch_proxy_config_with_raw, load_proxy_config_cache, save_proxy_config_cache, }; +pub(crate) fn resolve_runtime_config_path(config_path_cli: &str, startup_cwd: &std::path::Path) -> PathBuf { + let raw = PathBuf::from(config_path_cli); + let absolute = if raw.is_absolute() { + raw + } else { + startup_cwd.join(raw) + }; + absolute.canonicalize().unwrap_or(absolute) +} + pub(crate) fn parse_cli() -> (String, Option, bool, Option) { let mut config_path = "config.toml".to_string(); let mut data_path: Option = None; @@ -96,6 +106,44 @@ pub(crate) fn parse_cli() -> (String, Option, bool, Option) { (config_path, data_path, silent, log_level) } +#[cfg(test)] +mod tests { + use super::resolve_runtime_config_path; + + #[test] + fn resolve_runtime_config_path_anchors_relative_to_startup_cwd() { + let nonce = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_nanos(); + let startup_cwd = std::env::temp_dir().join(format!("telemt_cfg_path_{nonce}")); + std::fs::create_dir_all(&startup_cwd).unwrap(); + let target = startup_cwd.join("config.toml"); + std::fs::write(&target, " ").unwrap(); + + let resolved = resolve_runtime_config_path("config.toml", &startup_cwd); + assert_eq!(resolved, target.canonicalize().unwrap()); + + let _ = std::fs::remove_file(&target); + let _ = std::fs::remove_dir(&startup_cwd); + } + + #[test] + fn resolve_runtime_config_path_keeps_absolute_for_missing_file() { + let nonce = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_nanos(); + let startup_cwd = std::env::temp_dir().join(format!("telemt_cfg_path_missing_{nonce}")); + std::fs::create_dir_all(&startup_cwd).unwrap(); + + let resolved = resolve_runtime_config_path("missing.toml", &startup_cwd); + assert_eq!(resolved, startup_cwd.join("missing.toml")); + + let _ = std::fs::remove_dir(&startup_cwd); + } +} + pub(crate) fn print_proxy_links(host: &str, port: u16, config: &ProxyConfig) { info!(target: "telemt::links", "--- Proxy Links ({}) ---", host); for user_name in config.general.links.show.resolve_users(&config.access.users) { diff --git a/src/maestro/me_startup.rs b/src/maestro/me_startup.rs index eb45cc4..bbe46a8 100644 --- a/src/maestro/me_startup.rs +++ b/src/maestro/me_startup.rs @@ -268,8 +268,6 @@ pub(crate) async fn initialize_me_pool( config.general.me_warn_rate_limit_ms, config.general.me_route_no_writer_mode, config.general.me_route_no_writer_wait_ms, - config.general.me_route_hybrid_max_wait_ms, - config.general.me_route_blocking_send_timeout_ms, config.general.me_route_inline_recovery_attempts, config.general.me_route_inline_recovery_wait_ms, ); @@ -517,7 +515,7 @@ pub(crate) async fn initialize_me_pool( } } }); - + break Some(pool); } Err(e) => { diff --git a/src/maestro/mod.rs b/src/maestro/mod.rs index dce421c..7ba7b39 100644 --- a/src/maestro/mod.rs +++ b/src/maestro/mod.rs @@ -45,7 +45,7 @@ use crate::startup::{ use crate::stream::BufferPool; use crate::transport::middle_proxy::MePool; use crate::transport::UpstreamManager; -use helpers::parse_cli; +use helpers::{parse_cli, resolve_runtime_config_path}; /// Runs the full telemt runtime startup pipeline and blocks until shutdown. pub async fn run() -> std::result::Result<(), Box> { @@ -58,18 +58,26 @@ pub async fn run() -> std::result::Result<(), Box> { startup_tracker .start_component(COMPONENT_CONFIG_LOAD, Some("load and validate config".to_string())) .await; - let (config_path, data_path, cli_silent, cli_log_level) = parse_cli(); + let (config_path_cli, data_path, cli_silent, cli_log_level) = parse_cli(); + let startup_cwd = match std::env::current_dir() { + Ok(cwd) => cwd, + Err(e) => { + eprintln!("[telemt] Can't read current_dir: {}", e); + std::process::exit(1); + } + }; + let config_path = resolve_runtime_config_path(&config_path_cli, &startup_cwd); let mut config = match ProxyConfig::load(&config_path) { Ok(c) => c, Err(e) => { - if std::path::Path::new(&config_path).exists() { + if config_path.exists() { eprintln!("[telemt] Error: {}", e); std::process::exit(1); } else { let default = ProxyConfig::default(); std::fs::write(&config_path, toml::to_string_pretty(&default).unwrap()).unwrap(); - eprintln!("[telemt] Created default config at {}", config_path); + eprintln!("[telemt] Created default config at {}", config_path.display()); default } } @@ -258,7 +266,7 @@ pub async fn run() -> std::result::Result<(), Box> { let route_runtime_api = route_runtime.clone(); let config_rx_api = api_config_rx.clone(); let admission_rx_api = admission_rx.clone(); - let config_path_api = std::path::PathBuf::from(&config_path); + let config_path_api = config_path.clone(); let startup_tracker_api = startup_tracker.clone(); let detected_ips_rx_api = detected_ips_rx.clone(); tokio::spawn(async move { diff --git a/src/maestro/runtime_tasks.rs b/src/maestro/runtime_tasks.rs index d9691a8..c2233c7 100644 --- a/src/maestro/runtime_tasks.rs +++ b/src/maestro/runtime_tasks.rs @@ -1,5 +1,5 @@ use std::net::IpAddr; -use std::path::PathBuf; +use std::path::Path; use std::sync::Arc; use tokio::sync::{mpsc, watch}; @@ -32,7 +32,7 @@ pub(crate) struct RuntimeWatches { #[allow(clippy::too_many_arguments)] pub(crate) async fn spawn_runtime_tasks( config: &Arc, - config_path: &str, + config_path: &Path, probe: &NetworkProbe, prefer_ipv6: bool, decision_ipv4_dc: bool, @@ -83,7 +83,7 @@ pub(crate) async fn spawn_runtime_tasks( watch::Receiver>, watch::Receiver, ) = spawn_config_watcher( - PathBuf::from(config_path), + config_path.to_path_buf(), config.clone(), detected_ip_v4, detected_ip_v6, diff --git a/src/maestro/tls_bootstrap.rs b/src/maestro/tls_bootstrap.rs index 73eec4c..342a2f9 100644 --- a/src/maestro/tls_bootstrap.rs +++ b/src/maestro/tls_bootstrap.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use std::time::Duration; -use rand::Rng; +use rand::RngExt; use tracing::warn; use crate::config::ProxyConfig; diff --git a/src/main.rs b/src/main.rs index 2cfbe28..dff8c8a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -7,7 +7,11 @@ mod crypto; mod error; mod ip_tracker; #[cfg(test)] +#[path = "tests/ip_tracker_regression_tests.rs"] mod ip_tracker_regression_tests; +#[cfg(test)] +#[path = "tests/ip_tracker_hotpath_adversarial_tests.rs"] +mod ip_tracker_hotpath_adversarial_tests; mod maestro; mod metrics; mod network; diff --git a/src/metrics.rs b/src/metrics.rs index b7272b2..b7a16f0 100644 --- a/src/metrics.rs +++ b/src/metrics.rs @@ -16,9 +16,7 @@ use tracing::{info, warn, debug}; use crate::config::ProxyConfig; use crate::ip_tracker::UserIpTracker; use crate::stats::beobachten::BeobachtenStore; -use crate::stats::{ - MeWriterCleanupSideEffectStep, MeWriterTeardownMode, MeWriterTeardownReason, Stats, -}; +use crate::stats::Stats; use crate::transport::{ListenOptions, create_listener}; pub async fn serve( @@ -294,109 +292,6 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp "telemt_connections_bad_total {}", if core_enabled { stats.get_connects_bad() } else { 0 } ); - let _ = writeln!(out, "# HELP telemt_connections_current Current active connections"); - let _ = writeln!(out, "# TYPE telemt_connections_current gauge"); - let _ = writeln!( - out, - "telemt_connections_current {}", - if core_enabled { - stats.get_current_connections_total() - } else { - 0 - } - ); - let _ = writeln!(out, "# HELP telemt_connections_direct_current Current active direct connections"); - let _ = writeln!(out, "# TYPE telemt_connections_direct_current gauge"); - let _ = writeln!( - out, - "telemt_connections_direct_current {}", - if core_enabled { - stats.get_current_connections_direct() - } else { - 0 - } - ); - let _ = writeln!(out, "# HELP telemt_connections_me_current Current active middle-end connections"); - let _ = writeln!(out, "# TYPE telemt_connections_me_current gauge"); - let _ = writeln!( - out, - "telemt_connections_me_current {}", - if core_enabled { - stats.get_current_connections_me() - } else { - 0 - } - ); - let _ = writeln!( - out, - "# HELP telemt_relay_adaptive_promotions_total Adaptive relay tier promotions" - ); - let _ = writeln!(out, "# TYPE telemt_relay_adaptive_promotions_total counter"); - let _ = writeln!( - out, - "telemt_relay_adaptive_promotions_total {}", - if core_enabled { - stats.get_relay_adaptive_promotions_total() - } else { - 0 - } - ); - let _ = writeln!( - out, - "# HELP telemt_relay_adaptive_demotions_total Adaptive relay tier demotions" - ); - let _ = writeln!(out, "# TYPE telemt_relay_adaptive_demotions_total counter"); - let _ = writeln!( - out, - "telemt_relay_adaptive_demotions_total {}", - if core_enabled { - stats.get_relay_adaptive_demotions_total() - } else { - 0 - } - ); - let _ = writeln!( - out, - "# HELP telemt_relay_adaptive_hard_promotions_total Adaptive relay hard promotions triggered by write pressure" - ); - let _ = writeln!( - out, - "# TYPE telemt_relay_adaptive_hard_promotions_total counter" - ); - let _ = writeln!( - out, - "telemt_relay_adaptive_hard_promotions_total {}", - if core_enabled { - stats.get_relay_adaptive_hard_promotions_total() - } else { - 0 - } - ); - let _ = writeln!(out, "# HELP telemt_reconnect_evict_total Reconnect-driven session evictions"); - let _ = writeln!(out, "# TYPE telemt_reconnect_evict_total counter"); - let _ = writeln!( - out, - "telemt_reconnect_evict_total {}", - if core_enabled { - stats.get_reconnect_evict_total() - } else { - 0 - } - ); - let _ = writeln!( - out, - "# HELP telemt_reconnect_stale_close_total Sessions closed because they became stale after reconnect" - ); - let _ = writeln!(out, "# TYPE telemt_reconnect_stale_close_total counter"); - let _ = writeln!( - out, - "telemt_reconnect_stale_close_total {}", - if core_enabled { - stats.get_reconnect_stale_close_total() - } else { - 0 - } - ); let _ = writeln!(out, "# HELP telemt_handshake_timeouts_total Handshake timeouts"); let _ = writeln!(out, "# TYPE telemt_handshake_timeouts_total counter"); @@ -810,6 +705,69 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp } ); + let _ = writeln!( + out, + "# HELP telemt_relay_idle_soft_mark_total Middle-relay sessions marked as soft-idle candidates" + ); + let _ = writeln!(out, "# TYPE telemt_relay_idle_soft_mark_total counter"); + let _ = writeln!( + out, + "telemt_relay_idle_soft_mark_total {}", + if me_allows_normal { + stats.get_relay_idle_soft_mark_total() + } else { + 0 + } + ); + + let _ = writeln!( + out, + "# HELP telemt_relay_idle_hard_close_total Middle-relay sessions closed by hard-idle policy" + ); + let _ = writeln!(out, "# TYPE telemt_relay_idle_hard_close_total counter"); + let _ = writeln!( + out, + "telemt_relay_idle_hard_close_total {}", + if me_allows_normal { + stats.get_relay_idle_hard_close_total() + } else { + 0 + } + ); + + let _ = writeln!( + out, + "# HELP telemt_relay_pressure_evict_total Middle-relay sessions evicted under resource pressure" + ); + let _ = writeln!(out, "# TYPE telemt_relay_pressure_evict_total counter"); + let _ = writeln!( + out, + "telemt_relay_pressure_evict_total {}", + if me_allows_normal { + stats.get_relay_pressure_evict_total() + } else { + 0 + } + ); + + let _ = writeln!( + out, + "# HELP telemt_relay_protocol_desync_close_total Middle-relay sessions closed due to protocol desync" + ); + let _ = writeln!( + out, + "# TYPE telemt_relay_protocol_desync_close_total counter" + ); + let _ = writeln!( + out, + "telemt_relay_protocol_desync_close_total {}", + if me_allows_normal { + stats.get_relay_protocol_desync_close_total() + } else { + 0 + } + ); + let _ = writeln!(out, "# HELP telemt_me_crc_mismatch_total ME CRC mismatches"); let _ = writeln!(out, "# TYPE telemt_me_crc_mismatch_total counter"); let _ = writeln!( @@ -1652,36 +1610,6 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp } ); - let _ = writeln!( - out, - "# HELP telemt_pool_drain_soft_evict_total Soft-evicted client sessions on stuck draining writers" - ); - let _ = writeln!(out, "# TYPE telemt_pool_drain_soft_evict_total counter"); - let _ = writeln!( - out, - "telemt_pool_drain_soft_evict_total {}", - if me_allows_normal { - stats.get_pool_drain_soft_evict_total() - } else { - 0 - } - ); - - let _ = writeln!( - out, - "# HELP telemt_pool_drain_soft_evict_writer_total Draining writers with at least one soft eviction" - ); - let _ = writeln!(out, "# TYPE telemt_pool_drain_soft_evict_writer_total counter"); - let _ = writeln!( - out, - "telemt_pool_drain_soft_evict_writer_total {}", - if me_allows_normal { - stats.get_pool_drain_soft_evict_writer_total() - } else { - 0 - } - ); - let _ = writeln!(out, "# HELP telemt_pool_stale_pick_total Stale writer fallback picks for new binds"); let _ = writeln!(out, "# TYPE telemt_pool_stale_pick_total counter"); let _ = writeln!( @@ -1694,57 +1622,6 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp } ); - let _ = writeln!( - out, - "# HELP telemt_me_writer_close_signal_drop_total Close-signal drops for already-removed ME writers" - ); - let _ = writeln!(out, "# TYPE telemt_me_writer_close_signal_drop_total counter"); - let _ = writeln!( - out, - "telemt_me_writer_close_signal_drop_total {}", - if me_allows_normal { - stats.get_me_writer_close_signal_drop_total() - } else { - 0 - } - ); - - let _ = writeln!( - out, - "# HELP telemt_me_writer_close_signal_channel_full_total Close-signal drops caused by full writer command channels" - ); - let _ = writeln!( - out, - "# TYPE telemt_me_writer_close_signal_channel_full_total counter" - ); - let _ = writeln!( - out, - "telemt_me_writer_close_signal_channel_full_total {}", - if me_allows_normal { - stats.get_me_writer_close_signal_channel_full_total() - } else { - 0 - } - ); - - let _ = writeln!( - out, - "# HELP telemt_me_draining_writers_reap_progress_total Draining-writer removals processed by reap cleanup" - ); - let _ = writeln!( - out, - "# TYPE telemt_me_draining_writers_reap_progress_total counter" - ); - let _ = writeln!( - out, - "telemt_me_draining_writers_reap_progress_total {}", - if me_allows_normal { - stats.get_me_draining_writers_reap_progress_total() - } else { - 0 - } - ); - let _ = writeln!(out, "# HELP telemt_me_writer_removed_total Total ME writer removals"); let _ = writeln!(out, "# TYPE telemt_me_writer_removed_total counter"); let _ = writeln!( @@ -1772,169 +1649,6 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp } ); - let _ = writeln!( - out, - "# HELP telemt_me_writer_teardown_attempt_total ME writer teardown attempts by reason and mode" - ); - let _ = writeln!(out, "# TYPE telemt_me_writer_teardown_attempt_total counter"); - for reason in MeWriterTeardownReason::ALL { - for mode in MeWriterTeardownMode::ALL { - let _ = writeln!( - out, - "telemt_me_writer_teardown_attempt_total{{reason=\"{}\",mode=\"{}\"}} {}", - reason.as_str(), - mode.as_str(), - if me_allows_normal { - stats.get_me_writer_teardown_attempt_total(reason, mode) - } else { - 0 - } - ); - } - } - - let _ = writeln!( - out, - "# HELP telemt_me_writer_teardown_success_total ME writer teardown successes by mode" - ); - let _ = writeln!(out, "# TYPE telemt_me_writer_teardown_success_total counter"); - for mode in MeWriterTeardownMode::ALL { - let _ = writeln!( - out, - "telemt_me_writer_teardown_success_total{{mode=\"{}\"}} {}", - mode.as_str(), - if me_allows_normal { - stats.get_me_writer_teardown_success_total(mode) - } else { - 0 - } - ); - } - - let _ = writeln!( - out, - "# HELP telemt_me_writer_teardown_timeout_total Teardown operations that timed out" - ); - let _ = writeln!(out, "# TYPE telemt_me_writer_teardown_timeout_total counter"); - let _ = writeln!( - out, - "telemt_me_writer_teardown_timeout_total {}", - if me_allows_normal { - stats.get_me_writer_teardown_timeout_total() - } else { - 0 - } - ); - - let _ = writeln!( - out, - "# HELP telemt_me_writer_teardown_escalation_total Watchdog teardown escalations to hard detach" - ); - let _ = writeln!( - out, - "# TYPE telemt_me_writer_teardown_escalation_total counter" - ); - let _ = writeln!( - out, - "telemt_me_writer_teardown_escalation_total {}", - if me_allows_normal { - stats.get_me_writer_teardown_escalation_total() - } else { - 0 - } - ); - - let _ = writeln!( - out, - "# HELP telemt_me_writer_teardown_noop_total Teardown operations that became no-op" - ); - let _ = writeln!(out, "# TYPE telemt_me_writer_teardown_noop_total counter"); - let _ = writeln!( - out, - "telemt_me_writer_teardown_noop_total {}", - if me_allows_normal { - stats.get_me_writer_teardown_noop_total() - } else { - 0 - } - ); - - let _ = writeln!( - out, - "# HELP telemt_me_writer_teardown_duration_seconds ME writer teardown latency histogram by mode" - ); - let _ = writeln!( - out, - "# TYPE telemt_me_writer_teardown_duration_seconds histogram" - ); - let bucket_labels = Stats::me_writer_teardown_duration_bucket_labels(); - for mode in MeWriterTeardownMode::ALL { - for (bucket_idx, label) in bucket_labels.iter().enumerate() { - let _ = writeln!( - out, - "telemt_me_writer_teardown_duration_seconds_bucket{{mode=\"{}\",le=\"{}\"}} {}", - mode.as_str(), - label, - if me_allows_normal { - stats.get_me_writer_teardown_duration_bucket_total(mode, bucket_idx) - } else { - 0 - } - ); - } - let _ = writeln!( - out, - "telemt_me_writer_teardown_duration_seconds_bucket{{mode=\"{}\",le=\"+Inf\"}} {}", - mode.as_str(), - if me_allows_normal { - stats.get_me_writer_teardown_duration_count(mode) - } else { - 0 - } - ); - let _ = writeln!( - out, - "telemt_me_writer_teardown_duration_seconds_sum{{mode=\"{}\"}} {:.6}", - mode.as_str(), - if me_allows_normal { - stats.get_me_writer_teardown_duration_sum_seconds(mode) - } else { - 0.0 - } - ); - let _ = writeln!( - out, - "telemt_me_writer_teardown_duration_seconds_count{{mode=\"{}\"}} {}", - mode.as_str(), - if me_allows_normal { - stats.get_me_writer_teardown_duration_count(mode) - } else { - 0 - } - ); - } - - let _ = writeln!( - out, - "# HELP telemt_me_writer_cleanup_side_effect_failures_total Failed cleanup side effects by step" - ); - let _ = writeln!( - out, - "# TYPE telemt_me_writer_cleanup_side_effect_failures_total counter" - ); - for step in MeWriterCleanupSideEffectStep::ALL { - let _ = writeln!( - out, - "telemt_me_writer_cleanup_side_effect_failures_total{{step=\"{}\"}} {}", - step.as_str(), - if me_allows_normal { - stats.get_me_writer_cleanup_side_effect_failures_total(step) - } else { - 0 - } - ); - } - let _ = writeln!(out, "# HELP telemt_me_refill_triggered_total Immediate ME refill runs started"); let _ = writeln!(out, "# TYPE telemt_me_refill_triggered_total counter"); let _ = writeln!( @@ -2213,8 +1927,6 @@ mod tests { stats.increment_connects_all(); stats.increment_connects_all(); stats.increment_connects_bad(); - stats.increment_current_connections_direct(); - stats.increment_current_connections_me(); stats.increment_handshake_timeouts(); stats.increment_upstream_connect_attempt_total(); stats.increment_upstream_connect_attempt_total(); @@ -2230,6 +1942,10 @@ mod tests { stats.increment_me_rpc_proxy_req_signal_response_total(); stats.increment_me_rpc_proxy_req_signal_close_sent_total(); stats.increment_me_idle_close_by_peer_total(); + stats.increment_relay_idle_soft_mark_total(); + stats.increment_relay_idle_hard_close_total(); + stats.increment_relay_pressure_evict_total(); + stats.increment_relay_protocol_desync_close_total(); stats.increment_user_connects("alice"); stats.increment_user_curr_connects("alice"); stats.add_user_octets_from("alice", 1024); @@ -2246,9 +1962,6 @@ mod tests { assert!(output.contains("telemt_connections_total 2")); assert!(output.contains("telemt_connections_bad_total 1")); - assert!(output.contains("telemt_connections_current 2")); - assert!(output.contains("telemt_connections_direct_current 1")); - assert!(output.contains("telemt_connections_me_current 1")); assert!(output.contains("telemt_handshake_timeouts_total 1")); assert!(output.contains("telemt_upstream_connect_attempt_total 2")); assert!(output.contains("telemt_upstream_connect_success_total 1")); @@ -2271,6 +1984,10 @@ mod tests { assert!(output.contains("telemt_me_rpc_proxy_req_signal_response_total 1")); assert!(output.contains("telemt_me_rpc_proxy_req_signal_close_sent_total 1")); assert!(output.contains("telemt_me_idle_close_by_peer_total 1")); + assert!(output.contains("telemt_relay_idle_soft_mark_total 1")); + assert!(output.contains("telemt_relay_idle_hard_close_total 1")); + assert!(output.contains("telemt_relay_pressure_evict_total 1")); + assert!(output.contains("telemt_relay_protocol_desync_close_total 1")); assert!(output.contains("telemt_user_connections_total{user=\"alice\"} 1")); assert!(output.contains("telemt_user_connections_current{user=\"alice\"} 1")); assert!(output.contains("telemt_user_octets_from_client{user=\"alice\"} 1024")); @@ -2291,9 +2008,6 @@ mod tests { let output = render_metrics(&stats, &config, &tracker).await; assert!(output.contains("telemt_connections_total 0")); assert!(output.contains("telemt_connections_bad_total 0")); - assert!(output.contains("telemt_connections_current 0")); - assert!(output.contains("telemt_connections_direct_current 0")); - assert!(output.contains("telemt_connections_me_current 0")); assert!(output.contains("telemt_handshake_timeouts_total 0")); assert!(output.contains("telemt_user_unique_ips_current{user=")); assert!(output.contains("telemt_user_unique_ips_recent_window{user=")); @@ -2327,39 +2041,15 @@ mod tests { assert!(output.contains("# TYPE telemt_uptime_seconds gauge")); assert!(output.contains("# TYPE telemt_connections_total counter")); assert!(output.contains("# TYPE telemt_connections_bad_total counter")); - assert!(output.contains("# TYPE telemt_connections_current gauge")); - assert!(output.contains("# TYPE telemt_connections_direct_current gauge")); - assert!(output.contains("# TYPE telemt_connections_me_current gauge")); - assert!(output.contains("# TYPE telemt_relay_adaptive_promotions_total counter")); - assert!(output.contains("# TYPE telemt_relay_adaptive_demotions_total counter")); - assert!(output.contains("# TYPE telemt_relay_adaptive_hard_promotions_total counter")); - assert!(output.contains("# TYPE telemt_reconnect_evict_total counter")); - assert!(output.contains("# TYPE telemt_reconnect_stale_close_total counter")); assert!(output.contains("# TYPE telemt_handshake_timeouts_total counter")); assert!(output.contains("# TYPE telemt_upstream_connect_attempt_total counter")); assert!(output.contains("# TYPE telemt_me_rpc_proxy_req_signal_sent_total counter")); assert!(output.contains("# TYPE telemt_me_idle_close_by_peer_total counter")); + assert!(output.contains("# TYPE telemt_relay_idle_soft_mark_total counter")); + assert!(output.contains("# TYPE telemt_relay_idle_hard_close_total counter")); + assert!(output.contains("# TYPE telemt_relay_pressure_evict_total counter")); + assert!(output.contains("# TYPE telemt_relay_protocol_desync_close_total counter")); assert!(output.contains("# TYPE telemt_me_writer_removed_total counter")); - assert!(output.contains("# TYPE telemt_me_writer_teardown_attempt_total counter")); - assert!(output.contains("# TYPE telemt_me_writer_teardown_success_total counter")); - assert!(output.contains("# TYPE telemt_me_writer_teardown_timeout_total counter")); - assert!(output.contains("# TYPE telemt_me_writer_teardown_escalation_total counter")); - assert!(output.contains("# TYPE telemt_me_writer_teardown_noop_total counter")); - assert!(output.contains( - "# TYPE telemt_me_writer_teardown_duration_seconds histogram" - )); - assert!(output.contains( - "# TYPE telemt_me_writer_cleanup_side_effect_failures_total counter" - )); - assert!(output.contains("# TYPE telemt_me_writer_close_signal_drop_total counter")); - assert!(output.contains( - "# TYPE telemt_me_writer_close_signal_channel_full_total counter" - )); - assert!(output.contains( - "# TYPE telemt_me_draining_writers_reap_progress_total counter" - )); - assert!(output.contains("# TYPE telemt_pool_drain_soft_evict_total counter")); - assert!(output.contains("# TYPE telemt_pool_drain_soft_evict_writer_total counter")); assert!(output.contains( "# TYPE telemt_me_writer_removed_unexpected_minus_restored_total gauge" )); diff --git a/src/network/stun.rs b/src/network/stun.rs index c3a235f..6c6bd84 100644 --- a/src/network/stun.rs +++ b/src/network/stun.rs @@ -2,13 +2,20 @@ #![allow(dead_code)] use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; +use std::sync::OnceLock; use tokio::net::{lookup_host, UdpSocket}; use tokio::time::{timeout, Duration, sleep}; +use crate::crypto::SecureRandom; use crate::error::{ProxyError, Result}; use crate::network::dns_overrides::{resolve, split_host_port}; +fn stun_rng() -> &'static SecureRandom { + static STUN_RNG: OnceLock = OnceLock::new(); + STUN_RNG.get_or_init(SecureRandom::new) +} + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum IpFamily { V4, @@ -49,8 +56,6 @@ pub async fn stun_probe_family_with_bind( family: IpFamily, bind_ip: Option, ) -> Result> { - use rand::RngCore; - let bind_addr = match (family, bind_ip) { (IpFamily::V4, Some(IpAddr::V4(ip))) => SocketAddr::new(IpAddr::V4(ip), 0), (IpFamily::V6, Some(IpAddr::V6(ip))) => SocketAddr::new(IpAddr::V6(ip), 0), @@ -88,7 +93,7 @@ pub async fn stun_probe_family_with_bind( req[0..2].copy_from_slice(&0x0001u16.to_be_bytes()); // Binding Request req[2..4].copy_from_slice(&0u16.to_be_bytes()); // length req[4..8].copy_from_slice(&0x2112A442u32.to_be_bytes()); // magic cookie - rand::rng().fill_bytes(&mut req[8..20]); // transaction ID + stun_rng().fill(&mut req[8..20]); // transaction ID let mut buf = [0u8; 256]; let mut attempt = 0; diff --git a/src/protocol/constants.rs b/src/protocol/constants.rs index 9e79206..8130add 100644 --- a/src/protocol/constants.rs +++ b/src/protocol/constants.rs @@ -152,11 +152,29 @@ pub const TLS_RECORD_CHANGE_CIPHER: u8 = 0x14; pub const TLS_RECORD_APPLICATION: u8 = 0x17; /// TLS record type: Alert pub const TLS_RECORD_ALERT: u8 = 0x15; -/// Maximum TLS record size -pub const MAX_TLS_RECORD_SIZE: usize = 16384; -/// Maximum TLS chunk size (with overhead) -/// RFC 8446 §5.2 allows up to 16384 + 256 bytes of ciphertext -pub const MAX_TLS_CHUNK_SIZE: usize = 16384 + 256; +/// Maximum TLS plaintext record payload size. +/// RFC 8446 §5.1: "The length MUST NOT exceed 2^14 bytes." +/// Use this for validating incoming unencrypted records +/// (ClientHello, ChangeCipherSpec, unprotected Handshake messages). +pub const MAX_TLS_PLAINTEXT_SIZE: usize = 16_384; + +/// Structural minimum for a valid TLS 1.3 ClientHello with SNI. +/// Derived from RFC 8446 §4.1.2 field layout + Appendix D.4 compat mode. +/// Deliberately conservative (below any real client) to avoid false +/// positives on legitimate connections with compact extension sets. +pub const MIN_TLS_CLIENT_HELLO_SIZE: usize = 100; + +/// Maximum TLS ciphertext record payload size. +/// RFC 8446 §5.2: "The length MUST NOT exceed 2^14 + 256 bytes." +/// The +256 accounts for maximum AEAD expansion overhead. +/// Use this for validating or sizing buffers for encrypted records. +pub const MAX_TLS_CIPHERTEXT_SIZE: usize = 16_384 + 256; + +#[deprecated(note = "use MAX_TLS_PLAINTEXT_SIZE")] +pub const MAX_TLS_RECORD_SIZE: usize = MAX_TLS_PLAINTEXT_SIZE; + +#[deprecated(note = "use MAX_TLS_CIPHERTEXT_SIZE")] +pub const MAX_TLS_CHUNK_SIZE: usize = MAX_TLS_CIPHERTEXT_SIZE; /// Secure Intermediate payload is expected to be 4-byte aligned. pub fn is_valid_secure_payload_len(data_len: usize) -> bool { @@ -319,6 +337,10 @@ pub mod rpc_flags { pub const ME_CONNECT_TIMEOUT_SECS: u64 = 5; pub const ME_HANDSHAKE_TIMEOUT_SECS: u64 = 10; + + #[cfg(test)] + #[path = "tests/tls_size_constants_security_tests.rs"] + mod tls_size_constants_security_tests; #[cfg(test)] mod tests { diff --git a/src/protocol/tests/tls_adversarial_tests.rs b/src/protocol/tests/tls_adversarial_tests.rs new file mode 100644 index 0000000..b8df41a --- /dev/null +++ b/src/protocol/tests/tls_adversarial_tests.rs @@ -0,0 +1,351 @@ +use super::*; +use std::time::Instant; +use crate::crypto::sha256_hmac; + +/// Helper to create a byte vector of specific length. +fn make_garbage(len: usize) -> Vec { + vec![0x42u8; len] +} + +/// Helper to create a valid-looking HMAC digest for test. +fn make_digest(secret: &[u8], msg: &[u8], ts: u32) -> [u8; 32] { + let mut hmac = sha256_hmac(secret, msg); + let ts_bytes = ts.to_le_bytes(); + for i in 0..4 { + hmac[28 + i] ^= ts_bytes[i]; + } + hmac +} + +fn make_valid_tls_handshake_with_session_id( + secret: &[u8], + timestamp: u32, + session_id: &[u8], +) -> Vec { + let session_id_len = session_id.len(); + let len = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1 + session_id_len; + let mut handshake = vec![0x42u8; len]; + + handshake[TLS_DIGEST_POS + TLS_DIGEST_LEN] = session_id_len as u8; + let sid_start = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1; + handshake[sid_start..sid_start + session_id_len].copy_from_slice(session_id); + handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].fill(0); + + let digest = make_digest(secret, &handshake, timestamp); + + handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN] + .copy_from_slice(&digest); + handshake +} + +fn make_valid_tls_handshake(secret: &[u8], timestamp: u32) -> Vec { + make_valid_tls_handshake_with_session_id(secret, timestamp, &[0x42; 32]) +} + +// ------------------------------------------------------------------ +// Truncated Packet Tests (OWASP ASVS 5.1.4, 5.1.5) +// ------------------------------------------------------------------ + +#[test] +fn validate_tls_handshake_truncated_10_bytes_rejected() { + let secrets = vec![("user".to_string(), b"secret".to_vec())]; + let truncated = make_garbage(10); + assert!(validate_tls_handshake(&truncated, &secrets, true).is_none()); +} + +#[test] +fn validate_tls_handshake_truncated_at_digest_start_rejected() { + let secrets = vec![("user".to_string(), b"secret".to_vec())]; + // TLS_DIGEST_POS = 11. 11 bytes should be rejected. + let truncated = make_garbage(TLS_DIGEST_POS); + assert!(validate_tls_handshake(&truncated, &secrets, true).is_none()); +} + +#[test] +fn validate_tls_handshake_truncated_inside_digest_rejected() { + let secrets = vec![("user".to_string(), b"secret".to_vec())]; + // TLS_DIGEST_POS + 16 (half digest) + let truncated = make_garbage(TLS_DIGEST_POS + 16); + assert!(validate_tls_handshake(&truncated, &secrets, true).is_none()); +} + +#[test] +fn extract_sni_truncated_at_record_header_rejected() { + let truncated = make_garbage(3); + assert!(extract_sni_from_client_hello(&truncated).is_none()); +} + +#[test] +fn extract_sni_truncated_at_handshake_header_rejected() { + let mut truncated = vec![TLS_RECORD_HANDSHAKE, 0x03, 0x03, 0x00, 0x05]; + truncated.extend_from_slice(&[0x01, 0x00]); // ClientHello type but truncated length + assert!(extract_sni_from_client_hello(&truncated).is_none()); +} + +// ------------------------------------------------------------------ +// Malformed Extension Parsing Tests +// ------------------------------------------------------------------ + +#[test] +fn extract_sni_with_overlapping_extension_lengths_rejected() { + let mut h = vec![0x16, 0x03, 0x03, 0x00, 0x60]; // Record header + h.push(0x01); // Handshake type: ClientHello + h.extend_from_slice(&[0x00, 0x00, 0x5C]); // Length: 92 + h.extend_from_slice(&[0x03, 0x03]); // Version + h.extend_from_slice(&[0u8; 32]); // Random + h.push(0); // Session ID length: 0 + h.extend_from_slice(&[0x00, 0x02, 0x13, 0x01]); // Cipher suites + h.extend_from_slice(&[0x01, 0x00]); // Compression + + // Extensions start + h.extend_from_slice(&[0x00, 0x20]); // Total Extensions length: 32 + + // Extension 1: SNI (type 0) + h.extend_from_slice(&[0x00, 0x00]); + h.extend_from_slice(&[0x00, 0x40]); // Claimed len: 64 (OVERFLOWS total extensions len 32) + h.extend_from_slice(&[0u8; 64]); + + assert!(extract_sni_from_client_hello(&h).is_none()); +} + +#[test] +fn extract_sni_with_infinite_loop_potential_extension_rejected() { + let mut h = vec![0x16, 0x03, 0x03, 0x00, 0x60]; // Record header + h.push(0x01); // Handshake type: ClientHello + h.extend_from_slice(&[0x00, 0x00, 0x5C]); // Length: 92 + h.extend_from_slice(&[0x03, 0x03]); // Version + h.extend_from_slice(&[0u8; 32]); // Random + h.push(0); // Session ID length: 0 + h.extend_from_slice(&[0x00, 0x02, 0x13, 0x01]); // Cipher suites + h.extend_from_slice(&[0x01, 0x00]); // Compression + + // Extensions start + h.extend_from_slice(&[0x00, 0x10]); // Total Extensions length: 16 + + // Extension: zero length but claims more? + // If our parser didn't advance, it might loop. + // Telemt uses `pos += 4 + elen;` so it always advances. + h.extend_from_slice(&[0x12, 0x34]); // Unknown type + h.extend_from_slice(&[0x00, 0x00]); // Length 0 + + // Fill the rest with garbage + h.extend_from_slice(&[0x42; 12]); + + // We expect it to finish without SNI found + assert!(extract_sni_from_client_hello(&h).is_none()); +} + +#[test] +fn extract_sni_with_invalid_hostname_rejected() { + let host = b"invalid_host!%^"; + let mut sni = Vec::new(); + sni.extend_from_slice(&((host.len() + 3) as u16).to_be_bytes()); + sni.push(0); + sni.extend_from_slice(&(host.len() as u16).to_be_bytes()); + sni.extend_from_slice(host); + + let mut h = vec![0x16, 0x03, 0x03, 0x00, 0x60]; // Record header + h.push(0x01); // ClientHello + h.extend_from_slice(&[0x00, 0x00, 0x5C]); + h.extend_from_slice(&[0x03, 0x03]); + h.extend_from_slice(&[0u8; 32]); + h.push(0); + h.extend_from_slice(&[0x00, 0x02, 0x13, 0x01]); + h.extend_from_slice(&[0x01, 0x00]); + + let mut ext = Vec::new(); + ext.extend_from_slice(&0x0000u16.to_be_bytes()); + ext.extend_from_slice(&(sni.len() as u16).to_be_bytes()); + ext.extend_from_slice(&sni); + + h.extend_from_slice(&(ext.len() as u16).to_be_bytes()); + h.extend_from_slice(&ext); + + assert!(extract_sni_from_client_hello(&h).is_none(), "Invalid SNI hostname must be rejected"); +} + +// ------------------------------------------------------------------ +// Timing Neutrality Tests (OWASP ASVS 5.1.7) +// ------------------------------------------------------------------ + +#[test] +fn validate_tls_handshake_timing_neutrality() { + let secret = b"timing_test_secret_32_bytes_long_"; + let secrets = vec![("u".to_string(), secret.to_vec())]; + + let mut base = vec![0x42u8; 100]; + base[TLS_DIGEST_POS + TLS_DIGEST_LEN] = 32; + + const ITER: usize = 600; + const ROUNDS: usize = 7; + + let mut per_round_avg_diff_ns = Vec::with_capacity(ROUNDS); + + for round in 0..ROUNDS { + let mut success_h = base.clone(); + let mut fail_h = base.clone(); + + let start_success = Instant::now(); + for _ in 0..ITER { + let digest = make_digest(secret, &success_h, 0); + success_h[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].copy_from_slice(&digest); + let _ = validate_tls_handshake_at_time(&success_h, &secrets, true, 0); + } + let success_elapsed = start_success.elapsed(); + + let start_fail = Instant::now(); + for i in 0..ITER { + let mut digest = make_digest(secret, &fail_h, 0); + let flip_idx = (i + round) % (TLS_DIGEST_LEN - 4); + digest[flip_idx] ^= 0xFF; + fail_h[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].copy_from_slice(&digest); + let _ = validate_tls_handshake_at_time(&fail_h, &secrets, true, 0); + } + let fail_elapsed = start_fail.elapsed(); + + let diff = if success_elapsed > fail_elapsed { + success_elapsed - fail_elapsed + } else { + fail_elapsed - success_elapsed + }; + per_round_avg_diff_ns.push(diff.as_nanos() as f64 / ITER as f64); + } + + per_round_avg_diff_ns.sort_by(|a, b| a.partial_cmp(b).unwrap()); + let median_avg_diff_ns = per_round_avg_diff_ns[ROUNDS / 2]; + + // Keep this as a coarse side-channel guard only; noisy shared CI hosts can + // introduce microsecond-level jitter that should not fail deterministic suites. + assert!( + median_avg_diff_ns < 50_000.0, + "Median timing delta too large: {} ns/iter", + median_avg_diff_ns + ); +} + +// ------------------------------------------------------------------ +// Adversarial Fingerprinting / Active Probing Tests +// ------------------------------------------------------------------ + +#[test] +fn is_tls_handshake_robustness_against_probing() { + // Valid TLS 1.0 ClientHello + assert!(is_tls_handshake(&[0x16, 0x03, 0x01])); + // Valid TLS 1.2/1.3 ClientHello (Legacy Record Layer) + assert!(is_tls_handshake(&[0x16, 0x03, 0x03])); + + // Invalid record type but matching version + assert!(!is_tls_handshake(&[0x17, 0x03, 0x03])); + // Plaintext HTTP request + assert!(!is_tls_handshake(b"GET / HTTP/1.1")); + // Short garbage + assert!(!is_tls_handshake(&[0x16, 0x03])); +} + +#[test] +fn validate_tls_handshake_at_time_strict_boundary() { + let secret = b"strict_boundary_secret_32_bytes_"; + let secrets = vec![("u".to_string(), secret.to_vec())]; + let now: i64 = 1_000_000_000; + + // Boundary: exactly TIME_SKEW_MAX (120s past) + let ts_past = (now - TIME_SKEW_MAX) as u32; + let h = make_valid_tls_handshake_with_session_id(secret, ts_past, &[0x42; 32]); + assert!(validate_tls_handshake_at_time(&h, &secrets, false, now).is_some()); + + // Boundary + 1s: should be rejected + let ts_too_past = (now - TIME_SKEW_MAX - 1) as u32; + let h2 = make_valid_tls_handshake_with_session_id(secret, ts_too_past, &[0x42; 32]); + assert!(validate_tls_handshake_at_time(&h2, &secrets, false, now).is_none()); +} + +#[test] +fn extract_sni_with_duplicate_extensions_rejected() { + // Construct a ClientHello with TWO SNI extensions + let host1 = b"first.com"; + let mut sni1 = Vec::new(); + sni1.extend_from_slice(&((host1.len() + 3) as u16).to_be_bytes()); + sni1.push(0); + sni1.extend_from_slice(&(host1.len() as u16).to_be_bytes()); + sni1.extend_from_slice(host1); + + let host2 = b"second.com"; + let mut sni2 = Vec::new(); + sni2.extend_from_slice(&((host2.len() + 3) as u16).to_be_bytes()); + sni2.push(0); + sni2.extend_from_slice(&(host2.len() as u16).to_be_bytes()); + sni2.extend_from_slice(host2); + + let mut ext = Vec::new(); + // Ext 1: SNI + ext.extend_from_slice(&0x0000u16.to_be_bytes()); + ext.extend_from_slice(&(sni1.len() as u16).to_be_bytes()); + ext.extend_from_slice(&sni1); + // Ext 2: SNI again + ext.extend_from_slice(&0x0000u16.to_be_bytes()); + ext.extend_from_slice(&(sni2.len() as u16).to_be_bytes()); + ext.extend_from_slice(&sni2); + + let mut body = Vec::new(); + body.extend_from_slice(&[0x03, 0x03]); + body.extend_from_slice(&[0u8; 32]); + body.push(0); + body.extend_from_slice(&[0x00, 0x02, 0x13, 0x01]); + body.extend_from_slice(&[0x01, 0x00]); + body.extend_from_slice(&(ext.len() as u16).to_be_bytes()); + body.extend_from_slice(&ext); + + let mut handshake = Vec::new(); + handshake.push(0x01); + let body_len = (body.len() as u32).to_be_bytes(); + handshake.extend_from_slice(&body_len[1..4]); + handshake.extend_from_slice(&body); + + let mut h = Vec::new(); + h.push(0x16); + h.extend_from_slice(&[0x03, 0x03]); + h.extend_from_slice(&(handshake.len() as u16).to_be_bytes()); + h.extend_from_slice(&handshake); + + // Duplicate SNI extensions are ambiguous and must fail closed. + assert!(extract_sni_from_client_hello(&h).is_none()); +} + +#[test] +fn extract_alpn_with_malformed_list_rejected() { + let mut alpn_payload = Vec::new(); + alpn_payload.extend_from_slice(&0x0005u16.to_be_bytes()); // Total len 5 + alpn_payload.push(10); // Labeled len 10 (OVERFLOWS total 5) + alpn_payload.extend_from_slice(b"h2"); + + let mut ext = Vec::new(); + ext.extend_from_slice(&0x0010u16.to_be_bytes()); // Type: ALPN (16) + ext.extend_from_slice(&(alpn_payload.len() as u16).to_be_bytes()); + ext.extend_from_slice(&alpn_payload); + + let mut h = vec![0x16, 0x03, 0x03, 0x00, 0x40, 0x01, 0x00, 0x00, 0x3C, 0x03, 0x03]; + h.extend_from_slice(&[0u8; 32]); + h.push(0); + h.extend_from_slice(&[0x00, 0x02, 0x13, 0x01, 0x01, 0x00]); + h.extend_from_slice(&(ext.len() as u16).to_be_bytes()); + h.extend_from_slice(&ext); + + let res = extract_alpn_from_client_hello(&h); + assert!(res.is_empty(), "Malformed ALPN list must return empty or fail"); +} + +#[test] +fn extract_sni_with_huge_extension_header_rejected() { + let mut h = vec![0x16, 0x03, 0x03, 0x00, 0x00]; // Record header + h.push(0x01); // ClientHello + h.extend_from_slice(&[0x00, 0xFF, 0xFF]); // Huge length (65535) - overflows record + h.extend_from_slice(&[0x03, 0x03]); + h.extend_from_slice(&[0u8; 32]); + h.push(0); + h.extend_from_slice(&[0x00, 0x02, 0x13, 0x01, 0x01, 0x00]); + + // Extensions start + h.extend_from_slice(&[0xFF, 0xFF]); // Total extensions: 65535 (OVERFLOWS everything) + + assert!(extract_sni_from_client_hello(&h).is_none()); +} diff --git a/src/protocol/tests/tls_fuzz_security_tests.rs b/src/protocol/tests/tls_fuzz_security_tests.rs new file mode 100644 index 0000000..32d8efe --- /dev/null +++ b/src/protocol/tests/tls_fuzz_security_tests.rs @@ -0,0 +1,195 @@ +use super::*; +use crate::crypto::sha256_hmac; +use std::panic::catch_unwind; + +fn make_valid_tls_handshake_with_session_id( + secret: &[u8], + timestamp: u32, + session_id: &[u8], +) -> Vec { + let session_id_len = session_id.len(); + assert!(session_id_len <= u8::MAX as usize); + + let len = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1 + session_id_len; + let mut handshake = vec![0x42u8; len]; + handshake[TLS_DIGEST_POS + TLS_DIGEST_LEN] = session_id_len as u8; + let sid_start = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1; + handshake[sid_start..sid_start + session_id_len].copy_from_slice(session_id); + handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].fill(0); + + let mut digest = sha256_hmac(secret, &handshake); + let ts = timestamp.to_le_bytes(); + for idx in 0..4 { + digest[28 + idx] ^= ts[idx]; + } + + handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].copy_from_slice(&digest); + handshake +} + +fn make_valid_client_hello_record(host: &str, alpn_protocols: &[&[u8]]) -> Vec { + let mut body = Vec::new(); + body.extend_from_slice(&TLS_VERSION); + body.extend_from_slice(&[0u8; 32]); + body.push(0); + body.extend_from_slice(&2u16.to_be_bytes()); + body.extend_from_slice(&[0x13, 0x01]); + body.push(1); + body.push(0); + + let mut ext_blob = Vec::new(); + + let host_bytes = host.as_bytes(); + let mut sni_payload = Vec::new(); + sni_payload.extend_from_slice(&((host_bytes.len() + 3) as u16).to_be_bytes()); + sni_payload.push(0); + sni_payload.extend_from_slice(&(host_bytes.len() as u16).to_be_bytes()); + sni_payload.extend_from_slice(host_bytes); + ext_blob.extend_from_slice(&0x0000u16.to_be_bytes()); + ext_blob.extend_from_slice(&(sni_payload.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&sni_payload); + + if !alpn_protocols.is_empty() { + let mut alpn_list = Vec::new(); + for proto in alpn_protocols { + alpn_list.push(proto.len() as u8); + alpn_list.extend_from_slice(proto); + } + let mut alpn_data = Vec::new(); + alpn_data.extend_from_slice(&(alpn_list.len() as u16).to_be_bytes()); + alpn_data.extend_from_slice(&alpn_list); + + ext_blob.extend_from_slice(&0x0010u16.to_be_bytes()); + ext_blob.extend_from_slice(&(alpn_data.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&alpn_data); + } + + body.extend_from_slice(&(ext_blob.len() as u16).to_be_bytes()); + body.extend_from_slice(&ext_blob); + + let mut handshake = Vec::new(); + handshake.push(0x01); + let body_len = (body.len() as u32).to_be_bytes(); + handshake.extend_from_slice(&body_len[1..4]); + handshake.extend_from_slice(&body); + + let mut record = Vec::new(); + record.push(TLS_RECORD_HANDSHAKE); + record.extend_from_slice(&[0x03, 0x01]); + record.extend_from_slice(&(handshake.len() as u16).to_be_bytes()); + record.extend_from_slice(&handshake); + record +} + +#[test] +fn client_hello_fuzz_corpus_never_panics_or_accepts_corruption() { + let valid = make_valid_client_hello_record("example.com", &[b"h2", b"http/1.1"]); + assert_eq!(extract_sni_from_client_hello(&valid).as_deref(), Some("example.com")); + assert_eq!( + extract_alpn_from_client_hello(&valid), + vec![b"h2".to_vec(), b"http/1.1".to_vec()] + ); + assert!( + extract_sni_from_client_hello(&make_valid_client_hello_record("127.0.0.1", &[])).is_none(), + "literal IP hostnames must be rejected" + ); + + let mut corpus = vec![ + Vec::new(), + vec![0x16, 0x03, 0x03], + valid[..9].to_vec(), + valid[..valid.len() - 1].to_vec(), + ]; + + let mut wrong_type = valid.clone(); + wrong_type[0] = 0x15; + corpus.push(wrong_type); + + let mut wrong_handshake = valid.clone(); + wrong_handshake[5] = 0x02; + corpus.push(wrong_handshake); + + let mut wrong_length = valid.clone(); + wrong_length[3] ^= 0x7f; + corpus.push(wrong_length); + + for (idx, input) in corpus.iter().enumerate() { + assert!(catch_unwind(|| extract_sni_from_client_hello(input)).is_ok()); + assert!(catch_unwind(|| extract_alpn_from_client_hello(input)).is_ok()); + + if idx == 0 { + continue; + } + + assert!(extract_sni_from_client_hello(input).is_none(), "corpus item {idx} must fail closed for SNI"); + assert!(extract_alpn_from_client_hello(input).is_empty(), "corpus item {idx} must fail closed for ALPN"); + } +} + +#[test] +fn tls_handshake_fuzz_corpus_never_panics_and_rejects_digest_mutations() { + let secret = b"tls_fuzz_security_secret"; + let now: i64 = 1_700_000_000; + let base = make_valid_tls_handshake_with_session_id(secret, now as u32, &[0x42; 32]); + let secrets = vec![("fuzz-user".to_string(), secret.to_vec())]; + + assert!(validate_tls_handshake_at_time(&base, &secrets, false, now).is_some()); + + let mut corpus = Vec::new(); + + let mut truncated = base.clone(); + truncated.truncate(TLS_DIGEST_POS + 16); + corpus.push(truncated); + + let mut digest_flip = base.clone(); + digest_flip[TLS_DIGEST_POS + 7] ^= 0x80; + corpus.push(digest_flip); + + let mut session_id_len_overflow = base.clone(); + session_id_len_overflow[TLS_DIGEST_POS + TLS_DIGEST_LEN] = 33; + corpus.push(session_id_len_overflow); + + let mut timestamp_far_past = base.clone(); + timestamp_far_past[TLS_DIGEST_POS + 28..TLS_DIGEST_POS + 32] + .copy_from_slice(&((now - i64::from(TIME_SKEW_MAX) - 1) as u32).to_le_bytes()); + corpus.push(timestamp_far_past); + + let mut timestamp_far_future = base.clone(); + timestamp_far_future[TLS_DIGEST_POS + 28..TLS_DIGEST_POS + 32] + .copy_from_slice(&((now - TIME_SKEW_MIN + 1) as u32).to_le_bytes()); + corpus.push(timestamp_far_future); + + let mut seed = 0xA5A5_5A5A_F00D_BAAD_u64; + for _ in 0..32 { + let mut mutated = base.clone(); + for _ in 0..2 { + seed = seed.wrapping_mul(2862933555777941757).wrapping_add(3037000493); + let idx = TLS_DIGEST_POS + (seed as usize % TLS_DIGEST_LEN); + mutated[idx] ^= ((seed >> 17) as u8).wrapping_add(1); + } + corpus.push(mutated); + } + + for (idx, handshake) in corpus.iter().enumerate() { + let result = catch_unwind(|| validate_tls_handshake_at_time(handshake, &secrets, false, now)); + assert!(result.is_ok(), "corpus item {idx} must not panic"); + assert!(result.unwrap().is_none(), "corpus item {idx} must fail closed"); + } +} + +#[test] +fn tls_boot_time_acceptance_is_capped_by_replay_window() { + let secret = b"tls_boot_time_cap_secret"; + let secrets = vec![("boot-user".to_string(), secret.to_vec())]; + let boot_ts = 1u32; + let handshake = make_valid_tls_handshake_with_session_id(secret, boot_ts, &[0x42; 32]); + + assert!( + validate_tls_handshake_with_replay_window(&handshake, &secrets, false, 300).is_some(), + "boot-time timestamp should be accepted while replay window permits it" + ); + assert!( + validate_tls_handshake_with_replay_window(&handshake, &secrets, false, 0).is_none(), + "boot-time timestamp must be rejected when replay window disables the bypass" + ); +} diff --git a/src/protocol/tests/tls_security_tests.rs b/src/protocol/tests/tls_security_tests.rs new file mode 100644 index 0000000..a6e7b2b --- /dev/null +++ b/src/protocol/tests/tls_security_tests.rs @@ -0,0 +1,2355 @@ +use super::*; +use crate::crypto::sha256_hmac; +use crate::tls_front::emulator::build_emulated_server_hello; +use crate::tls_front::types::{CachedTlsData, ParsedServerHello, TlsBehaviorProfile, TlsProfileSource}; +use std::time::SystemTime; + +/// Build a TLS-handshake-like buffer that contains a valid HMAC digest +/// for the given `secret` and `timestamp`. +/// +/// Layout (bytes): +/// [0..TLS_DIGEST_POS] : fixed filler (0x42) +/// [TLS_DIGEST_POS..+32] : digest = HMAC XOR [0..0 || timestamp_le] +/// [TLS_DIGEST_POS+32] : session_id_len = 32 +/// [TLS_DIGEST_POS+33..+65] : session_id filler (0x42) +fn make_valid_tls_handshake_with_session_id( + secret: &[u8], + timestamp: u32, + session_id: &[u8], +) -> Vec { + let session_id_len = session_id.len(); + assert!(session_id_len <= u8::MAX as usize); + let len = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1 + session_id_len; + let mut handshake = vec![0x42u8; len]; + + handshake[TLS_DIGEST_POS + TLS_DIGEST_LEN] = session_id_len as u8; + let sid_start = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1; + handshake[sid_start..sid_start + session_id_len].copy_from_slice(session_id); + // Zero the digest slot before computing HMAC (mirrors what validate does). + handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].fill(0); + + let computed = sha256_hmac(secret, &handshake); + + // digest = HMAC such that XOR with stored digest yields [0..0, timestamp_le]. + // bytes 0-27 of digest == computed[0..28] -> xored[..28] == 0 + // bytes 28-31 of digest == computed[28..32] XOR timestamp_le + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + + handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN] + .copy_from_slice(&digest); + handshake +} + +fn make_valid_tls_handshake(secret: &[u8], timestamp: u32) -> Vec { + make_valid_tls_handshake_with_session_id(secret, timestamp, &[0x42; 32]) +} + +// ------------------------------------------------------------------ +// Happy-path sanity +// ------------------------------------------------------------------ + +#[test] +fn valid_handshake_with_correct_secret_accepted() { + let secret = b"correct_horse_battery_staple_32b"; + // timestamp = 0 triggers is_boot_time path, accepted without wall-clock check. + let handshake = make_valid_tls_handshake(secret, 0); + let secrets = vec![("alice".to_string(), secret.to_vec())]; + let result = validate_tls_handshake(&handshake, &secrets, true); + assert!(result.is_some(), "Valid handshake must be accepted"); + assert_eq!(result.unwrap().user, "alice"); +} + +#[test] +fn deterministic_external_vector_validates_without_helper() { + // Deterministic vector generated by an external Python stdlib HMAC script, + // not by this test module helper. This catches mirrored helper mistakes. + let secret = hex::decode("00112233445566778899aabbccddeeff").unwrap(); + let handshake = hex::decode( + "4242424242424242424242a93225d1d6b46260bc9ce0cc48c7487d2b1ca5afa7ae9fc6609d9e60a3ca842b204242424242424242424242424242424242424242424242424242424242424242", + ) + .unwrap(); + + let secrets = vec![("vector_user".to_string(), secret)]; + let result = validate_tls_handshake(&handshake, &secrets, true).unwrap(); + + assert_eq!(result.user, "vector_user"); + assert_eq!(result.timestamp, 0x01020304); +} + +#[test] +fn valid_handshake_timestamp_extracted_correctly() { + let secret = b"ts_extraction_test"; + let ts: u32 = 0xDEAD_BEEF; + let handshake = make_valid_tls_handshake(secret, ts); + let secrets = vec![("u".to_string(), secret.to_vec())]; + let result = validate_tls_handshake(&handshake, &secrets, true); + assert!(result.is_some()); + assert_eq!(result.unwrap().timestamp, ts); +} + +// ------------------------------------------------------------------ +// HMAC bit-flip rejection - adversarial HMAC forgery attempts +// ------------------------------------------------------------------ + +/// Flip every single bit across the 28-byte HMAC check window one at a +/// time. Each flip must cause rejection. This is the primary guard +/// against a censor gradually narrowing down a valid HMAC via partial +/// matches (which would be exploitable with a variable-time comparison). +#[test] +fn hmac_single_bit_flip_anywhere_in_check_window_rejected() { + let secret = b"flip_test_secret"; + let base = make_valid_tls_handshake(secret, 0); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + // First ensure the unmodified handshake is accepted. + assert!( + validate_tls_handshake(&base, &secrets, true).is_some(), + "Baseline handshake must be accepted before flip tests" + ); + + for byte_pos in 0..28usize { + for bit in 0u8..8 { + let mut h = base.clone(); + h[TLS_DIGEST_POS + byte_pos] ^= 1 << bit; + assert!( + validate_tls_handshake(&h, &secrets, true).is_none(), + "Flip of bit {bit} in HMAC byte {byte_pos} must be rejected" + ); + } + } +} + +/// XOR entire check window (bytes 0-27) with 0xFF - must still fail. +#[test] +fn hmac_full_window_corruption_rejected() { + let secret = b"full_window_test"; + let mut h = make_valid_tls_handshake(secret, 0); + let secrets = vec![("u".to_string(), secret.to_vec())]; + for i in 0..28 { + h[TLS_DIGEST_POS + i] ^= 0xFF; + } + assert!(validate_tls_handshake(&h, &secrets, true).is_none()); +} + +/// Byte 27 is the last byte in the checked window. A non-constant-time +/// `all(|b| b == 0)` that short-circuits on byte 0 would never even reach +/// byte 27, making this an effective "did the fix actually run to the end" +/// sentinel: if this passes but the earlier byte-0 test fails, the check +/// window is not being evaluated end-to-end. +#[test] +fn hmac_last_byte_of_check_window_enforced() { + let secret = b"last_byte_sentinel"; + let mut h = make_valid_tls_handshake(secret, 0); + let secrets = vec![("u".to_string(), secret.to_vec())]; + // Corrupt only byte 27. + h[TLS_DIGEST_POS + 27] ^= 0x01; + assert!( + validate_tls_handshake(&h, &secrets, true).is_none(), + "Corruption at byte 27 (end of HMAC window) must cause rejection" + ); +} + +// ------------------------------------------------------------------ +// User enumeration / multi-user ordering +// ------------------------------------------------------------------ + +#[test] +fn wrong_user_secret_rejected_even_with_valid_structure() { + let secret_a = b"secret_alpha"; + let secret_b = b"secret_beta"; + let handshake = make_valid_tls_handshake(secret_b, 0); + // Only user_a is configured. + let secrets = vec![("user_a".to_string(), secret_a.to_vec())]; + assert!( + validate_tls_handshake(&handshake, &secrets, true).is_none(), + "Handshake for user_b must fail when only user_a is configured" + ); +} + +#[test] +fn second_user_in_list_found_when_first_does_not_match() { + let secret_a = b"secret_alpha"; + let secret_b = b"secret_beta"; + let handshake = make_valid_tls_handshake(secret_b, 0); + let secrets = vec![ + ("user_a".to_string(), secret_a.to_vec()), + ("user_b".to_string(), secret_b.to_vec()), + ]; + let result = validate_tls_handshake(&handshake, &secrets, true); + assert!(result.is_some(), "user_b must be found even though user_a comes first"); + assert_eq!(result.unwrap().user, "user_b"); +} + +#[test] +fn duplicate_secret_keeps_first_user_identity() { + // If multiple entries share the same secret, the selected identity must + // stay stable and deterministic (first entry wins). + let shared = b"same_secret_for_two_users"; + let handshake = make_valid_tls_handshake(shared, 0); + let secrets = vec![ + ("first_user".to_string(), shared.to_vec()), + ("second_user".to_string(), shared.to_vec()), + ]; + + let result = validate_tls_handshake(&handshake, &secrets, true); + assert!(result.is_some()); + assert_eq!(result.unwrap().user, "first_user"); +} + +#[test] +fn no_user_matches_returns_none() { + let secret_a = b"aaa"; + let secret_b = b"bbb"; + let secret_c = b"ccc"; + let handshake = make_valid_tls_handshake(b"unknown_secret", 0); + let secrets = vec![ + ("a".to_string(), secret_a.to_vec()), + ("b".to_string(), secret_b.to_vec()), + ("c".to_string(), secret_c.to_vec()), + ]; + assert!(validate_tls_handshake(&handshake, &secrets, true).is_none()); +} + +#[test] +fn empty_secrets_list_rejects_everything() { + let secret = b"test"; + let handshake = make_valid_tls_handshake(secret, 0); + let secrets: Vec<(String, Vec)> = Vec::new(); + assert!(validate_tls_handshake(&handshake, &secrets, true).is_none()); +} + +// ------------------------------------------------------------------ +// Timestamp / time-skew boundary attacks +// ------------------------------------------------------------------ + +#[test] +fn timestamp_at_time_skew_boundaries_accepted() { + let secret = b"skew_boundary_test_secret"; + let now: i64 = 1_700_000_000; + let secrets = vec![("u".to_string(), secret.to_vec())]; + + // time_diff = now - ts = TIME_SKEW_MIN = -1200 + // -> ts = now - TIME_SKEW_MIN = now + 1200 (20 min in the future). + let ts_at_future_limit = (now - TIME_SKEW_MIN) as u32; + let h = make_valid_tls_handshake(secret, ts_at_future_limit); + assert!( + validate_tls_handshake_at_time(&h, &secrets, false, now).is_some(), + "Timestamp at max-allowed future (time_diff = TIME_SKEW_MIN) must be accepted" + ); + + // time_diff = now - ts = TIME_SKEW_MAX = 600 + // -> ts = now - TIME_SKEW_MAX = now - 600 (10 min in the past). + let ts_at_past_limit = (now - TIME_SKEW_MAX) as u32; + let h = make_valid_tls_handshake(secret, ts_at_past_limit); + assert!( + validate_tls_handshake_at_time(&h, &secrets, false, now).is_some(), + "Timestamp at max-allowed past (time_diff = TIME_SKEW_MAX) must be accepted" + ); +} + +#[test] +fn timestamp_one_second_outside_skew_window_rejected() { + let secret = b"skew_outside_test_secret"; + let now: i64 = 1_700_000_000; + let secrets = vec![("u".to_string(), secret.to_vec())]; + + // time_diff = TIME_SKEW_MAX + 1 = 601 (one second too far in the past) + // -> ts = now - (TIME_SKEW_MAX + 1) = now - 601 + let ts_too_past = (now - TIME_SKEW_MAX - 1) as u32; + let h = make_valid_tls_handshake(secret, ts_too_past); + assert!( + validate_tls_handshake_at_time(&h, &secrets, false, now).is_none(), + "Timestamp one second too far in the past must be rejected" + ); + + // time_diff = TIME_SKEW_MIN - 1 = -1201 (one second too far in the future) + // -> ts = now - (TIME_SKEW_MIN - 1) = now + 1201 + let ts_too_future = (now - TIME_SKEW_MIN + 1) as u32; + let h = make_valid_tls_handshake(secret, ts_too_future); + assert!( + validate_tls_handshake_at_time(&h, &secrets, false, now).is_none(), + "Timestamp one second too far in the future must be rejected" + ); +} + +#[test] +fn ignore_time_skew_accepts_far_future_timestamp() { + let secret = b"ignore_skew_test"; + let now: i64 = 1_700_000_000; + let secrets = vec![("u".to_string(), secret.to_vec())]; + + // 1 hour in the future - outside TIME_SKEW_MAX but should pass with flag. + let future_ts = (now + 3600) as u32; + let h = make_valid_tls_handshake(secret, future_ts); + assert!( + validate_tls_handshake_at_time(&h, &secrets, true, now).is_some(), + "ignore_time_skew=true must override window rejection" + ); + assert!( + validate_tls_handshake_at_time(&h, &secrets, false, now).is_none(), + "ignore_time_skew=false must still reject far-future timestamp" + ); +} + +#[test] +fn boot_time_timestamp_accepted_without_ignore_flag() { + // Timestamps below the boot-time threshold are treated as client uptime, + // not real wall-clock time. The proxy allows them regardless of skew. + let secret = b"boot_time_test"; + // Keep this safely below compatibility cap to assert bypass behavior. + let boot_ts: u32 = BOOT_TIME_COMPAT_MAX_SECS.saturating_sub(1); + let handshake = make_valid_tls_handshake(secret, boot_ts); + let secrets = vec![("u".to_string(), secret.to_vec())]; + assert!( + validate_tls_handshake(&handshake, &secrets, false).is_some(), + "Boot-time timestamp must be accepted even with ignore_time_skew=false" + ); +} + +// ------------------------------------------------------------------ +// Structural / length boundary attacks +// ------------------------------------------------------------------ + +#[test] +fn too_short_handshake_rejected_without_panic() { + let secrets = vec![("u".to_string(), b"s".to_vec())]; + // Exactly one byte short of the minimum required length. + let h = vec![0u8; TLS_DIGEST_POS + TLS_DIGEST_LEN]; + assert!(validate_tls_handshake(&h, &secrets, true).is_none()); + + // Empty buffer. + assert!(validate_tls_handshake(&[], &secrets, true).is_none()); +} + +#[test] +fn all_prefix_lengths_below_minimum_rejected_without_panic() { + let min_len = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1; + let secrets = vec![("u".to_string(), b"s".to_vec())]; + + for len in 0..min_len { + let h = vec![0u8; len]; + assert!( + validate_tls_handshake(&h, &secrets, true).is_none(), + "prefix length {len} below minimum must be rejected" + ); + } +} + +#[test] +fn claimed_session_id_overflows_buffer_rejected() { + let session_id_len: usize = 32; + let min_len = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1 + session_id_len; + let mut h = vec![0u8; min_len]; + // Claim session_id is 33 bytes - one more than the buffer holds. + h[TLS_DIGEST_POS + TLS_DIGEST_LEN] = (session_id_len + 1) as u8; + let secrets = vec![("u".to_string(), b"s".to_vec())]; + assert!(validate_tls_handshake(&h, &secrets, true).is_none()); +} + +#[test] +fn max_session_id_len_255_does_not_panic() { + // session_id_len = 255 with a buffer that is far too small for it. + let min_len = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1 + 32; + let mut h = vec![0u8; min_len]; + h[TLS_DIGEST_POS + TLS_DIGEST_LEN] = 255; + let secrets = vec![("u".to_string(), b"s".to_vec())]; + assert!(validate_tls_handshake(&h, &secrets, true).is_none()); +} + +#[test] +fn one_byte_session_id_validates_and_is_preserved() { + let secret = b"sid_len_1_test"; + let handshake = make_valid_tls_handshake_with_session_id(secret, 0, &[0xAB]); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + let result = validate_tls_handshake(&handshake, &secrets, true) + .expect("one-byte session_id handshake must validate"); + assert_eq!(result.session_id, vec![0xAB]); +} + +#[test] +fn max_session_id_len_255_with_valid_digest_is_rejected_by_rfc_cap() { + let secret = b"sid_len_255_test"; + let session_id = vec![0xCCu8; 255]; + let handshake = make_valid_tls_handshake_with_session_id(secret, 0, &session_id); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + assert!( + validate_tls_handshake(&handshake, &secrets, true).is_none(), + "legacy_session_id length > 32 must be rejected even with valid digest" + ); +} + +// ------------------------------------------------------------------ +// Adversarial digest values +// ------------------------------------------------------------------ + +#[test] +fn all_zeros_digest_rejected() { + // An all-zeros digest would only pass if HMAC(secret, msg) happens to + // have its first 28 bytes all zero, which is computationally infeasible. + let session_id_len: usize = 32; + let min_len = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1 + session_id_len; + let mut h = vec![0x42u8; min_len]; + h[TLS_DIGEST_POS + TLS_DIGEST_LEN] = session_id_len as u8; + h[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].fill(0); + let secrets = vec![("u".to_string(), b"test_secret".to_vec())]; + assert!(validate_tls_handshake(&h, &secrets, true).is_none()); +} + +#[test] +fn all_ones_digest_rejected() { + let session_id_len: usize = 32; + let min_len = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1 + session_id_len; + let mut h = vec![0x42u8; min_len]; + h[TLS_DIGEST_POS + TLS_DIGEST_LEN] = session_id_len as u8; + h[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].fill(0xFF); + let secrets = vec![("u".to_string(), b"test_secret".to_vec())]; + assert!(validate_tls_handshake(&h, &secrets, true).is_none()); +} + +/// Simulate a censor that sends 200 crafted packets with random digests. +/// Every single one must be rejected; no random digest should accidentally +/// pass (probability 2^{-224} per attempt; negligible for 200 trials). +#[test] +fn censor_probe_random_digests_all_rejected() { + use crate::crypto::SecureRandom; + let secret = b"production_like_secret_value_xyz"; + let session_id_len: usize = 32; + let min_len = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1 + session_id_len; + let secrets = vec![("u".to_string(), secret.to_vec())]; + let rng = SecureRandom::new(); + + for attempt in 0..200 { + let mut h = vec![0x42u8; min_len]; + h[TLS_DIGEST_POS + TLS_DIGEST_LEN] = session_id_len as u8; + let rand_digest = rng.bytes(TLS_DIGEST_LEN); + h[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN] + .copy_from_slice(&rand_digest); + assert!( + validate_tls_handshake(&h, &secrets, true).is_none(), + "Random digest at attempt {attempt} must not match" + ); + } +} + +/// The check window is bytes 0-27 of the XOR result. Bytes 28-31 encode +/// the timestamp and must NOT affect whether the HMAC portion validates - +/// only the timestamp range check uses them. Build a valid handshake with +/// timestamp = 0 (boot-time), flip each of bytes 28-31 with ignore_time_skew +/// enabled, and verify the HMAC portion still passes (the timestamp changes +/// but the proxy still accepts the connection under ignore_time_skew). +#[test] +fn timestamp_bytes_28_31_do_not_affect_hmac_window() { + let secret = b"window_boundary_test"; + let base = make_valid_tls_handshake(secret, 0); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + // Baseline must pass. + assert!(validate_tls_handshake(&base, &secrets, true).is_some()); + + // Flip each of the timestamp bytes; with ignore_time_skew the + // modified timestamps (small absolute values) still pass boot-time check. + for i in 28..32usize { + let mut h = base.clone(); + h[TLS_DIGEST_POS + i] ^= 0xFF; + // The new timestamp is non-zero but potentially still < boot threshold; + // use ignore_time_skew=true so wallet test is HMAC-only. + assert!( + validate_tls_handshake(&h, &secrets, true).is_some(), + "Flipping byte {i} (timestamp region) must not invalidate HMAC window" + ); + } +} + +// ------------------------------------------------------------------ +// session_id preservation +// ------------------------------------------------------------------ + +#[test] +fn session_id_is_preserved_verbatim_in_validation_result() { + // If session_id extraction is ever broken (wrong offset, wrong length, + // off-by-one), this test will catch it before it silently corrupts the + // ServerHello that echoes the session_id back to the client. + let secret = b"session_id_preservation_test"; + let handshake = make_valid_tls_handshake(secret, 0); + let secrets = vec![("u".to_string(), secret.to_vec())]; + let result = validate_tls_handshake(&handshake, &secrets, true).unwrap(); + + let sid_len_pos = TLS_DIGEST_POS + TLS_DIGEST_LEN; + let sid_len = handshake[sid_len_pos] as usize; + let expected = &handshake[sid_len_pos + 1..sid_len_pos + 1 + sid_len]; + + assert_eq!( + result.session_id, expected, + "session_id in TlsValidation must be the verbatim bytes from the handshake" + ); +} + +// ------------------------------------------------------------------ +// Clock decoupling - ignore_time_skew must not consult the system clock +// ------------------------------------------------------------------ + +/// When `ignore_time_skew = true`, a valid HMAC must be accepted even if +/// `now = 0` (the sentinel used when the clock is not needed). A broken +/// system clock cannot silently deny service when the admin has explicitly +/// disabled timestamp checking. +#[test] +fn ignore_time_skew_accepts_valid_hmac_with_now_zero() { + let secret = b"clock_decoupling_test"; + // Use a realistic Unix timestamp that would be far outside the window + // if compared against now=0 (time_diff would be ~-1_700_000_000). + let realistic_ts: u32 = 1_700_000_000; + let h = make_valid_tls_handshake(secret, realistic_ts); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + assert!( + validate_tls_handshake_at_time(&h, &secrets, true, 0).is_some(), + "ignore_time_skew=true must accept a valid HMAC regardless of `now`" + ); + + // Confirm that the same handshake IS rejected when the window is enforced + // and now=0 (time_diff very negative -> outside window). This distinguishes + // "clock decoupling" from "always accept". + assert!( + validate_tls_handshake_at_time(&h, &secrets, false, 0).is_none(), + "ignore_time_skew=false with now=0 must still reject out-of-window timestamps" + ); +} + +/// An HMAC-invalid handshake must be rejected even when ignore_time_skew=true +/// and now=0. Verifies that the clock-decoupling fix did not weaken HMAC +/// enforcement in the ignore_time_skew path. +#[test] +fn ignore_time_skew_with_now_zero_still_rejects_bad_hmac() { + let secret = b"clock_no_backdoor_test"; + let mut h = make_valid_tls_handshake(secret, 1_700_000_000); + let secrets = vec![("u".to_string(), secret.to_vec())]; + // Corrupt the HMAC check window. + h[TLS_DIGEST_POS] ^= 0xFF; + assert!( + validate_tls_handshake_at_time(&h, &secrets, true, 0).is_none(), + "Broken HMAC must be rejected even with ignore_time_skew=true and now=0" + ); +} + +#[test] +fn system_time_before_unix_epoch_is_rejected_without_panic() { + let before_epoch = UNIX_EPOCH + .checked_sub(std::time::Duration::from_secs(1)) + .expect("UNIX_EPOCH minus one second must be representable"); + assert!(system_time_to_unix_secs(before_epoch).is_none()); +} + +/// `i64::MAX` is 9_223_372_036_854_775_807 seconds (~292 billion years CE). +/// Any `SystemTime` whose duration since epoch exceeds `i64::MAX` seconds +/// must return `None` rather than silently wrapping to a large negative +/// timestamp that would corrupt every subsequent time-skew comparison. +#[test] +fn system_time_far_future_overflowing_i64_returns_none() { + // i64::MAX + 1 seconds past epoch overflows i64 when cast naively with `as`. + let overflow_secs = u64::try_from(i64::MAX).unwrap() + 1; + if let Some(far_future) = + UNIX_EPOCH.checked_add(std::time::Duration::from_secs(overflow_secs)) + { + assert!( + system_time_to_unix_secs(far_future).is_none(), + "Seconds > i64::MAX must return None, not a wrapped negative timestamp" + ); + } + // If the platform cannot represent this SystemTime, the test is vacuously + // satisfied: `checked_add` returning None means the platform already rejects it. +} + +// ------------------------------------------------------------------ +// Message canonicalization — HMAC covers every byte of the handshake +// ------------------------------------------------------------------ + +/// Every byte before TLS_DIGEST_POS is part of the HMAC input (because msg +/// = full handshake with only the digest slot zeroed). An attacker cannot +/// replay a valid handshake with a modified ClientHello header while keeping +/// the stored digest; each such modification produces a different HMAC. +#[test] +fn pre_digest_bytes_are_hmac_covered() { + // TLS_DIGEST_POS = 11, so 11 bytes precede the digest. + let secret = b"pre_digest_coverage_test"; + let base = make_valid_tls_handshake(secret, 0); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + for byte_pos in 0..TLS_DIGEST_POS { + let mut h = base.clone(); + h[byte_pos] ^= 0x01; + assert!( + validate_tls_handshake(&h, &secrets, true).is_none(), + "Flip in pre-digest byte {byte_pos} must cause HMAC check failure" + ); + } +} + +/// session_id bytes follow the digest in the buffer and are also part of the +/// HMAC input. Flipping any of them invalidates the stored digest, preventing +/// a censor from capturing a valid session_id and replaying it with a different +/// one while keeping the rest of the packet intact. +#[test] +fn session_id_bytes_are_hmac_covered() { + let secret = b"session_id_coverage_test"; + let base = make_valid_tls_handshake(secret, 0); // session_id_len = 32 + let secrets = vec![("u".to_string(), secret.to_vec())]; + + let sid_start = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1; + for byte_pos in sid_start..base.len() { + let mut h = base.clone(); + h[byte_pos] ^= 0x01; + assert!( + validate_tls_handshake(&h, &secrets, true).is_none(), + "Flip in session_id byte at offset {byte_pos} must cause HMAC check failure" + ); + } +} + +/// Appending even one byte to a valid handshake changes the HMAC input (msg +/// includes all bytes) and therefore invalidates the stored digest. This +/// prevents a length-extension-style modification of the payload. +#[test] +fn appended_trailing_byte_causes_rejection() { + let secret = b"trailing_byte_test"; + let mut h = make_valid_tls_handshake(secret, 0); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + assert!(validate_tls_handshake(&h, &secrets, true).is_some(), "baseline"); + + h.push(0x00); + assert!( + validate_tls_handshake(&h, &secrets, true).is_none(), + "Appending a trailing byte to a valid handshake must invalidate the HMAC" + ); +} + +// ------------------------------------------------------------------ +// Zero-length session_id (structural edge case) +// ------------------------------------------------------------------ + +/// session_id_len = 0 is legal in the TLS spec. The validator must accept a +/// valid handshake with an empty session_id and return an empty session_id +/// slice without panicking or accessing out-of-bounds memory. +#[test] +fn zero_length_session_id_accepted() { + let secret = b"zero_sid_test"; + // Buffer: pre-digest | digest | session_id_len=0 (no session_id bytes follow) + let len = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1; + let mut handshake = vec![0x42u8; len]; + handshake[TLS_DIGEST_POS + TLS_DIGEST_LEN] = 0; // session_id_len = 0 + handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].fill(0); + + let computed = sha256_hmac(secret, &handshake); + // timestamp = 0 → ts XOR bytes are all zero → digest = computed unchanged. + handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN] + .copy_from_slice(&computed); + + let secrets = vec![("u".to_string(), secret.to_vec())]; + let result = validate_tls_handshake(&handshake, &secrets, true); + assert!(result.is_some(), "zero-length session_id must be accepted"); + assert!( + result.unwrap().session_id.is_empty(), + "session_id field must be empty when session_id_len = 0" + ); +} + +// ------------------------------------------------------------------ +// Boot-time threshold — exact boundary precision +// ------------------------------------------------------------------ + +/// timestamp = BOOT_TIME_COMPAT_MAX_SECS - 1 is the last value inside +/// the runtime boot-time compatibility window. +/// is_boot_time = true → skew check is skipped entirely → accepted even +/// when `now` is far from the timestamp. +#[test] +fn timestamp_one_below_boot_threshold_bypasses_skew_check() { + let secret = b"boot_last_value_test"; + let ts: u32 = BOOT_TIME_COMPAT_MAX_SECS - 1; + let h = make_valid_tls_handshake(secret, ts); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + // now = 0 → time_diff would be -86_399_999, way outside [-1200, 600]. + // Boot-time bypass must prevent the skew check from running. + assert!( + validate_tls_handshake_at_time(&h, &secrets, false, 0).is_some(), + "ts=BOOT_TIME_COMPAT_MAX_SECS-1 must bypass skew check regardless of now" + ); +} + +/// timestamp = BOOT_TIME_COMPAT_MAX_SECS is the first value outside the +/// runtime boot-time compatibility window. +/// is_boot_time = false → skew check IS applied. Two sub-cases confirm this: +/// once with now chosen so the skew passes (accepted) and once where it fails. +#[test] +fn timestamp_at_boot_threshold_triggers_skew_check() { + let secret = b"boot_exact_value_test"; + let ts: u32 = BOOT_TIME_COMPAT_MAX_SECS; + let h = make_valid_tls_handshake(secret, ts); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + // now = ts + 50 → time_diff = 50, within [-1200, 600] → accepted. + let now_valid: i64 = ts as i64 + 50; + assert!( + validate_tls_handshake_at_time_with_boot_cap( + &h, + &secrets, + false, + now_valid, + BOOT_TIME_COMPAT_MAX_SECS, + ) + .is_some(), + "ts=BOOT_TIME_COMPAT_MAX_SECS within skew window must be accepted via skew check" + ); + + // now = -1 → time_diff = -121 at the 120-second threshold, outside window + // for TIME_SKEW_MIN=-120. If boot-time bypass were wrongly applied this + // would pass. + assert!( + validate_tls_handshake_at_time_with_boot_cap( + &h, + &secrets, + false, + -1, + BOOT_TIME_COMPAT_MAX_SECS, + ) + .is_none(), + "ts=BOOT_TIME_COMPAT_MAX_SECS far from now must be rejected — no boot-time bypass" + ); +} + +#[test] +fn replay_window_cap_disables_boot_bypass_for_old_timestamps() { + let secret = b"boot_cap_disabled_test"; + let ts: u32 = 900; + let h = make_valid_tls_handshake(secret, ts); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + let result = validate_tls_handshake_with_replay_window(&h, &secrets, false, 300); + assert!( + result.is_none(), + "timestamp above replay-window cap must not use boot-time bypass" + ); +} + +#[test] +fn replay_window_cap_still_allows_small_boot_timestamp() { + let secret = b"boot_cap_enabled_test"; + let ts: u32 = BOOT_TIME_COMPAT_MAX_SECS.saturating_sub(1); + let h = make_valid_tls_handshake(secret, ts); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + let result = validate_tls_handshake_with_replay_window(&h, &secrets, false, 300); + assert!( + result.is_some(), + "timestamp below replay-window cap must retain boot-time compatibility" + ); +} + +#[test] +fn large_replay_window_is_hard_capped_for_boot_compatibility() { + let secret = b"boot_cap_hard_limit_test"; + let ts: u32 = BOOT_TIME_COMPAT_MAX_SECS + 1; + let h = make_valid_tls_handshake(secret, ts); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + let result = validate_tls_handshake_with_replay_window(&h, &secrets, false, u64::MAX); + assert!( + result.is_none(), + "very large replay window must not expand boot-time bypass beyond hard compatibility cap" + ); +} + +#[test] +fn ignore_time_skew_explicitly_decouples_from_boot_time_cap() { + let secret = b"ignore_skew_boot_cap_decouple_test"; + let ts: u32 = 1; + let h = make_valid_tls_handshake(secret, ts); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + let cap_zero = validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, true, 0, 0); + let cap_nonzero = + validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, true, 0, BOOT_TIME_COMPAT_MAX_SECS); + + assert!(cap_zero.is_some(), "ignore_time_skew=true must accept valid HMAC"); + assert!( + cap_nonzero.is_some(), + "ignore_time_skew path must not depend on boot-time cap" + ); + + let a = cap_zero.unwrap(); + let b = cap_nonzero.unwrap(); + assert_eq!(a.user, b.user); + assert_eq!(a.timestamp, b.timestamp); +} + +#[test] +fn adversarial_small_boot_timestamp_matrix_rejected_when_boot_cap_forced_zero() { + let secret = b"boot_cap_zero_matrix_test"; + let secrets = vec![("u".to_string(), secret.to_vec())]; + let now: i64 = 1_700_000_000; + + for ts in 0u32..1024u32 { + let h = make_valid_tls_handshake(secret, ts); + let result = validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, false, now, 0); + assert!( + result.is_none(), + "boot cap=0 must reject timestamp {ts} when skew checks are active" + ); + } +} + +#[test] +fn light_fuzz_boot_cap_zero_rejects_small_timestamp_space() { + let secret = b"boot_cap_zero_fuzz_test"; + let secrets = vec![("u".to_string(), secret.to_vec())]; + let now: i64 = 1_700_000_000; + let mut s: u64 = 0x9E37_79B9_7F4A_7C15; + + for _ in 0..4096 { + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + let ts = (s as u32) % 2048; + + let h = make_valid_tls_handshake(secret, ts); + let result = validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, false, now, 0); + assert!( + result.is_none(), + "fuzzed boot-range timestamp {ts} must be rejected when cap=0" + ); + } +} + +#[test] +fn stress_boot_cap_zero_rejection_is_deterministic_under_high_iteration_count() { + let secret = b"boot_cap_zero_stress_test"; + let secrets = vec![("u".to_string(), secret.to_vec())]; + let now: i64 = 1_700_000_000; + + for i in 0u32..20_000u32 { + let ts = i % 4096; + let h = make_valid_tls_handshake(secret, ts); + let result = validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, false, now, 0); + assert!( + result.is_none(), + "iteration {i}: timestamp {ts} must be rejected with cap=0" + ); + } +} + +#[test] +fn replay_window_one_allows_only_zero_timestamp_boot_bypass() { + let secret = b"replay_window_one_boot_test"; + let secrets = vec![("u".to_string(), secret.to_vec())]; + + let ts0 = make_valid_tls_handshake(secret, 0); + let ts1 = make_valid_tls_handshake(secret, 1); + + assert!( + validate_tls_handshake_with_replay_window(&ts0, &secrets, false, 1).is_some(), + "replay_window=1 must allow timestamp 0 via boot-time compatibility" + ); + assert!( + validate_tls_handshake_with_replay_window(&ts1, &secrets, false, 1).is_none(), + "replay_window=1 must reject timestamp 1 on normal wall-clock systems" + ); +} + +#[test] +fn replay_window_two_allows_ts0_ts1_but_rejects_ts2() { + let secret = b"replay_window_two_boot_test"; + let secrets = vec![("u".to_string(), secret.to_vec())]; + + let ts0 = make_valid_tls_handshake(secret, 0); + let ts1 = make_valid_tls_handshake(secret, 1); + let ts2 = make_valid_tls_handshake(secret, 2); + + assert!(validate_tls_handshake_with_replay_window(&ts0, &secrets, false, 2).is_some()); + assert!(validate_tls_handshake_with_replay_window(&ts1, &secrets, false, 2).is_some()); + assert!( + validate_tls_handshake_with_replay_window(&ts2, &secrets, false, 2).is_none(), + "timestamp equal to replay-window cap must not use boot-time bypass" + ); +} + +#[test] +fn adversarial_skew_boundary_matrix_accepts_only_inclusive_window_when_boot_disabled() { + let secret = b"skew_boundary_matrix_test"; + let secrets = vec![("u".to_string(), secret.to_vec())]; + let now: i64 = 1_700_000_000; + + for offset in -1500i64..=1500i64 { + let ts_i64 = now - offset; + let ts = u32::try_from(ts_i64).expect("timestamp must fit u32 for test matrix"); + let h = make_valid_tls_handshake(secret, ts); + let accepted = validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, false, now, 0) + .is_some(); + let expected = (TIME_SKEW_MIN..=TIME_SKEW_MAX).contains(&offset); + assert_eq!( + accepted, expected, + "offset {offset} must match inclusive skew window when boot bypass is disabled" + ); + } +} + +#[test] +fn light_fuzz_skew_window_rejects_outside_range_when_boot_disabled() { + let secret = b"skew_outside_fuzz_test"; + let secrets = vec![("u".to_string(), secret.to_vec())]; + let now: i64 = 1_700_000_000; + let mut s: u64 = 0x0123_4567_89AB_CDEF; + + for _ in 0..4096 { + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + + let magnitude = 1300i64 + ((s % 2000u64) as i64); + let sign = if (s & 1) == 0 { 1i64 } else { -1i64 }; + let offset = sign * magnitude; + let ts_i64 = now - offset; + let ts = u32::try_from(ts_i64).expect("timestamp must fit u32 for fuzz test"); + + let h = make_valid_tls_handshake(secret, ts); + let accepted = validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, false, now, 0) + .is_some(); + assert!( + !accepted, + "offset {offset} must be rejected outside strict skew window" + ); + } +} + +#[test] +fn stress_boot_disabled_validation_matches_time_diff_oracle() { + let secret = b"boot_disabled_oracle_stress_test"; + let secrets = vec![("u".to_string(), secret.to_vec())]; + let now: i64 = 1_700_000_000; + let mut s: u64 = 0xBADC_0FFE_EE11_2233; + + for _ in 0..25_000 { + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + let ts = s as u32; + let h = make_valid_tls_handshake(secret, ts); + + let accepted = validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, false, now, 0) + .is_some(); + let time_diff = now - i64::from(ts); + let expected = (TIME_SKEW_MIN..=TIME_SKEW_MAX).contains(&time_diff); + assert_eq!( + accepted, expected, + "boot-disabled validation must match pure time-diff oracle" + ); + } +} + +#[test] +fn integration_large_user_list_with_boot_disabled_finds_only_matching_user() { + let now: i64 = 1_700_000_000; + let target_secret = b"target_user_secret"; + let target_ts = (now - 1) as u32; + let handshake = make_valid_tls_handshake(target_secret, target_ts); + + let mut secrets = Vec::new(); + for i in 0..512u32 { + secrets.push((format!("noise-{i}"), format!("noise-secret-{i}").into_bytes())); + } + secrets.push(("target-user".to_string(), target_secret.to_vec())); + + let result = validate_tls_handshake_at_time_with_boot_cap(&handshake, &secrets, false, now, 0) + .expect("matching user should validate within strict skew window"); + assert_eq!(result.user, "target-user"); +} + +#[test] +fn light_fuzz_ignore_time_skew_accepts_wide_timestamp_range_with_valid_hmac() { + let secret = b"ignore_skew_fuzz_accept_test"; + let secrets = vec![("u".to_string(), secret.to_vec())]; + let mut s: u64 = 0xC0FF_EE11_2233_4455; + + for _ in 0..2048 { + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + let ts = s as u32; + + let h = make_valid_tls_handshake(secret, ts); + let result = validate_tls_handshake_with_replay_window(&h, &secrets, true, 60); + assert!( + result.is_some(), + "ignore_time_skew=true must accept valid HMAC for arbitrary timestamp" + ); + } +} + +#[test] +fn light_fuzz_small_replay_window_rejects_far_timestamps_when_skew_enabled() { + let secret = b"replay_window_reject_fuzz_test"; + let secrets = vec![("u".to_string(), secret.to_vec())]; + + for ts in 300u32..=1323u32 { + let h = make_valid_tls_handshake(secret, ts); + let result = validate_tls_handshake_at_time_with_boot_cap(&h, &secrets, false, 0, 300); + assert!( + result.is_none(), + "with skew checks enabled and boot cap=300, timestamp >=300 at now=0 must be rejected" + ); + } +} + +// ------------------------------------------------------------------ +// Extreme timestamp values +// ------------------------------------------------------------------ + +/// u32::MAX is a valid timestamp value. When ignore_time_skew=true the HMAC +/// is the only gate, and a correctly constructed handshake must be accepted. +#[test] +fn u32_max_timestamp_accepted_with_ignore_time_skew() { + let secret = b"u32_max_ts_accept_test"; + let h = make_valid_tls_handshake(secret, u32::MAX); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + let result = validate_tls_handshake(&h, &secrets, true); + assert!(result.is_some(), "u32::MAX timestamp must be accepted with ignore_time_skew=true"); + assert_eq!( + result.unwrap().timestamp, + u32::MAX, + "timestamp field must equal u32::MAX verbatim" + ); +} + +/// u32::MAX > BOOT_TIME_MAX_SECS so the skew check runs. With any realistic `now` +/// (~1.7 billion), time_diff = now - u32::MAX is deeply negative — far outside +/// [-1200, 600] — so the handshake must be rejected without overflow. +#[test] +fn u32_max_timestamp_rejected_by_skew_enforcement() { + let secret = b"u32_max_ts_reject_test"; + let h = make_valid_tls_handshake(secret, u32::MAX); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + let now: i64 = 1_700_000_000; + assert!( + validate_tls_handshake_at_time(&h, &secrets, false, now).is_none(), + "u32::MAX timestamp must be rejected by skew check with realistic now" + ); +} + +// ------------------------------------------------------------------ +// Validation result field correctness +// ------------------------------------------------------------------ + +/// result.digest must be the verbatim bytes stored in the handshake buffer, +/// not the freshly recomputed HMAC. Callers use this field directly when +/// constructing the ServerHello response digest. +#[test] +fn result_digest_field_is_verbatim_stored_digest() { + let secret = b"digest_field_verbatim_test"; + let ts: u32 = 0xCAFE_BABE; + let h = make_valid_tls_handshake(secret, ts); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + let result = validate_tls_handshake(&h, &secrets, true).unwrap(); + + let stored: [u8; TLS_DIGEST_LEN] = h[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN] + .try_into() + .unwrap(); + assert_eq!( + result.digest, stored, + "result.digest must equal the stored bytes, not the computed HMAC" + ); +} + +// ------------------------------------------------------------------ +// Secret length edge cases +// ------------------------------------------------------------------ + +/// HMAC-SHA256 pads or hashes keys of any length; a single-byte key must work. +#[test] +fn single_byte_secret_works() { + let secret = b"x"; + let h = make_valid_tls_handshake(secret, 0); + let secrets = vec![("u".to_string(), secret.to_vec())]; + assert!( + validate_tls_handshake(&h, &secrets, true).is_some(), + "single-byte secret must produce a valid and verifiable HMAC" + ); +} + +/// Keys longer than the HMAC block size (64 bytes for SHA-256) are hashed +/// before use. A 256-byte key must work without truncation or panic. +#[test] +fn very_long_secret_256_bytes_works() { + let secret = vec![0xABu8; 256]; + let h = make_valid_tls_handshake(&secret, 0); + let secrets = vec![("u".to_string(), secret.clone())]; + assert!( + validate_tls_handshake(&h, &secrets, true).is_some(), + "256-byte secret must be accepted without truncation" + ); +} + +// ------------------------------------------------------------------ +// Determinism — same input must always produce same result +// ------------------------------------------------------------------ + +/// Calling validate twice on the same input must return identical results. +/// Non-determinism (e.g. from an accidentally global mutable state or a +/// shared nonce) would be a critical security defect in a proxy that rejects +/// censors by relying on stable authentication outcomes. +#[test] +fn validation_is_deterministic() { + let secret = b"determinism_test_key"; + let h = make_valid_tls_handshake(secret, 42); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + let r1 = validate_tls_handshake(&h, &secrets, true).unwrap(); + let r2 = validate_tls_handshake(&h, &secrets, true).unwrap(); + + assert_eq!(r1.user, r2.user); + assert_eq!(r1.session_id, r2.session_id); + assert_eq!(r1.digest, r2.digest); + assert_eq!(r1.timestamp, r2.timestamp); +} + +// ------------------------------------------------------------------ +// Multi-user: scan-all correctness guarantees +// ------------------------------------------------------------------ + +/// The matching logic must scan through the entire secrets list. A user +/// at position 99 of 100 must be found; an implementation that stops early +/// on the first non-match would fail this test. +#[test] +fn last_user_in_large_list_is_found() { + let target_secret = b"needle_in_haystack"; + let h = make_valid_tls_handshake(target_secret, 0); + + let mut secrets: Vec<(String, Vec)> = (0..99) + .map(|i| (format!("decoy_{i}"), format!("wrong_{i}").into_bytes())) + .collect(); + secrets.push(("needle".to_string(), target_secret.to_vec())); + + let result = validate_tls_handshake(&h, &secrets, true); + assert!(result.is_some(), "100th user must be found"); + assert_eq!(result.unwrap().user, "needle"); +} + +/// When multiple users share the same secret the first occurrence must always +/// win. The scan-all loop must not replace first_match with a later one. +#[test] +fn first_matching_user_wins_over_later_duplicate_secret() { + let shared = b"duplicated_secret_key"; + let h = make_valid_tls_handshake(shared, 0); + + let secrets = vec![ + ("decoy_1".to_string(), b"wrong_1".to_vec()), + ("winner".to_string(), shared.to_vec()), // first match + ("decoy_2".to_string(), b"wrong_2".to_vec()), + ("loser".to_string(), shared.to_vec()), // second match — must not win + ("decoy_3".to_string(), b"wrong_3".to_vec()), + ]; + + let result = validate_tls_handshake(&h, &secrets, true); + assert!(result.is_some()); + assert_eq!( + result.unwrap().user, "winner", + "first matching user must be returned even when a later entry also matches" + ); +} + +// ------------------------------------------------------------------ +// Legacy tls.rs tests moved here +// ------------------------------------------------------------------ + +#[test] +fn test_is_tls_handshake() { + assert!(is_tls_handshake(&[0x16, 0x03, 0x01])); + assert!(is_tls_handshake(&[0x16, 0x03, 0x03])); + assert!(is_tls_handshake(&[0x16, 0x03, 0x01, 0x02, 0x00])); + assert!(is_tls_handshake(&[0x16, 0x03, 0x03, 0x02, 0x00])); + assert!(!is_tls_handshake(&[0x17, 0x03, 0x01])); + assert!(!is_tls_handshake(&[0x16, 0x03, 0x02])); + assert!(!is_tls_handshake(&[0x16, 0x03])); +} + +#[test] +fn test_parse_tls_record_header() { + let header = [0x16, 0x03, 0x01, 0x02, 0x00]; + let result = parse_tls_record_header(&header).unwrap(); + assert_eq!(result.0, TLS_RECORD_HANDSHAKE); + assert_eq!(result.1, 512); + + let header = [0x17, 0x03, 0x03, 0x40, 0x00]; + let result = parse_tls_record_header(&header).unwrap(); + assert_eq!(result.0, TLS_RECORD_APPLICATION); + assert_eq!(usize::from(result.1), MAX_TLS_PLAINTEXT_SIZE); +} + +#[test] +fn parse_tls_record_header_rejects_invalid_versions() { + let invalid = [ + [0x16, 0x03, 0x00, 0x00, 0x10], + [0x16, 0x02, 0x00, 0x00, 0x10], + [0x16, 0x03, 0x02, 0x00, 0x10], + [0x16, 0x04, 0x00, 0x00, 0x10], + ]; + for header in invalid { + assert!( + parse_tls_record_header(&header).is_none(), + "invalid TLS record version {:?} must be rejected", + [header[1], header[2]] + ); + } +} + +#[test] +fn test_gen_fake_x25519_key() { + let rng = crate::crypto::SecureRandom::new(); + let key1 = gen_fake_x25519_key(&rng); + let key2 = gen_fake_x25519_key(&rng); + + assert_eq!(key1.len(), 32); + assert_eq!(key2.len(), 32); + assert_ne!(key1, key2); +} + +#[test] +fn test_fake_x25519_key_is_nonzero_and_varies() { + let rng = crate::crypto::SecureRandom::new(); + let mut unique = std::collections::HashSet::new(); + let mut saw_non_zero = false; + + for _ in 0..64 { + let key = gen_fake_x25519_key(&rng); + if key != [0u8; 32] { + saw_non_zero = true; + } + unique.insert(key); + } + + assert!( + saw_non_zero, + "generated X25519 public keys must not collapse to all-zero output" + ); + assert!( + unique.len() > 1, + "generated X25519 public keys must vary across invocations" + ); +} + +#[test] +fn validate_tls_handshake_rejects_session_id_longer_than_rfc_cap() { + let secret = b"session_id_cap_secret"; + let oversized_sid = vec![0x42u8; 33]; + let handshake = make_valid_tls_handshake_with_session_id(secret, 0, &oversized_sid); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + assert!( + validate_tls_handshake(&handshake, &secrets, true).is_none(), + "legacy_session_id length > 32 must be rejected" + ); +} + +fn server_hello_extension_types(record: &[u8]) -> Vec { + if record.len() < 9 || record[0] != TLS_RECORD_HANDSHAKE || record[5] != 0x02 { + return Vec::new(); + } + + let record_len = u16::from_be_bytes([record[3], record[4]]) as usize; + if record.len() < 5 + record_len { + return Vec::new(); + } + + let hs_len = u32::from_be_bytes([0, record[6], record[7], record[8]]) as usize; + let hs_start = 5; + let hs_end = hs_start + 4 + hs_len; + if hs_end > record.len() { + return Vec::new(); + } + + let mut pos = hs_start + 4 + 2 + 32; + if pos >= hs_end { + return Vec::new(); + } + let sid_len = record[pos] as usize; + pos += 1 + sid_len; + if pos + 2 + 1 + 2 > hs_end { + return Vec::new(); + } + + pos += 2 + 1; + let ext_len = u16::from_be_bytes([record[pos], record[pos + 1]]) as usize; + pos += 2; + let ext_end = pos + ext_len; + if ext_end > hs_end { + return Vec::new(); + } + + let mut out = Vec::new(); + while pos + 4 <= ext_end { + let etype = u16::from_be_bytes([record[pos], record[pos + 1]]); + let elen = u16::from_be_bytes([record[pos + 2], record[pos + 3]]) as usize; + pos += 4; + if pos + elen > ext_end { + break; + } + out.push(etype); + pos += elen; + } + out +} + +#[test] +fn build_server_hello_never_places_alpn_in_server_hello_extensions() { + let secret = b"alpn_sh_forbidden"; + let client_digest = [0x11u8; 32]; + let session_id = vec![0xAA; 32]; + let rng = crate::crypto::SecureRandom::new(); + + let response = build_server_hello( + secret, + &client_digest, + &session_id, + 1024, + &rng, + Some(b"h2".to_vec()), + 0, + ); + let exts = server_hello_extension_types(&response); + assert!( + !exts.contains(&0x0010), + "ALPN extension must not appear in ServerHello" + ); +} + +#[test] +fn emulated_server_hello_never_places_alpn_in_server_hello_extensions() { + let secret = b"alpn_emulated_forbidden"; + let client_digest = [0x22u8; 32]; + let session_id = vec![0xAB; 32]; + let rng = crate::crypto::SecureRandom::new(); + let cached = CachedTlsData { + server_hello_template: ParsedServerHello { + version: TLS_VERSION, + random: [0u8; 32], + session_id: Vec::new(), + cipher_suite: [0x13, 0x01], + compression: 0, + extensions: Vec::new(), + }, + cert_info: None, + cert_payload: None, + app_data_records_sizes: vec![1024], + total_app_data_len: 1024, + behavior_profile: TlsBehaviorProfile { + change_cipher_spec_count: 1, + app_data_record_sizes: vec![1024], + ticket_record_sizes: Vec::new(), + source: TlsProfileSource::Default, + }, + fetched_at: SystemTime::now(), + domain: "example.com".to_string(), + }; + + let response = build_emulated_server_hello( + secret, + &client_digest, + &session_id, + &cached, + false, + &rng, + Some(b"h2".to_vec()), + 0, + ); + let exts = server_hello_extension_types(&response); + assert!( + !exts.contains(&0x0010), + "ALPN extension must not appear in emulated ServerHello" + ); +} + +#[test] +fn test_tls_extension_builder() { + let key = [0x42u8; 32]; + + let mut builder = TlsExtensionBuilder::new(); + builder.add_key_share(&key); + builder.add_supported_versions(0x0304); + + let result = builder.build(); + let len = u16::from_be_bytes([result[0], result[1]]) as usize; + + assert_eq!(len, result.len() - 2); + assert!(result.len() > 40); +} + +#[test] +fn test_server_hello_builder() { + let session_id = vec![0x01, 0x02, 0x03, 0x04]; + let key = [0x55u8; 32]; + + let builder = ServerHelloBuilder::new(session_id.clone()) + .with_x25519_key(&key) + .with_tls13_version(); + + let record = builder.build_record(); + validate_server_hello_structure(&record).expect("Invalid ServerHello structure"); + + assert_eq!(record[0], TLS_RECORD_HANDSHAKE); + assert_eq!(&record[1..3], &TLS_VERSION); + assert_eq!(record[5], 0x02); +} + +#[test] +fn test_build_server_hello_structure() { + let secret = b"test secret"; + let client_digest = [0x42u8; 32]; + let session_id = vec![0xAA; 32]; + + let rng = crate::crypto::SecureRandom::new(); + let response = build_server_hello(secret, &client_digest, &session_id, 2048, &rng, None, 0); + + assert!(response.len() > 100); + assert_eq!(response[0], TLS_RECORD_HANDSHAKE); + validate_server_hello_structure(&response).expect("Invalid ServerHello"); + + let server_hello_len = 5 + u16::from_be_bytes([response[3], response[4]]) as usize; + let ccs_start = server_hello_len; + assert!(response.len() > ccs_start + 6); + assert_eq!(response[ccs_start], TLS_RECORD_CHANGE_CIPHER); + + let ccs_len = 5 + u16::from_be_bytes([response[ccs_start + 3], response[ccs_start + 4]]) as usize; + let app_start = ccs_start + ccs_len; + assert!(response.len() > app_start + 5); + assert_eq!(response[app_start], TLS_RECORD_APPLICATION); +} + +#[test] +fn test_build_server_hello_digest() { + let secret = b"test secret key here"; + let client_digest = [0x42u8; 32]; + let session_id = vec![0xAA; 32]; + + let rng = crate::crypto::SecureRandom::new(); + let response1 = build_server_hello(secret, &client_digest, &session_id, 1024, &rng, None, 0); + let response2 = build_server_hello(secret, &client_digest, &session_id, 1024, &rng, None, 0); + + let digest1 = &response1[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN]; + assert!(!digest1.iter().all(|&b| b == 0)); + + let digest2 = &response2[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN]; + assert_ne!(digest1, digest2); +} + +#[test] +fn test_server_hello_extensions_length() { + let session_id = vec![0x01; 32]; + let key = [0x55u8; 32]; + + let builder = ServerHelloBuilder::new(session_id) + .with_x25519_key(&key) + .with_tls13_version(); + + let record = builder.build_record(); + let msg_start = 5; + let msg_len = u32::from_be_bytes([0, record[6], record[7], record[8]]) as usize; + let session_id_pos = msg_start + 4 + 2 + 32; + let session_id_len = record[session_id_pos] as usize; + let ext_len_pos = session_id_pos + 1 + session_id_len + 2 + 1; + let ext_len = u16::from_be_bytes([record[ext_len_pos], record[ext_len_pos + 1]]) as usize; + let extensions_data = &record[ext_len_pos + 2..msg_start + 4 + msg_len]; + + assert_eq!( + ext_len, + extensions_data.len(), + "Extension length mismatch: declared {}, actual {}", + ext_len, + extensions_data.len() + ); +} + +#[test] +fn test_validate_tls_handshake_format() { + let mut handshake = vec![0u8; 100]; + handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].copy_from_slice(&[0x42; 32]); + handshake[TLS_DIGEST_POS + TLS_DIGEST_LEN] = 32; + + let secrets = vec![("test".to_string(), b"secret".to_vec())]; + let result = validate_tls_handshake(&handshake, &secrets, true); + assert!(result.is_none()); +} + +fn build_client_hello_with_exts(exts: Vec<(u16, Vec)>, host: &str) -> Vec { + let mut body = Vec::new(); + body.extend_from_slice(&TLS_VERSION); + body.extend_from_slice(&[0u8; 32]); + body.push(0); + body.extend_from_slice(&2u16.to_be_bytes()); + body.extend_from_slice(&[0x13, 0x01]); + body.push(1); + body.push(0); + + let host_bytes = host.as_bytes(); + let mut sni_ext = Vec::new(); + sni_ext.extend_from_slice(&(host_bytes.len() as u16 + 3).to_be_bytes()); + sni_ext.push(0); + sni_ext.extend_from_slice(&(host_bytes.len() as u16).to_be_bytes()); + sni_ext.extend_from_slice(host_bytes); + + let mut ext_blob = Vec::new(); + for (typ, data) in exts { + ext_blob.extend_from_slice(&typ.to_be_bytes()); + ext_blob.extend_from_slice(&(data.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&data); + } + ext_blob.extend_from_slice(&0x0000u16.to_be_bytes()); + ext_blob.extend_from_slice(&(sni_ext.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&sni_ext); + + body.extend_from_slice(&(ext_blob.len() as u16).to_be_bytes()); + body.extend_from_slice(&ext_blob); + + let mut handshake = Vec::new(); + handshake.push(0x01); + let len_bytes = (body.len() as u32).to_be_bytes(); + handshake.extend_from_slice(&len_bytes[1..4]); + handshake.extend_from_slice(&body); + + let mut record = Vec::new(); + record.push(TLS_RECORD_HANDSHAKE); + record.extend_from_slice(&[0x03, 0x01]); + record.extend_from_slice(&(handshake.len() as u16).to_be_bytes()); + record.extend_from_slice(&handshake); + record +} + +fn build_client_hello_with_raw_extensions(ext_blob: &[u8]) -> Vec { + let mut body = Vec::new(); + body.extend_from_slice(&TLS_VERSION); + body.extend_from_slice(&[0u8; 32]); + body.push(0); + body.extend_from_slice(&2u16.to_be_bytes()); + body.extend_from_slice(&[0x13, 0x01]); + body.push(1); + body.push(0); + body.extend_from_slice(&(ext_blob.len() as u16).to_be_bytes()); + body.extend_from_slice(ext_blob); + + let mut handshake = Vec::new(); + handshake.push(0x01); + let len_bytes = (body.len() as u32).to_be_bytes(); + handshake.extend_from_slice(&len_bytes[1..4]); + handshake.extend_from_slice(&body); + + let mut record = Vec::new(); + record.push(TLS_RECORD_HANDSHAKE); + record.extend_from_slice(&[0x03, 0x01]); + record.extend_from_slice(&(handshake.len() as u16).to_be_bytes()); + record.extend_from_slice(&handshake); + record +} + +#[test] +fn test_extract_sni_with_grease_extension() { + let ch = build_client_hello_with_exts(vec![(0x0a0a, Vec::new())], "example.com"); + let sni = extract_sni_from_client_hello(&ch); + assert_eq!(sni.as_deref(), Some("example.com")); +} + +#[test] +fn test_extract_sni_tolerates_empty_unknown_extension() { + let ch = build_client_hello_with_exts(vec![(0x1234, Vec::new())], "test.local"); + let sni = extract_sni_from_client_hello(&ch); + assert_eq!(sni.as_deref(), Some("test.local")); +} + +#[test] +fn test_extract_alpn_single() { + let mut alpn_data = Vec::new(); + alpn_data.extend_from_slice(&3u16.to_be_bytes()); + alpn_data.push(2); + alpn_data.extend_from_slice(b"h2"); + let ch = build_client_hello_with_exts(vec![(0x0010, alpn_data)], "alpn.test"); + let alpn = extract_alpn_from_client_hello(&ch); + let alpn_str: Vec = alpn + .iter() + .map(|p| std::str::from_utf8(p).unwrap().to_string()) + .collect(); + assert_eq!(alpn_str, vec!["h2"]); +} + +#[test] +fn test_extract_alpn_multiple() { + let mut alpn_data = Vec::new(); + alpn_data.extend_from_slice(&11u16.to_be_bytes()); + alpn_data.push(2); + alpn_data.extend_from_slice(b"h2"); + alpn_data.push(4); + alpn_data.extend_from_slice(b"spdy"); + alpn_data.push(2); + alpn_data.extend_from_slice(b"h3"); + let ch = build_client_hello_with_exts(vec![(0x0010, alpn_data)], "alpn.test"); + let alpn = extract_alpn_from_client_hello(&ch); + let alpn_str: Vec = alpn + .iter() + .map(|p| std::str::from_utf8(p).unwrap().to_string()) + .collect(); + assert_eq!(alpn_str, vec!["h2", "spdy", "h3"]); +} + +#[test] +fn extract_sni_rejects_zero_length_host_name() { + let mut sni_ext = Vec::new(); + sni_ext.extend_from_slice(&3u16.to_be_bytes()); + sni_ext.push(0); + sni_ext.extend_from_slice(&0u16.to_be_bytes()); + + let mut ext_blob = Vec::new(); + ext_blob.extend_from_slice(&0x0000u16.to_be_bytes()); + ext_blob.extend_from_slice(&(sni_ext.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&sni_ext); + + let ch = build_client_hello_with_raw_extensions(&ext_blob); + assert!(extract_sni_from_client_hello(&ch).is_none()); +} + +#[test] +fn extract_sni_rejects_raw_ipv4_literals() { + let ch = build_client_hello_with_exts(Vec::new(), "203.0.113.10"); + assert!(extract_sni_from_client_hello(&ch).is_none()); +} + +#[test] +fn extract_sni_rejects_invalid_label_characters() { + let ch = build_client_hello_with_exts(Vec::new(), "exa_mple.com"); + assert!(extract_sni_from_client_hello(&ch).is_none()); +} + +#[test] +fn extract_sni_rejects_oversized_label() { + let oversized = format!("{}.example.com", "a".repeat(64)); + let ch = build_client_hello_with_exts(Vec::new(), &oversized); + assert!(extract_sni_from_client_hello(&ch).is_none()); +} + +#[test] +fn extract_sni_rejects_when_extension_block_is_truncated() { + let mut ext_blob = Vec::new(); + ext_blob.extend_from_slice(&0x0000u16.to_be_bytes()); + ext_blob.extend_from_slice(&5u16.to_be_bytes()); + ext_blob.extend_from_slice(&[0, 3, 0]); + + let mut ch = build_client_hello_with_raw_extensions(&ext_blob); + ch.pop(); + assert!(extract_sni_from_client_hello(&ch).is_none()); +} + +#[test] +fn extract_sni_rejects_session_id_len_overflow() { + let mut ch = build_client_hello_with_exts(Vec::new(), "example.com"); + let sid_len_pos = 5 + 4 + 2 + 32; + ch[sid_len_pos] = 255; + assert!(extract_sni_from_client_hello(&ch).is_none()); +} + +#[test] +fn extract_sni_rejects_cipher_suites_len_overflow() { + let mut ch = build_client_hello_with_exts(Vec::new(), "example.com"); + let sid_len_pos = 5 + 4 + 2 + 32; + let cipher_len_pos = sid_len_pos + 1 + ch[sid_len_pos] as usize; + ch[cipher_len_pos] = 0xFF; + ch[cipher_len_pos + 1] = 0xFF; + assert!(extract_sni_from_client_hello(&ch).is_none()); +} + +#[test] +fn extract_sni_rejects_compression_methods_len_overflow() { + let mut ch = build_client_hello_with_exts(Vec::new(), "example.com"); + let sid_len_pos = 5 + 4 + 2 + 32; + let cipher_len_pos = sid_len_pos + 1 + ch[sid_len_pos] as usize; + let cipher_len = u16::from_be_bytes([ch[cipher_len_pos], ch[cipher_len_pos + 1]]) as usize; + let comp_len_pos = cipher_len_pos + 2 + cipher_len; + ch[comp_len_pos] = 0xFF; + assert!(extract_sni_from_client_hello(&ch).is_none()); +} + +#[test] +fn extract_alpn_returns_empty_on_session_id_len_overflow() { + let mut alpn_data = Vec::new(); + alpn_data.extend_from_slice(&3u16.to_be_bytes()); + alpn_data.push(2); + alpn_data.extend_from_slice(b"h2"); + let mut ch = build_client_hello_with_exts(vec![(0x0010, alpn_data)], "alpn.test"); + let sid_len_pos = 5 + 4 + 2 + 32; + ch[sid_len_pos] = 255; + assert!(extract_alpn_from_client_hello(&ch).is_empty()); +} + +#[test] +fn extract_alpn_rejects_when_extension_block_is_truncated() { + let mut ext_blob = Vec::new(); + ext_blob.extend_from_slice(&0x0010u16.to_be_bytes()); + ext_blob.extend_from_slice(&5u16.to_be_bytes()); + ext_blob.extend_from_slice(&[0, 3, 2, b'h']); + + let ch = build_client_hello_with_raw_extensions(&ext_blob); + assert!(extract_alpn_from_client_hello(&ch).is_empty()); +} + +#[test] +fn extract_alpn_rejects_nested_length_overflow() { + let mut alpn_data = Vec::new(); + alpn_data.extend_from_slice(&10u16.to_be_bytes()); + alpn_data.push(8); + alpn_data.extend_from_slice(b"h2"); + + let mut ext_blob = Vec::new(); + ext_blob.extend_from_slice(&0x0010u16.to_be_bytes()); + ext_blob.extend_from_slice(&(alpn_data.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&alpn_data); + + let ch = build_client_hello_with_raw_extensions(&ext_blob); + assert!(extract_alpn_from_client_hello(&ch).is_empty()); +} + +// ------------------------------------------------------------------ +// Additional adversarial checks +// ------------------------------------------------------------------ + +#[test] +fn empty_secret_hmac_is_supported() { + let secret: &[u8] = b""; + let handshake = make_valid_tls_handshake(secret, 0); + let secrets = vec![("empty".to_string(), secret.to_vec())]; + let result = validate_tls_handshake(&handshake, &secrets, true); + assert!(result.is_some(), "Empty HMAC key must not panic and must validate when correct"); +} + +#[test] +fn server_hello_digest_verifies_against_full_response() { + let secret = b"fronting_digest_verify_key"; + let client_digest = [0x42u8; TLS_DIGEST_LEN]; + let session_id = vec![0xAA; 32]; + let rng = crate::crypto::SecureRandom::new(); + + let response = build_server_hello(secret, &client_digest, &session_id, 1024, &rng, None, 1); + let mut zeroed = response.clone(); + zeroed[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].fill(0); + + let mut hmac_input = Vec::with_capacity(TLS_DIGEST_LEN + zeroed.len()); + hmac_input.extend_from_slice(&client_digest); + hmac_input.extend_from_slice(&zeroed); + let expected = sha256_hmac(secret, &hmac_input); + + assert_eq!( + &response[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN], + &expected, + "ServerHello digest must be verifiable by a client that recomputes HMAC over full response" + ); +} + +#[test] +fn server_hello_digest_fails_after_single_byte_tamper() { + let secret = b"fronting_tamper_detect_key"; + let client_digest = [0x24u8; TLS_DIGEST_LEN]; + let session_id = vec![0xBB; 32]; + let rng = crate::crypto::SecureRandom::new(); + + let mut response = build_server_hello(secret, &client_digest, &session_id, 1024, &rng, None, 0); + response[TLS_DIGEST_POS + TLS_DIGEST_LEN + 1] ^= 0x01; + + let mut zeroed = response.clone(); + zeroed[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].fill(0); + + let mut hmac_input = Vec::with_capacity(TLS_DIGEST_LEN + zeroed.len()); + hmac_input.extend_from_slice(&client_digest); + hmac_input.extend_from_slice(&zeroed); + let expected = sha256_hmac(secret, &hmac_input); + + assert_ne!( + &response[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN], + &expected, + "Tampering any response byte must invalidate the embedded digest" + ); +} + +#[test] +fn server_hello_application_data_payload_varies_across_runs() { + use std::collections::HashSet; + + let secret = b"fronting_payload_variability_key"; + let client_digest = [0x13u8; TLS_DIGEST_LEN]; + let session_id = vec![0x44; 32]; + let rng = crate::crypto::SecureRandom::new(); + + let mut unique_payloads: HashSet> = HashSet::new(); + for _ in 0..16 { + let response = build_server_hello(secret, &client_digest, &session_id, 1024, &rng, None, 0); + + let sh_len = u16::from_be_bytes([response[3], response[4]]) as usize; + let ccs_pos = 5 + sh_len; + let ccs_len = u16::from_be_bytes([response[ccs_pos + 3], response[ccs_pos + 4]]) as usize; + let app_pos = ccs_pos + 5 + ccs_len; + + assert_eq!(response[app_pos], TLS_RECORD_APPLICATION); + let app_len = u16::from_be_bytes([response[app_pos + 3], response[app_pos + 4]]) as usize; + let payload = response[app_pos + 5..app_pos + 5 + app_len].to_vec(); + + assert!(payload.iter().any(|&b| b != 0), "Payload must not be all-zero deterministic filler"); + unique_payloads.insert(payload); + } + + assert!( + unique_payloads.len() >= 4, + "ApplicationData payload should vary across runs to reduce fingerprintability" + ); +} + +#[test] +fn replay_window_zero_disables_boot_bypass_for_any_nonzero_timestamp() { + let secret = b"window_zero_boot_bypass_test"; + let secrets = vec![("u".to_string(), secret.to_vec())]; + + let ts1 = make_valid_tls_handshake(secret, 1); + assert!( + validate_tls_handshake_with_replay_window(&ts1, &secrets, false, 0).is_none(), + "replay_window_secs=0 must reject nonzero timestamps even in boot-time range" + ); + + let ts0 = make_valid_tls_handshake(secret, 0); + assert!( + validate_tls_handshake_with_replay_window(&ts0, &secrets, false, 0).is_none(), + "replay_window_secs=0 enforces strict skew check and rejects timestamp=0 on normal wall-clock systems" + ); +} + +#[test] +fn large_replay_window_does_not_expand_time_skew_acceptance() { + let secret = b"large_replay_window_skew_bound_test"; + let secrets = vec![("u".to_string(), secret.to_vec())]; + let now: i64 = 1_700_000_000; + + let ts_far_past = (now - 600) as u32; + let valid = make_valid_tls_handshake(secret, ts_far_past); + assert!( + validate_tls_handshake_with_replay_window(&valid, &secrets, false, 86_400).is_none(), + "large replay window must not relax strict skew check once boot-time bypass is not in play" + ); +} + +#[test] +fn parse_tls_record_header_accepts_tls_version_constant() { + let header = [TLS_RECORD_HANDSHAKE, TLS_VERSION[0], TLS_VERSION[1], 0x00, 0x2A]; + let parsed = parse_tls_record_header(&header).expect("TLS_VERSION header should be accepted"); + assert_eq!(parsed.0, TLS_RECORD_HANDSHAKE); + assert_eq!(parsed.1, 42); +} + +#[test] +fn server_hello_clamps_fake_cert_len_lower_bound() { + let secret = b"fake_cert_lower_bound_test"; + let client_digest = [0x11u8; TLS_DIGEST_LEN]; + let session_id = vec![0x77; 32]; + let rng = crate::crypto::SecureRandom::new(); + + let response = build_server_hello(secret, &client_digest, &session_id, 1, &rng, None, 0); + + let sh_len = u16::from_be_bytes([response[3], response[4]]) as usize; + let ccs_pos = 5 + sh_len; + let ccs_len = u16::from_be_bytes([response[ccs_pos + 3], response[ccs_pos + 4]]) as usize; + let app_pos = ccs_pos + 5 + ccs_len; + let app_len = u16::from_be_bytes([response[app_pos + 3], response[app_pos + 4]]) as usize; + + assert_eq!(response[app_pos], TLS_RECORD_APPLICATION); + assert_eq!(app_len, 64, "fake cert payload must be clamped to minimum 64 bytes"); +} + +#[test] +fn server_hello_clamps_fake_cert_len_upper_bound() { + let secret = b"fake_cert_upper_bound_test"; + let client_digest = [0x22u8; TLS_DIGEST_LEN]; + let session_id = vec![0x66; 32]; + let rng = crate::crypto::SecureRandom::new(); + + let response = build_server_hello(secret, &client_digest, &session_id, 65_535, &rng, None, 0); + + let sh_len = u16::from_be_bytes([response[3], response[4]]) as usize; + let ccs_pos = 5 + sh_len; + let ccs_len = u16::from_be_bytes([response[ccs_pos + 3], response[ccs_pos + 4]]) as usize; + let app_pos = ccs_pos + 5 + ccs_len; + let app_len = u16::from_be_bytes([response[app_pos + 3], response[app_pos + 4]]) as usize; + + assert_eq!(response[app_pos], TLS_RECORD_APPLICATION); + assert_eq!(app_len, MAX_TLS_CIPHERTEXT_SIZE, "fake cert payload must be clamped to TLS record max bound"); +} + +#[test] +fn server_hello_new_session_ticket_count_matches_configuration() { + let secret = b"ticket_count_surface_test"; + let client_digest = [0x33u8; TLS_DIGEST_LEN]; + let session_id = vec![0x55; 32]; + let rng = crate::crypto::SecureRandom::new(); + + let tickets: u8 = 3; + let response = build_server_hello(secret, &client_digest, &session_id, 1024, &rng, None, tickets); + + let mut pos = 0usize; + let mut app_records = 0usize; + while pos + 5 <= response.len() { + let rtype = response[pos]; + let rlen = u16::from_be_bytes([response[pos + 3], response[pos + 4]]) as usize; + let next = pos + 5 + rlen; + assert!(next <= response.len(), "TLS record must stay inside response bounds"); + if rtype == TLS_RECORD_APPLICATION { + app_records += 1; + } + pos = next; + } + + assert_eq!( + app_records, + 1 + tickets as usize, + "response must contain one main application record plus configured ticket-like tail records" + ); +} + +#[test] +fn server_hello_new_session_ticket_count_is_safely_capped() { + let secret = b"ticket_count_cap_test"; + let client_digest = [0x44u8; TLS_DIGEST_LEN]; + let session_id = vec![0x54; 32]; + let rng = crate::crypto::SecureRandom::new(); + + let response = build_server_hello(secret, &client_digest, &session_id, 1024, &rng, None, u8::MAX); + + let mut pos = 0usize; + let mut app_records = 0usize; + while pos + 5 <= response.len() { + let rtype = response[pos]; + let rlen = u16::from_be_bytes([response[pos + 3], response[pos + 4]]) as usize; + let next = pos + 5 + rlen; + assert!(next <= response.len(), "TLS record must stay inside response bounds"); + if rtype == TLS_RECORD_APPLICATION { + app_records += 1; + } + pos = next; + } + + assert_eq!( + app_records, + 5, + "response must cap ticket-like tail records to four plus one main application record" + ); +} + +#[test] +fn boot_time_handshake_replay_remains_blocked_after_cache_window_expires() { + let secret = b"gap_t01_boot_replay"; + let secrets = vec![("user".to_string(), secret.to_vec())]; + let handshake = make_valid_tls_handshake(secret, 1); + + let validation = validate_tls_handshake_with_replay_window(&handshake, &secrets, false, 2) + .expect("boot-time handshake must validate on first use"); + + let checker = crate::stats::ReplayChecker::new(128, std::time::Duration::from_millis(40)); + let digest_half = &validation.digest[..TLS_DIGEST_HALF_LEN]; + + assert!( + !checker.check_and_add_tls_digest(digest_half), + "first use must not be treated as replay" + ); + assert!( + checker.check_and_add_tls_digest(digest_half), + "immediate second use must be detected as replay" + ); + + std::thread::sleep(std::time::Duration::from_millis(70)); + + let validation_after_expiry = validate_tls_handshake_with_replay_window(&handshake, &secrets, false, 2) + .expect("boot-time handshake must still cryptographically validate after cache expiry"); + let digest_half_after_expiry = &validation_after_expiry.digest[..TLS_DIGEST_HALF_LEN]; + assert_eq!(digest_half, digest_half_after_expiry, "replay key must be stable for same handshake"); + + assert!( + checker.check_and_add_tls_digest(digest_half_after_expiry), + "after cache window expiry, the same boot-time handshake must still be treated as replay" + ); +} + +#[test] +fn adversarial_boot_time_handshake_should_not_be_replayable_after_cache_expiry() { + let secret = b"gap_t01_boot_replay_adversarial"; + let secrets = vec![("user".to_string(), secret.to_vec())]; + let handshake = make_valid_tls_handshake(secret, 1); + + let validation = validate_tls_handshake_with_replay_window(&handshake, &secrets, false, 2) + .expect("boot-time handshake must validate on first use"); + + let checker = crate::stats::ReplayChecker::new(128, std::time::Duration::from_millis(40)); + let digest_half = &validation.digest[..TLS_DIGEST_HALF_LEN]; + + assert!( + !checker.check_and_add_tls_digest(digest_half), + "first use must not be treated as replay" + ); + assert!( + checker.check_and_add_tls_digest(digest_half), + "immediate reuse must be rejected as replay" + ); + + std::thread::sleep(std::time::Duration::from_millis(70)); + + let validation_after_expiry = validate_tls_handshake_with_replay_window(&handshake, &secrets, false, 2) + .expect("boot-time handshake still validates cryptographically after cache expiry"); + let digest_half_after_expiry = &validation_after_expiry.digest[..TLS_DIGEST_HALF_LEN]; + + assert_eq!( + digest_half, digest_half_after_expiry, + "replay key must remain stable for the same captured handshake" + ); + + assert!( + checker.check_and_add_tls_digest(digest_half_after_expiry), + "security expectation: a boot-time handshake should remain replay-protected even after cache expiry" + ); +} + +#[test] +fn stress_short_replay_window_boot_timestamp_replay_cycles_remain_fail_closed_in_window() { + let secret = b"gap_t01_boot_replay_stress"; + let secrets = vec![("user".to_string(), secret.to_vec())]; + let handshake = make_valid_tls_handshake(secret, 1); + + let checker = crate::stats::ReplayChecker::new(256, std::time::Duration::from_millis(25)); + + for cycle in 0..64 { + let validation = validate_tls_handshake_with_replay_window(&handshake, &secrets, false, 2) + .expect("boot-time handshake must validate"); + let digest_half = &validation.digest[..TLS_DIGEST_HALF_LEN]; + + if cycle == 0 { + assert!( + !checker.check_and_add_tls_digest(digest_half), + "cycle 0: first use must be fresh" + ); + assert!( + checker.check_and_add_tls_digest(digest_half), + "cycle 0: second use must be replay" + ); + } else { + assert!( + checker.check_and_add_tls_digest(digest_half), + "cycle {cycle}: digest must remain replay-protected across short-window churn" + ); + } + + std::thread::sleep(std::time::Duration::from_millis(30)); + } +} + +#[test] +fn light_fuzz_boot_time_timestamp_matrix_with_short_replay_window_obeys_boot_cap() { + let secret = b"gap_t01_boot_replay_fuzz"; + let secrets = vec![("user".to_string(), secret.to_vec())]; + + let mut s: u64 = 0xA1B2_C3D4_55AA_7733; + for _ in 0..2048 { + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + let ts = (s as u32) % 8; + + let handshake = make_valid_tls_handshake(secret, ts); + let accepted = validate_tls_handshake_with_replay_window(&handshake, &secrets, false, 2) + .is_some(); + + if ts < 2 { + assert!(accepted, "timestamp {ts} must remain boot-time compatible under 2s cap"); + } else { + assert!( + !accepted, + "timestamp {ts} must be rejected when outside replay-window boot cap" + ); + } + } +} + +#[test] +fn server_hello_application_data_contains_alpn_marker_when_selected() { + let secret = b"alpn_marker_test"; + let client_digest = [0x55u8; TLS_DIGEST_LEN]; + let session_id = vec![0xAB; 32]; + let rng = crate::crypto::SecureRandom::new(); + + let response = build_server_hello( + secret, + &client_digest, + &session_id, + 512, + &rng, + Some(b"h2".to_vec()), + 0, + ); + + let sh_len = u16::from_be_bytes([response[3], response[4]]) as usize; + let ccs_pos = 5 + sh_len; + let ccs_len = u16::from_be_bytes([response[ccs_pos + 3], response[ccs_pos + 4]]) as usize; + let app_pos = ccs_pos + 5 + ccs_len; + let app_len = u16::from_be_bytes([response[app_pos + 3], response[app_pos + 4]]) as usize; + let app_payload = &response[app_pos + 5..app_pos + 5 + app_len]; + + let expected = [0x00u8, 0x10, 0x00, 0x05, 0x00, 0x03, 0x02, b'h', b'2']; + assert!( + app_payload.windows(expected.len()).any(|window| window == expected), + "first application payload must carry ALPN marker for selected protocol" + ); +} + +#[test] +fn server_hello_ignores_oversized_alpn_and_still_caps_ticket_tail() { + let secret = b"alpn_oversize_ignore_test"; + let client_digest = [0x56u8; TLS_DIGEST_LEN]; + let session_id = vec![0xCD; 32]; + let rng = crate::crypto::SecureRandom::new(); + let oversized_alpn = vec![b'x'; u8::MAX as usize + 1]; + + let response = build_server_hello( + secret, + &client_digest, + &session_id, + 512, + &rng, + Some(oversized_alpn), + u8::MAX, + ); + + let mut pos = 0usize; + let mut app_records = 0usize; + let mut first_app_payload: Option<&[u8]> = None; + while pos + 5 <= response.len() { + let rtype = response[pos]; + let rlen = u16::from_be_bytes([response[pos + 3], response[pos + 4]]) as usize; + let next = pos + 5 + rlen; + assert!(next <= response.len(), "TLS record must stay inside response bounds"); + if rtype == TLS_RECORD_APPLICATION { + app_records += 1; + if first_app_payload.is_none() { + first_app_payload = Some(&response[pos + 5..next]); + } + } + pos = next; + } + let marker = [0x00u8, 0x10, 0x00, 0x06, 0x00, 0x04, 0x03, b'x', b'x', b'x', b'x']; + + assert_eq!( + app_records, 5, + "oversized ALPN must not change the four-ticket cap on tail records" + ); + assert!( + !first_app_payload + .expect("response must contain an application record") + .windows(marker.len()) + .any(|window| window == marker), + "oversized ALPN must be ignored rather than embedded into the first application payload" + ); +} + +#[test] +fn server_hello_ignores_oversized_alpn_when_marker_would_not_fit() { + let secret = b"alpn_too_large_to_fit_test"; + let client_digest = [0x57u8; TLS_DIGEST_LEN]; + let session_id = vec![0xEF; 32]; + let rng = crate::crypto::SecureRandom::new(); + let oversized_alpn = vec![0xAB; u8::MAX as usize]; + + let response = build_server_hello( + secret, + &client_digest, + &session_id, + 64, + &rng, + Some(oversized_alpn), + 0, + ); + + let sh_len = u16::from_be_bytes([response[3], response[4]]) as usize; + let ccs_pos = 5 + sh_len; + let ccs_len = u16::from_be_bytes([response[ccs_pos + 3], response[ccs_pos + 4]]) as usize; + let app_pos = ccs_pos + 5 + ccs_len; + let app_len = u16::from_be_bytes([response[app_pos + 3], response[app_pos + 4]]) as usize; + let app_payload = &response[app_pos + 5..app_pos + 5 + app_len]; + + let mut marker_prefix = Vec::new(); + marker_prefix.extend_from_slice(&0x0010u16.to_be_bytes()); + marker_prefix.extend_from_slice(&0x0102u16.to_be_bytes()); + marker_prefix.extend_from_slice(&0x0100u16.to_be_bytes()); + marker_prefix.push(0xff); + marker_prefix.extend_from_slice(&[0xab; 8]); + assert!( + !app_payload.starts_with(&marker_prefix), + "oversized ALPN must not be partially embedded into the ServerHello application record" + ); +} + +#[test] +fn server_hello_embeds_full_alpn_marker_when_it_exactly_fits_fake_cert_len() { + let secret = b"alpn_exact_fit_test"; + let client_digest = [0x58u8; TLS_DIGEST_LEN]; + let session_id = vec![0xA5; 32]; + let rng = crate::crypto::SecureRandom::new(); + let proto = vec![b'z'; 57]; + + // marker_len = 4 + (2 + (1 + proto_len)) = 7 + proto_len = 64 + let response = build_server_hello( + secret, + &client_digest, + &session_id, + 64, + &rng, + Some(proto.clone()), + 0, + ); + + let sh_len = u16::from_be_bytes([response[3], response[4]]) as usize; + let ccs_pos = 5 + sh_len; + let ccs_len = u16::from_be_bytes([response[ccs_pos + 3], response[ccs_pos + 4]]) as usize; + let app_pos = ccs_pos + 5 + ccs_len; + let app_len = u16::from_be_bytes([response[app_pos + 3], response[app_pos + 4]]) as usize; + let app_payload = &response[app_pos + 5..app_pos + 5 + app_len]; + + let mut expected_marker = Vec::new(); + expected_marker.extend_from_slice(&0x0010u16.to_be_bytes()); + expected_marker.extend_from_slice(&0x003Cu16.to_be_bytes()); + expected_marker.extend_from_slice(&0x003Au16.to_be_bytes()); + expected_marker.push(57u8); + expected_marker.extend_from_slice(&proto); + + assert_eq!(app_payload.len(), expected_marker.len()); + assert_eq!(app_payload, expected_marker.as_slice()); +} + +#[test] +fn server_hello_does_not_embed_partial_alpn_marker_when_one_byte_short() { + let secret = b"alpn_one_byte_short_test"; + let client_digest = [0x59u8; TLS_DIGEST_LEN]; + let session_id = vec![0xA6; 32]; + let rng = crate::crypto::SecureRandom::new(); + let proto = vec![0xAB; 58]; + + // marker_len = 65, fake_cert_len = 64 => marker must be fully skipped. + let response = build_server_hello( + secret, + &client_digest, + &session_id, + 64, + &rng, + Some(proto), + 0, + ); + + let sh_len = u16::from_be_bytes([response[3], response[4]]) as usize; + let ccs_pos = 5 + sh_len; + let ccs_len = u16::from_be_bytes([response[ccs_pos + 3], response[ccs_pos + 4]]) as usize; + let app_pos = ccs_pos + 5 + ccs_len; + let app_len = u16::from_be_bytes([response[app_pos + 3], response[app_pos + 4]]) as usize; + let app_payload = &response[app_pos + 5..app_pos + 5 + app_len]; + + let mut marker_prefix = Vec::new(); + marker_prefix.extend_from_slice(&0x0010u16.to_be_bytes()); + marker_prefix.extend_from_slice(&0x003Du16.to_be_bytes()); + marker_prefix.extend_from_slice(&0x003Bu16.to_be_bytes()); + marker_prefix.push(58u8); + marker_prefix.extend_from_slice(&[0xAB; 8]); + + assert!( + !app_payload.starts_with(&marker_prefix), + "one-byte-short ALPN marker must be skipped entirely, not partially embedded" + ); +} + +#[test] +fn exhaustive_tls_minor_version_classification_matches_policy() { + for minor in 0u8..=u8::MAX { + let first = [TLS_RECORD_HANDSHAKE, 0x03, minor]; + let expected = minor == 0x01 || minor == 0x03; + assert_eq!( + is_tls_handshake(&first), + expected, + "minor version {minor:#04x} classification mismatch" + ); + } +} + +#[test] +fn light_fuzz_tls_header_classifier_and_parser_policy_consistency() { + // Deterministic xorshift state keeps this fuzz test reproducible. + let mut s: u64 = 0x9E37_79B9_AA95_5A5D; + + for _ in 0..10_000 { + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + + let header = [ + (s & 0xff) as u8, + ((s >> 8) & 0xff) as u8, + ((s >> 16) & 0xff) as u8, + ((s >> 24) & 0xff) as u8, + ((s >> 32) & 0xff) as u8, + ]; + + let classified = is_tls_handshake(&header[..3]); + let expected_classified = header[0] == TLS_RECORD_HANDSHAKE + && header[1] == 0x03 + && (header[2] == 0x01 || header[2] == 0x03); + assert_eq!( + classified, + expected_classified, + "classifier policy mismatch for header {header:02x?}" + ); + + let parsed = parse_tls_record_header(&header); + let expected_parsed = header[1] == 0x03 && (header[2] == 0x01 || header[2] == TLS_VERSION[1]); + assert_eq!( + parsed.is_some(), + expected_parsed, + "parser policy mismatch for header {header:02x?}" + ); + } +} + +#[test] +fn stress_random_noise_handshakes_never_authenticate() { + let secret = b"stress_noise_secret"; + let secrets = vec![("noise-user".to_string(), secret.to_vec())]; + + // Deterministic xorshift state keeps this stress test reproducible. + let mut s: u64 = 0xD1B5_4A32_9C6E_77F1; + + for _ in 0..5_000 { + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + + let len = 1 + ((s as usize) % 196); + let mut buf = vec![0u8; len]; + for b in &mut buf { + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + *b = (s & 0xff) as u8; + } + + assert!( + validate_tls_handshake(&buf, &secrets, true).is_none(), + "random noise must never authenticate" + ); + } +} diff --git a/src/protocol/tests/tls_size_constants_security_tests.rs b/src/protocol/tests/tls_size_constants_security_tests.rs new file mode 100644 index 0000000..1389ab6 --- /dev/null +++ b/src/protocol/tests/tls_size_constants_security_tests.rs @@ -0,0 +1,15 @@ +use super::{ + MAX_TLS_CIPHERTEXT_SIZE, + MAX_TLS_PLAINTEXT_SIZE, + MIN_TLS_CLIENT_HELLO_SIZE, +}; + +#[test] +fn tls_size_constants_match_rfc_8446() { + assert_eq!(MAX_TLS_PLAINTEXT_SIZE, 16_384); + assert_eq!(MAX_TLS_CIPHERTEXT_SIZE, 16_640); + + assert!(MIN_TLS_CLIENT_HELLO_SIZE < 512); + assert!(MIN_TLS_CLIENT_HELLO_SIZE > 64); + assert!(MAX_TLS_CIPHERTEXT_SIZE > MAX_TLS_PLAINTEXT_SIZE); +} diff --git a/src/protocol/tls.rs b/src/protocol/tls.rs index fbe7ad5..9cac85e 100644 --- a/src/protocol/tls.rs +++ b/src/protocol/tls.rs @@ -11,8 +11,8 @@ use crate::crypto::{sha256_hmac, SecureRandom}; use crate::error::ProxyError; use super::constants::*; use std::time::{SystemTime, UNIX_EPOCH}; -use num_bigint::BigUint; -use num_traits::One; +use subtle::ConstantTimeEq; +use x25519_dalek::{X25519_BASEPOINT_BYTES, x25519}; // ============= Public Constants ============= @@ -26,8 +26,17 @@ pub const TLS_DIGEST_POS: usize = 11; pub const TLS_DIGEST_HALF_LEN: usize = 16; /// Time skew limits for anti-replay (in seconds) -pub const TIME_SKEW_MIN: i64 = -20 * 60; // 20 minutes before -pub const TIME_SKEW_MAX: i64 = 10 * 60; // 10 minutes after +/// +/// The default window is intentionally narrow to reduce replay acceptance. +/// Operators with known clock-drifted clients should tune deployment config +/// (for example replay-window policy) to match their environment. +pub const TIME_SKEW_MIN: i64 = -2 * 60; // 2 minutes before +pub const TIME_SKEW_MAX: i64 = 2 * 60; // 2 minutes after +/// Maximum accepted boot-time timestamp (seconds) before skew checks are enforced. +pub const BOOT_TIME_MAX_SECS: u32 = 7 * 24 * 60 * 60; +/// Hard cap for boot-time compatibility bypass to avoid oversized acceptance +/// windows when replay TTL is configured very large. +pub const BOOT_TIME_COMPAT_MAX_SECS: u32 = 2 * 60; // ============= Private Constants ============= @@ -60,6 +69,7 @@ pub struct TlsValidation { /// Client digest for response generation pub digest: [u8; TLS_DIGEST_LEN], /// Timestamp extracted from digest + pub timestamp: u32, } @@ -114,28 +124,8 @@ impl TlsExtensionBuilder { self } - /// Add ALPN extension with a single selected protocol. - fn add_alpn(&mut self, proto: &[u8]) -> &mut Self { - // Extension type: ALPN (0x0010) - self.extensions.extend_from_slice(&extension_type::ALPN.to_be_bytes()); - - // ALPN extension format: - // extension_data length (2 bytes) - // protocols length (2 bytes) - // protocol name length (1 byte) - // protocol name bytes - let proto_len = proto.len() as u8; - let list_len: u16 = 1 + proto_len as u16; - let ext_len: u16 = 2 + list_len; - - self.extensions.extend_from_slice(&ext_len.to_be_bytes()); - self.extensions.extend_from_slice(&list_len.to_be_bytes()); - self.extensions.push(proto_len); - self.extensions.extend_from_slice(proto); - self - } - /// Build final extensions with length prefix + fn build(self) -> Vec { let mut result = Vec::with_capacity(2 + self.extensions.len()); @@ -150,7 +140,7 @@ impl TlsExtensionBuilder { } /// Get current extensions without length prefix (for calculation) - #[allow(dead_code)] + fn as_bytes(&self) -> &[u8] { &self.extensions } @@ -170,8 +160,6 @@ struct ServerHelloBuilder { compression: u8, /// Extensions extensions: TlsExtensionBuilder, - /// Selected ALPN protocol (if any) - alpn: Option>, } impl ServerHelloBuilder { @@ -182,7 +170,6 @@ impl ServerHelloBuilder { cipher_suite: cipher_suite::TLS_AES_128_GCM_SHA256, compression: 0x00, extensions: TlsExtensionBuilder::new(), - alpn: None, } } @@ -197,18 +184,9 @@ impl ServerHelloBuilder { self } - fn with_alpn(mut self, proto: Option>) -> Self { - self.alpn = proto; - self - } - /// Build ServerHello message (without record header) fn build_message(&self) -> Vec { - let mut ext_builder = self.extensions.clone(); - if let Some(ref alpn) = self.alpn { - ext_builder.add_alpn(alpn); - } - let extensions = ext_builder.extensions.clone(); + let extensions = self.extensions.extensions.clone(); let extensions_len = extensions.len() as u16; // Calculate total length @@ -273,13 +251,97 @@ impl ServerHelloBuilder { // ============= Public Functions ============= -/// Validate TLS ClientHello against user secrets +/// Validate TLS ClientHello against user secrets. /// /// Returns validation result if a matching user is found. +/// The result **must** be used — ignoring it silently bypasses authentication. +#[must_use] + pub fn validate_tls_handshake( handshake: &[u8], secrets: &[(String, Vec)], ignore_time_skew: bool, +) -> Option { + validate_tls_handshake_with_replay_window( + handshake, + secrets, + ignore_time_skew, + u64::from(BOOT_TIME_MAX_SECS), + ) +} + +/// Validate TLS ClientHello and cap the boot-time bypass by replay-cache TTL. +/// +/// A boot-time timestamp is only accepted when it falls below all three +/// bounds: `BOOT_TIME_MAX_SECS`, configured replay window, and +/// `BOOT_TIME_COMPAT_MAX_SECS`, preventing oversized compatibility windows. +#[must_use] +pub fn validate_tls_handshake_with_replay_window( + handshake: &[u8], + secrets: &[(String, Vec)], + ignore_time_skew: bool, + replay_window_secs: u64, +) -> Option { + // Only pay the clock syscall when we will actually compare against it. + // If `ignore_time_skew` is set, a broken or unavailable system clock + // must not block legitimate clients — that would be a DoS via clock failure. + let now = if !ignore_time_skew { + system_time_to_unix_secs(SystemTime::now())? + } else { + 0_i64 + }; + + let replay_window_u32 = u32::try_from(replay_window_secs).unwrap_or(u32::MAX); + // Boot-time bypass and ignore_time_skew serve different compatibility paths. + // When skew checks are disabled, force boot-time cap to zero to prevent + // accidental future coupling of boot-time logic into the ignore-skew path. + let boot_time_cap_secs = if ignore_time_skew { + 0 + } else { + BOOT_TIME_MAX_SECS + .min(replay_window_u32) + .min(BOOT_TIME_COMPAT_MAX_SECS) + }; + + validate_tls_handshake_at_time_with_boot_cap( + handshake, + secrets, + ignore_time_skew, + now, + boot_time_cap_secs, + ) +} + +fn system_time_to_unix_secs(now: SystemTime) -> Option { + // `try_from` rejects values that overflow i64 (> ~292 billion years CE), + // whereas `as i64` would silently wrap to a negative timestamp and corrupt + // every subsequent time-skew comparison. + let d = now.duration_since(UNIX_EPOCH).ok()?; + i64::try_from(d.as_secs()).ok() +} + + +fn validate_tls_handshake_at_time( + handshake: &[u8], + secrets: &[(String, Vec)], + ignore_time_skew: bool, + now: i64, +) -> Option { + validate_tls_handshake_at_time_with_boot_cap( + handshake, + secrets, + ignore_time_skew, + now, + BOOT_TIME_MAX_SECS, + ) +} + +fn validate_tls_handshake_at_time_with_boot_cap( + handshake: &[u8], + secrets: &[(String, Vec)], + ignore_time_skew: bool, + now: i64, + boot_time_cap_secs: u32, ) -> Option { if handshake.len() < TLS_DIGEST_POS + TLS_DIGEST_LEN + 1 { return None; @@ -293,6 +355,9 @@ pub fn validate_tls_handshake( // Extract session ID let session_id_len_pos = TLS_DIGEST_POS + TLS_DIGEST_LEN; let session_id_len = handshake.get(session_id_len_pos).copied()? as usize; + if session_id_len > 32 { + return None; + } let session_id_start = session_id_len_pos + 1; if handshake.len() < session_id_start + session_id_len { @@ -305,73 +370,66 @@ pub fn validate_tls_handshake( let mut msg = handshake.to_vec(); msg[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].fill(0); - // Get current time - let now = SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_secs() as i64; - + let mut first_match: Option<(&String, u32)> = None; + for (user, secret) in secrets { let computed = sha256_hmac(secret, &msg); - - // XOR digests - let xored: Vec = digest.iter() - .zip(computed.iter()) - .map(|(a, b)| a ^ b) - .collect(); - - // Check that first 28 bytes are zeros (timestamp in last 4) - if !xored[..28].iter().all(|&b| b == 0) { + + // Constant-time equality check on the 28-byte HMAC window. + // A variable-time short-circuit here lets an active censor measure how many + // bytes matched, enabling secret brute-force via timing side-channels. + // Direct comparison on the original arrays avoids a heap allocation and + // removes the `try_into().unwrap()` that the intermediate Vec would require. + if !bool::from(digest[..28].ct_eq(&computed[..28])) { continue; } - - // Extract timestamp - let timestamp = u32::from_le_bytes(xored[28..32].try_into().unwrap()); - let time_diff = now - timestamp as i64; - - // Check time skew + + // The last 4 bytes encode the timestamp as XOR(digest[28..32], computed[28..32]). + // Inline array construction is infallible: both slices are [u8; 32] by construction. + let timestamp = u32::from_le_bytes([ + digest[28] ^ computed[28], + digest[29] ^ computed[29], + digest[30] ^ computed[30], + digest[31] ^ computed[31], + ]); + + // time_diff is only meaningful (and `now` is only valid) when we are + // actually checking the window. Keep both inside the guard to make + // the dead-code path explicit and prevent accidental future use of + // a sentinel `now` value outside its intended scope. if !ignore_time_skew { // Allow very small timestamps (boot time instead of unix time) // This is a quirk in some clients that use uptime instead of real time - let is_boot_time = timestamp < 60 * 60 * 24 * 1000; // < ~2.7 years in seconds - - if !is_boot_time && !(TIME_SKEW_MIN..=TIME_SKEW_MAX).contains(&time_diff) { - continue; + let is_boot_time = boot_time_cap_secs > 0 && timestamp < boot_time_cap_secs; + if !is_boot_time { + let time_diff = now - i64::from(timestamp); + if !(TIME_SKEW_MIN..=TIME_SKEW_MAX).contains(&time_diff) { + continue; + } } } - return Some(TlsValidation { - user: user.clone(), - session_id, - digest, - timestamp, - }); + if first_match.is_none() { + first_match = Some((user, timestamp)); + } } - - None -} -fn curve25519_prime() -> BigUint { - (BigUint::one() << 255) - BigUint::from(19u32) + first_match.map(|(user, timestamp)| TlsValidation { + user: user.clone(), + session_id, + digest, + timestamp, + }) } /// Generate a fake X25519 public key for TLS /// -/// Produces a quadratic residue mod p = 2^255 - 19 by computing n² mod p, -/// which matches Python/C behavior and avoids DPI fingerprinting. +/// Uses RFC 7748 X25519 scalar multiplication over the canonical basepoint, +/// yielding distribution-consistent public keys for anti-fingerprinting. pub fn gen_fake_x25519_key(rng: &SecureRandom) -> [u8; 32] { - let mut n_bytes = [0u8; 32]; - n_bytes.copy_from_slice(&rng.bytes(32)); - - let n = BigUint::from_bytes_le(&n_bytes); - let p = curve25519_prime(); - let pk = (&n * &n) % &p; - - let mut out = pk.to_bytes_le(); - out.resize(32, 0); - let mut result = [0u8; 32]; - result.copy_from_slice(&out[..32]); - result + let mut scalar = [0u8; 32]; + scalar.copy_from_slice(&rng.bytes(32)); + x25519(scalar, X25519_BASEPOINT_BYTES) } /// Build TLS ServerHello response @@ -392,7 +450,7 @@ pub fn build_server_hello( new_session_tickets: u8, ) -> Vec { const MIN_APP_DATA: usize = 64; - const MAX_APP_DATA: usize = 16640; // RFC 8446 §5.2 upper bound + const MAX_APP_DATA: usize = MAX_TLS_CIPHERTEXT_SIZE; let fake_cert_len = fake_cert_len.clamp(MIN_APP_DATA, MAX_APP_DATA); let x25519_key = gen_fake_x25519_key(rng); @@ -400,7 +458,6 @@ pub fn build_server_hello( let server_hello = ServerHelloBuilder::new(session_id.to_vec()) .with_x25519_key(&x25519_key) .with_tls13_version() - .with_alpn(alpn) .build_record(); // Build Change Cipher Spec record @@ -411,8 +468,27 @@ pub fn build_server_hello( 0x01, // CCS byte ]; - // Build fake certificate (Application Data record) - let fake_cert = rng.bytes(fake_cert_len); + // Build first encrypted flight mimic as opaque ApplicationData bytes. + // Embed a compact EncryptedExtensions-like ALPN block when selected. + let mut fake_cert = Vec::with_capacity(fake_cert_len); + if let Some(proto) = alpn.as_ref().filter(|p| !p.is_empty() && p.len() <= u8::MAX as usize) { + let proto_list_len = 1usize + proto.len(); + let ext_data_len = 2usize + proto_list_len; + let marker_len = 4usize + ext_data_len; + if marker_len <= fake_cert_len { + fake_cert.extend_from_slice(&0x0010u16.to_be_bytes()); + fake_cert.extend_from_slice(&(ext_data_len as u16).to_be_bytes()); + fake_cert.extend_from_slice(&(proto_list_len as u16).to_be_bytes()); + fake_cert.push(proto.len() as u8); + fake_cert.extend_from_slice(proto); + } + } + if fake_cert.len() < fake_cert_len { + fake_cert.extend_from_slice(&rng.bytes(fake_cert_len - fake_cert.len())); + } else if fake_cert.len() > fake_cert_len { + fake_cert.truncate(fake_cert_len); + } + let mut app_data_record = Vec::with_capacity(5 + fake_cert_len); app_data_record.push(TLS_RECORD_APPLICATION); app_data_record.extend_from_slice(&TLS_VERSION); @@ -424,8 +500,9 @@ pub fn build_server_hello( // Build optional NewSessionTicket records (TLS 1.3 handshake messages are encrypted; // here we mimic with opaque ApplicationData records of plausible size). let mut tickets = Vec::new(); - if new_session_tickets > 0 { - for _ in 0..new_session_tickets { + let ticket_count = new_session_tickets.min(4); + if ticket_count > 0 { + for _ in 0..ticket_count { let ticket_len: usize = rng.range(48) + 48; // 48-95 bytes let mut record = Vec::with_capacity(5 + ticket_len); record.push(TLS_RECORD_APPLICATION); @@ -467,6 +544,11 @@ pub fn extract_sni_from_client_hello(handshake: &[u8]) -> Option { return None; } + let record_len = u16::from_be_bytes([handshake[3], handshake[4]]) as usize; + if handshake.len() < 5 + record_len { + return None; + } + let mut pos = 5; // after record header if handshake.get(pos).copied()? != 0x01 { return None; // not ClientHello @@ -506,6 +588,9 @@ pub fn extract_sni_from_client_hello(handshake: &[u8]) -> Option { return None; } + let mut saw_sni_extension = false; + let mut extracted_sni = None; + while pos + 4 <= ext_end { let etype = u16::from_be_bytes([handshake[pos], handshake[pos + 1]]); let elen = u16::from_be_bytes([handshake[pos + 2], handshake[pos + 3]]) as usize; @@ -513,6 +598,12 @@ pub fn extract_sni_from_client_hello(handshake: &[u8]) -> Option { if pos + elen > ext_end { break; } + if etype == 0x0000 { + if saw_sni_extension { + return None; + } + saw_sni_extension = true; + } if etype == 0x0000 && elen >= 5 { // server_name extension let list_len = u16::from_be_bytes([handshake[pos], handshake[pos + 1]]) as usize; @@ -528,7 +619,10 @@ pub fn extract_sni_from_client_hello(handshake: &[u8]) -> Option { if name_type == 0 && name_len > 0 && let Ok(host) = std::str::from_utf8(&handshake[sn_pos..sn_pos + name_len]) { - return Some(host.to_string()); + if is_valid_sni_hostname(host) { + extracted_sni = Some(host.to_string()); + break; + } } sn_pos += name_len; } @@ -536,11 +630,49 @@ pub fn extract_sni_from_client_hello(handshake: &[u8]) -> Option { pos += elen; } - None + extracted_sni +} + +fn is_valid_sni_hostname(host: &str) -> bool { + if host.is_empty() || host.len() > 253 { + return false; + } + if host.starts_with('.') || host.ends_with('.') { + return false; + } + if host.parse::().is_ok() { + return false; + } + + for label in host.split('.') { + if label.is_empty() || label.len() > 63 { + return false; + } + if label.starts_with('-') || label.ends_with('-') { + return false; + } + if !label + .bytes() + .all(|b| b.is_ascii_alphanumeric() || b == b'-') + { + return false; + } + } + + true } /// Extract ALPN protocol list from ClientHello, return in offered order. pub fn extract_alpn_from_client_hello(handshake: &[u8]) -> Vec> { + if handshake.len() < 5 || handshake[0] != TLS_RECORD_HANDSHAKE { + return Vec::new(); + } + + let record_len = u16::from_be_bytes([handshake[3], handshake[4]]) as usize; + if handshake.len() < 5 + record_len { + return Vec::new(); + } + let mut pos = 5; // after record header if handshake.get(pos) != Some(&0x01) { return Vec::new(); @@ -592,13 +724,14 @@ pub fn is_tls_handshake(first_bytes: &[u8]) -> bool { return false; } - // TLS record header: 0x16 (handshake) 0x03 0x01 (TLS 1.0) + // TLS ClientHello commonly uses legacy record versions 0x0301 or 0x0303. first_bytes[0] == TLS_RECORD_HANDSHAKE && first_bytes[1] == 0x03 - && first_bytes[2] == 0x01 + && (first_bytes[2] == 0x01 || first_bytes[2] == 0x03) } /// Parse TLS record header, returns (record_type, length) + pub fn parse_tls_record_header(header: &[u8; 5]) -> Option<(u8, u16)> { let record_type = header[0]; let version = [header[1], header[2]]; @@ -667,291 +800,37 @@ fn validate_server_hello_structure(data: &[u8]) -> Result<(), ProxyError> { Ok(()) } -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_is_tls_handshake() { - assert!(is_tls_handshake(&[0x16, 0x03, 0x01])); - assert!(is_tls_handshake(&[0x16, 0x03, 0x01, 0x02, 0x00])); - assert!(!is_tls_handshake(&[0x17, 0x03, 0x01])); // Application data - assert!(!is_tls_handshake(&[0x16, 0x03, 0x02])); // Wrong version - assert!(!is_tls_handshake(&[0x16, 0x03])); // Too short - } - - #[test] - fn test_parse_tls_record_header() { - let header = [0x16, 0x03, 0x01, 0x02, 0x00]; - let result = parse_tls_record_header(&header).unwrap(); - assert_eq!(result.0, TLS_RECORD_HANDSHAKE); - assert_eq!(result.1, 512); - - let header = [0x17, 0x03, 0x03, 0x40, 0x00]; - let result = parse_tls_record_header(&header).unwrap(); - assert_eq!(result.0, TLS_RECORD_APPLICATION); - assert_eq!(result.1, 16384); - } - - #[test] - fn test_gen_fake_x25519_key() { - let rng = SecureRandom::new(); - let key1 = gen_fake_x25519_key(&rng); - let key2 = gen_fake_x25519_key(&rng); - - assert_eq!(key1.len(), 32); - assert_eq!(key2.len(), 32); - assert_ne!(key1, key2); // Should be random - } +// ============= Compile-time Security Invariants ============= - #[test] - fn test_fake_x25519_key_is_quadratic_residue() { - let rng = SecureRandom::new(); - let key = gen_fake_x25519_key(&rng); - let p = curve25519_prime(); - let k_num = BigUint::from_bytes_le(&key); - let exponent = (&p - BigUint::one()) >> 1; - let legendre = k_num.modpow(&exponent, &p); - assert_eq!(legendre, BigUint::one()); - } - - #[test] - fn test_tls_extension_builder() { - let key = [0x42u8; 32]; - - let mut builder = TlsExtensionBuilder::new(); - builder.add_key_share(&key); - builder.add_supported_versions(0x0304); - - let result = builder.build(); - - // Check length prefix - let len = u16::from_be_bytes([result[0], result[1]]) as usize; - assert_eq!(len, result.len() - 2); - - // Check key_share extension is present - assert!(result.len() > 40); // At least key share - } - - #[test] - fn test_server_hello_builder() { - let session_id = vec![0x01, 0x02, 0x03, 0x04]; - let key = [0x55u8; 32]; - - let builder = ServerHelloBuilder::new(session_id.clone()) - .with_x25519_key(&key) - .with_tls13_version(); - - let record = builder.build_record(); - - // Validate structure - validate_server_hello_structure(&record).expect("Invalid ServerHello structure"); - - // Check record type - assert_eq!(record[0], TLS_RECORD_HANDSHAKE); - - // Check version - assert_eq!(&record[1..3], &TLS_VERSION); - - // Check message type (ServerHello = 0x02) - assert_eq!(record[5], 0x02); - } - - #[test] - fn test_build_server_hello_structure() { - let secret = b"test secret"; - let client_digest = [0x42u8; 32]; - let session_id = vec![0xAA; 32]; - - let rng = SecureRandom::new(); - let response = build_server_hello(secret, &client_digest, &session_id, 2048, &rng, None, 0); - - // Should have at least 3 records - assert!(response.len() > 100); - - // First record should be ServerHello - assert_eq!(response[0], TLS_RECORD_HANDSHAKE); - - // Validate ServerHello structure - validate_server_hello_structure(&response).expect("Invalid ServerHello"); - - // Find Change Cipher Spec - let server_hello_len = 5 + u16::from_be_bytes([response[3], response[4]]) as usize; - let ccs_start = server_hello_len; - - assert!(response.len() > ccs_start + 6); - assert_eq!(response[ccs_start], TLS_RECORD_CHANGE_CIPHER); - - // Find Application Data - let ccs_len = 5 + u16::from_be_bytes([response[ccs_start + 3], response[ccs_start + 4]]) as usize; - let app_start = ccs_start + ccs_len; - - assert!(response.len() > app_start + 5); - assert_eq!(response[app_start], TLS_RECORD_APPLICATION); - } - - #[test] - fn test_build_server_hello_digest() { - let secret = b"test secret key here"; - let client_digest = [0x42u8; 32]; - let session_id = vec![0xAA; 32]; - - let rng = SecureRandom::new(); - let response1 = build_server_hello(secret, &client_digest, &session_id, 1024, &rng, None, 0); - let response2 = build_server_hello(secret, &client_digest, &session_id, 1024, &rng, None, 0); - - // Digest position should have non-zero data - let digest1 = &response1[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN]; - assert!(!digest1.iter().all(|&b| b == 0)); - - // Different calls should have different digests (due to random cert) - let digest2 = &response2[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN]; - assert_ne!(digest1, digest2); - } - - #[test] - fn test_server_hello_extensions_length() { - let session_id = vec![0x01; 32]; - let key = [0x55u8; 32]; - - let builder = ServerHelloBuilder::new(session_id) - .with_x25519_key(&key) - .with_tls13_version(); - - let record = builder.build_record(); - - // Parse to find extensions - let msg_start = 5; // After record header - let msg_len = u32::from_be_bytes([0, record[6], record[7], record[8]]) as usize; - - // Skip to session ID - let session_id_pos = msg_start + 4 + 2 + 32; // header(4) + version(2) + random(32) - let session_id_len = record[session_id_pos] as usize; - - // Skip to extensions - let ext_len_pos = session_id_pos + 1 + session_id_len + 2 + 1; // session_id + cipher(2) + compression(1) - let ext_len = u16::from_be_bytes([record[ext_len_pos], record[ext_len_pos + 1]]) as usize; - - // Verify extensions length matches actual data - let extensions_data = &record[ext_len_pos + 2..msg_start + 4 + msg_len]; - assert_eq!(ext_len, extensions_data.len(), - "Extension length mismatch: declared {}, actual {}", ext_len, extensions_data.len()); - } - - #[test] - fn test_validate_tls_handshake_format() { - // Build a minimal ClientHello-like structure - let mut handshake = vec![0u8; 100]; - - // Put a valid-looking digest at position 11 - handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN] - .copy_from_slice(&[0x42; 32]); - - // Session ID length - handshake[TLS_DIGEST_POS + TLS_DIGEST_LEN] = 32; - - // This won't validate (wrong HMAC) but shouldn't panic - let secrets = vec![("test".to_string(), b"secret".to_vec())]; - let result = validate_tls_handshake(&handshake, &secrets, true); - - // Should return None (no match) but not panic - assert!(result.is_none()); - } +/// Compile-time checks that enforce invariants the rest of the code relies on. +/// Using `static_assertions` ensures these can never silently break across +/// refactors without a compile error. +mod compile_time_security_checks { + use super::{TLS_DIGEST_LEN, TLS_DIGEST_HALF_LEN}; + use static_assertions::const_assert; - fn build_client_hello_with_exts(exts: Vec<(u16, Vec)>, host: &str) -> Vec { - let mut body = Vec::new(); - body.extend_from_slice(&TLS_VERSION); // legacy version - body.extend_from_slice(&[0u8; 32]); // random - body.push(0); // session id len - body.extend_from_slice(&2u16.to_be_bytes()); // cipher suites len - body.extend_from_slice(&[0x13, 0x01]); // TLS_AES_128_GCM_SHA256 - body.push(1); // compression len - body.push(0); // null compression + // The digest must be exactly one SHA-256 output. + const_assert!(TLS_DIGEST_LEN == 32); - // Build SNI extension - let host_bytes = host.as_bytes(); - let mut sni_ext = Vec::new(); - sni_ext.extend_from_slice(&(host_bytes.len() as u16 + 3).to_be_bytes()); - sni_ext.push(0); - sni_ext.extend_from_slice(&(host_bytes.len() as u16).to_be_bytes()); - sni_ext.extend_from_slice(host_bytes); + // Replay-dedup stores the first half; verify it is literally half. + const_assert!(TLS_DIGEST_HALF_LEN * 2 == TLS_DIGEST_LEN); - let mut ext_blob = Vec::new(); - for (typ, data) in exts { - ext_blob.extend_from_slice(&typ.to_be_bytes()); - ext_blob.extend_from_slice(&(data.len() as u16).to_be_bytes()); - ext_blob.extend_from_slice(&data); - } - // SNI last - ext_blob.extend_from_slice(&0x0000u16.to_be_bytes()); - ext_blob.extend_from_slice(&(sni_ext.len() as u16).to_be_bytes()); - ext_blob.extend_from_slice(&sni_ext); - - body.extend_from_slice(&(ext_blob.len() as u16).to_be_bytes()); - body.extend_from_slice(&ext_blob); - - let mut handshake = Vec::new(); - handshake.push(0x01); // ClientHello - let len_bytes = (body.len() as u32).to_be_bytes(); - handshake.extend_from_slice(&len_bytes[1..4]); - handshake.extend_from_slice(&body); - - let mut record = Vec::new(); - record.push(TLS_RECORD_HANDSHAKE); - record.extend_from_slice(&[0x03, 0x01]); - record.extend_from_slice(&(handshake.len() as u16).to_be_bytes()); - record.extend_from_slice(&handshake); - record - } - - #[test] - fn test_extract_sni_with_grease_extension() { - // GREASE type 0x0a0a with zero length before SNI - let ch = build_client_hello_with_exts(vec![(0x0a0a, Vec::new())], "example.com"); - let sni = extract_sni_from_client_hello(&ch); - assert_eq!(sni.as_deref(), Some("example.com")); - } - - #[test] - fn test_extract_sni_tolerates_empty_unknown_extension() { - let ch = build_client_hello_with_exts(vec![(0x1234, Vec::new())], "test.local"); - let sni = extract_sni_from_client_hello(&ch); - assert_eq!(sni.as_deref(), Some("test.local")); - } - - #[test] - fn test_extract_alpn_single() { - let mut alpn_data = Vec::new(); - // list length = 3 (1 length byte + "h2") - alpn_data.extend_from_slice(&3u16.to_be_bytes()); - alpn_data.push(2); - alpn_data.extend_from_slice(b"h2"); - let ch = build_client_hello_with_exts(vec![(0x0010, alpn_data)], "alpn.test"); - let alpn = extract_alpn_from_client_hello(&ch); - let alpn_str: Vec = alpn - .iter() - .map(|p| std::str::from_utf8(p).unwrap().to_string()) - .collect(); - assert_eq!(alpn_str, vec!["h2"]); - } - - #[test] - fn test_extract_alpn_multiple() { - let mut alpn_data = Vec::new(); - // list length = 11 (sum of per-proto lengths including length bytes) - alpn_data.extend_from_slice(&11u16.to_be_bytes()); - alpn_data.push(2); - alpn_data.extend_from_slice(b"h2"); - alpn_data.push(4); - alpn_data.extend_from_slice(b"spdy"); - alpn_data.push(2); - alpn_data.extend_from_slice(b"h3"); - let ch = build_client_hello_with_exts(vec![(0x0010, alpn_data)], "alpn.test"); - let alpn = extract_alpn_from_client_hello(&ch); - let alpn_str: Vec = alpn - .iter() - .map(|p| std::str::from_utf8(p).unwrap().to_string()) - .collect(); - assert_eq!(alpn_str, vec!["h2", "spdy", "h3"]); - } + // The HMAC check window (28 bytes) plus the embedded timestamp (4 bytes) + // must exactly fill the digest. If TLS_DIGEST_LEN ever changes, these + // assertions will catch the mismatch before any timing-oracle fix is broke. + const_assert!(28 + 4 == TLS_DIGEST_LEN); } + +// ============= Security-focused regression tests ============= + +#[cfg(test)] +#[path = "tests/tls_security_tests.rs"] +mod security_tests; + +#[cfg(test)] +#[path = "tests/tls_adversarial_tests.rs"] +mod adversarial_tests; + +#[cfg(test)] +#[path = "tests/tls_fuzz_security_tests.rs"] +mod fuzz_security_tests; diff --git a/src/proxy/client.rs b/src/proxy/client.rs index 25e6cf9..a68a8c2 100644 --- a/src/proxy/client.rs +++ b/src/proxy/client.rs @@ -4,7 +4,11 @@ use std::future::Future; use std::net::{IpAddr, SocketAddr}; use std::pin::Pin; use std::sync::Arc; +use std::sync::OnceLock; +use std::sync::atomic::{AtomicBool, Ordering}; use std::time::Duration; +use ipnetwork::IpNetwork; +use rand::RngExt; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite}; use tokio::net::TcpStream; use tokio::time::timeout; @@ -17,13 +21,54 @@ type PostHandshakeFuture = Pin> + Send>>; enum HandshakeOutcome { /// Handshake succeeded, relay work to do (outside timeout) NeedsRelay(PostHandshakeFuture), - /// Already fully handled (bad client masking, etc.) - Handled, + /// Handshake failed and masking must run outside handshake timeout budget + NeedsMasking(PostHandshakeFuture), +} + +#[must_use = "UserConnectionReservation must be kept alive to retain user/IP reservation until release or drop"] +struct UserConnectionReservation { + stats: Arc, + ip_tracker: Arc, + user: String, + ip: IpAddr, + active: bool, +} + +impl UserConnectionReservation { + fn new(stats: Arc, ip_tracker: Arc, user: String, ip: IpAddr) -> Self { + Self { + stats, + ip_tracker, + user, + ip, + active: true, + } + } + + async fn release(mut self) { + if !self.active { + return; + } + self.ip_tracker.remove_ip(&self.user, self.ip).await; + self.active = false; + self.stats.decrement_user_curr_connects(&self.user); + } +} + +impl Drop for UserConnectionReservation { + fn drop(&mut self) { + if !self.active { + return; + } + self.active = false; + self.stats.decrement_user_curr_connects(&self.user); + self.ip_tracker.enqueue_cleanup(self.user.clone(), self.ip); + } } use crate::config::ProxyConfig; use crate::crypto::SecureRandom; -use crate::error::{HandshakeResult, ProxyError, Result}; +use crate::error::{HandshakeResult, ProxyError, Result, StreamError}; use crate::ip_tracker::UserIpTracker; use crate::protocol::constants::*; use crate::protocol::tls; @@ -40,10 +85,118 @@ use crate::proxy::handshake::{HandshakeSuccess, handle_mtproto_handshake, handle use crate::proxy::masking::handle_bad_client; use crate::proxy::middle_relay::handle_via_middle_proxy; use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController}; -use crate::proxy::session_eviction::register_session; fn beobachten_ttl(config: &ProxyConfig) -> Duration { - Duration::from_secs(config.general.beobachten_minutes.saturating_mul(60)) + const BEOBACHTEN_TTL_MAX_MINUTES: u64 = 24 * 60; + let minutes = config.general.beobachten_minutes; + if minutes == 0 { + static BEOBACHTEN_ZERO_MINUTES_WARNED: OnceLock = OnceLock::new(); + let warned = BEOBACHTEN_ZERO_MINUTES_WARNED.get_or_init(|| AtomicBool::new(false)); + if !warned.swap(true, Ordering::Relaxed) { + warn!( + "general.beobachten_minutes=0 is insecure because entries expire immediately; forcing minimum TTL to 1 minute" + ); + } + return Duration::from_secs(60); + } + + if minutes > BEOBACHTEN_TTL_MAX_MINUTES { + static BEOBACHTEN_OVERSIZED_MINUTES_WARNED: OnceLock = OnceLock::new(); + let warned = BEOBACHTEN_OVERSIZED_MINUTES_WARNED.get_or_init(|| AtomicBool::new(false)); + if !warned.swap(true, Ordering::Relaxed) { + warn!( + configured_minutes = minutes, + max_minutes = BEOBACHTEN_TTL_MAX_MINUTES, + "general.beobachten_minutes is too large; clamping to secure maximum" + ); + } + } + + Duration::from_secs(minutes.min(BEOBACHTEN_TTL_MAX_MINUTES).saturating_mul(60)) +} + +fn wrap_tls_application_record(payload: &[u8]) -> Vec { + let mut record = Vec::with_capacity(5 + payload.len()); + record.push(TLS_RECORD_APPLICATION); + record.extend_from_slice(&TLS_VERSION); + record.extend_from_slice(&(payload.len() as u16).to_be_bytes()); + record.extend_from_slice(payload); + record +} + +fn tls_clienthello_len_in_bounds(tls_len: usize) -> bool { + (MIN_TLS_CLIENT_HELLO_SIZE..=MAX_TLS_PLAINTEXT_SIZE).contains(&tls_len) +} + +async fn read_with_progress(reader: &mut R, mut buf: &mut [u8]) -> std::io::Result { + let mut total = 0usize; + while !buf.is_empty() { + match reader.read(buf).await { + Ok(0) => return Ok(total), + Ok(n) => { + total += n; + let (_, rest) = buf.split_at_mut(n); + buf = rest; + } + Err(e) => return Err(e), + } + } + Ok(total) +} + +async fn maybe_apply_mask_reject_delay(config: &ProxyConfig) { + let min = config.censorship.server_hello_delay_min_ms; + let max = config.censorship.server_hello_delay_max_ms; + if max == 0 { + return; + } + + let delay_ms = if min >= max { + max + } else { + rand::rng().random_range(min..=max) + }; + + if delay_ms > 0 { + tokio::time::sleep(Duration::from_millis(delay_ms)).await; + } +} + +fn handshake_timeout_with_mask_grace(config: &ProxyConfig) -> Duration { + let base = Duration::from_secs(config.timeouts.client_handshake); + if config.censorship.mask { + base.saturating_add(Duration::from_millis(750)) + } else { + base + } +} + +fn masking_outcome( + reader: R, + writer: W, + initial_data: Vec, + peer: SocketAddr, + local_addr: SocketAddr, + config: Arc, + beobachten: Arc, +) -> HandshakeOutcome +where + R: AsyncRead + Unpin + Send + 'static, + W: AsyncWrite + Unpin + Send + 'static, +{ + HandshakeOutcome::NeedsMasking(Box::pin(async move { + handle_bad_client( + reader, + writer, + &initial_data, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + Ok(()) + })) } fn record_beobachten_class( @@ -64,14 +217,34 @@ fn record_handshake_failure_class( peer_ip: IpAddr, error: &ProxyError, ) { - let class = if error.to_string().contains("expected 64 bytes, got 0") { - "expected_64_got_0" - } else { - "other" + let class = match error { + ProxyError::Io(err) if err.kind() == std::io::ErrorKind::UnexpectedEof => { + "expected_64_got_0" + } + ProxyError::Stream(StreamError::UnexpectedEof) => "expected_64_got_0", + _ => "other", }; record_beobachten_class(beobachten, config, peer_ip, class); } +fn is_trusted_proxy_source(peer_ip: IpAddr, trusted: &[IpNetwork]) -> bool { + if trusted.is_empty() { + static EMPTY_PROXY_TRUST_WARNED: OnceLock = OnceLock::new(); + let warned = EMPTY_PROXY_TRUST_WARNED.get_or_init(|| AtomicBool::new(false)); + if !warned.swap(true, Ordering::Relaxed) { + warn!( + "PROXY protocol enabled but server.proxy_protocol_trusted_cidrs is empty; rejecting all PROXY headers by default" + ); + } + return false; + } + trusted.iter().any(|cidr| cidr.contains(peer_ip)) +} + +fn synthetic_local_addr(port: u16) -> SocketAddr { + SocketAddr::from(([0, 0, 0, 0], port)) +} + pub async fn handle_client_stream( mut stream: S, peer: SocketAddr, @@ -95,9 +268,7 @@ where let mut real_peer = normalize_ip(peer); // For non-TCP streams, use a synthetic local address; may be overridden by PROXY protocol dst - let mut local_addr: SocketAddr = format!("0.0.0.0:{}", config.server.port) - .parse() - .unwrap_or_else(|_| "0.0.0.0:443".parse().unwrap()); + let mut local_addr = synthetic_local_addr(config.server.port); if proxy_protocol_enabled { let proxy_header_timeout = Duration::from_millis( @@ -105,6 +276,17 @@ where ); match timeout(proxy_header_timeout, parse_proxy_protocol(&mut stream, peer)).await { Ok(Ok(info)) => { + if !is_trusted_proxy_source(peer.ip(), &config.server.proxy_protocol_trusted_cidrs) + { + stats.increment_connects_bad(); + warn!( + peer = %peer, + trusted = ?config.server.proxy_protocol_trusted_cidrs, + "Rejecting PROXY protocol header from untrusted source" + ); + record_beobachten_class(&beobachten, &config, peer.ip(), "other"); + return Err(ProxyError::InvalidProxyProtocol); + } debug!( peer = %peer, client = %info.src_addr, @@ -133,7 +315,7 @@ where debug!(peer = %real_peer, "New connection (generic stream)"); - let handshake_timeout = Duration::from_secs(config.timeouts.client_handshake); + let handshake_timeout = handshake_timeout_with_mask_grace(&config); let stats_for_timeout = stats.clone(); let config_for_timeout = config.clone(); let beobachten_for_timeout = beobachten.clone(); @@ -150,26 +332,68 @@ where if is_tls { let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize; - if tls_len < 512 { - debug!(peer = %real_peer, tls_len = tls_len, "TLS handshake too short"); + // RFC 8446 §5.1: TLS record payload MUST NOT exceed 2^14 (16_384) bytes. + // Lower bound is a structural minimum for a valid TLS 1.3 ClientHello + // (record header + handshake header + random + session_id + cipher_suites + // + compression + at least one extension with SNI). The previous value of + // 512 was implicitly coupled to TLS_REQUEST_LENGTH=517 from the official + // Telegram MTProxy reference server, leaving only a 5-byte margin and + // incorrectly rejecting compact but spec-compliant ClientHellos from + // third-party clients or future Telegram versions. + if !tls_clienthello_len_in_bounds(tls_len) { + debug!(peer = %real_peer, tls_len = tls_len, max_tls_len = MAX_TLS_PLAINTEXT_SIZE, "TLS handshake length out of bounds"); stats.increment_connects_bad(); + maybe_apply_mask_reject_delay(&config).await; let (reader, writer) = tokio::io::split(stream); - handle_bad_client( + return Ok(masking_outcome( reader, writer, - &first_bytes, + first_bytes.to_vec(), real_peer, local_addr, - &config, - &beobachten, - ) - .await; - return Ok(HandshakeOutcome::Handled); + config.clone(), + beobachten.clone(), + )); } let mut handshake = vec![0u8; 5 + tls_len]; handshake[..5].copy_from_slice(&first_bytes); - stream.read_exact(&mut handshake[5..]).await?; + let body_read = match read_with_progress(&mut stream, &mut handshake[5..]).await { + Ok(n) => n, + Err(e) => { + debug!(peer = %real_peer, error = %e, tls_len = tls_len, "TLS ClientHello body read failed; engaging masking fallback"); + stats.increment_connects_bad(); + maybe_apply_mask_reject_delay(&config).await; + let initial_len = 5; + let (reader, writer) = tokio::io::split(stream); + return Ok(masking_outcome( + reader, + writer, + handshake[..initial_len].to_vec(), + real_peer, + local_addr, + config.clone(), + beobachten.clone(), + )); + } + }; + + if body_read < tls_len { + debug!(peer = %real_peer, got = body_read, expected = tls_len, "Truncated in-range TLS ClientHello; engaging masking fallback"); + stats.increment_connects_bad(); + maybe_apply_mask_reject_delay(&config).await; + let initial_len = 5 + body_read; + let (reader, writer) = tokio::io::split(stream); + return Ok(masking_outcome( + reader, + writer, + handshake[..initial_len].to_vec(), + real_peer, + local_addr, + config.clone(), + beobachten.clone(), + )); + } let (read_half, write_half) = tokio::io::split(stream); @@ -180,17 +404,15 @@ where HandshakeResult::Success(result) => result, HandshakeResult::BadClient { reader, writer } => { stats.increment_connects_bad(); - handle_bad_client( + return Ok(masking_outcome( reader, writer, - &handshake, + handshake.clone(), real_peer, local_addr, - &config, - &beobachten, - ) - .await; - return Ok(HandshakeOutcome::Handled); + config.clone(), + beobachten.clone(), + )); } HandshakeResult::Error(e) => return Err(e), }; @@ -205,10 +427,32 @@ where &config, &replay_checker, true, Some(tls_user.as_str()), ).await { HandshakeResult::Success(result) => result, - HandshakeResult::BadClient { reader: _, writer: _ } => { + HandshakeResult::BadClient { reader, writer } => { + // MTProto failed after TLS ServerHello was already sent. + // Switch fallback relay back to raw transport so the mask + // backend receives valid TLS records (not unwrapped payload). + let (reader, pending_plaintext) = reader.into_inner_with_pending_plaintext(); + let writer = writer.into_inner(); + let pending_record = if pending_plaintext.is_empty() { + Vec::new() + } else { + wrap_tls_application_record(&pending_plaintext) + }; + let reader = tokio::io::AsyncReadExt::chain(std::io::Cursor::new(pending_record), reader); stats.increment_connects_bad(); - debug!(peer = %peer, "Valid TLS but invalid MTProto handshake"); - return Ok(HandshakeOutcome::Handled); + debug!( + peer = %peer, + "Authenticated TLS session failed MTProto validation; engaging masking fallback" + ); + return Ok(masking_outcome( + reader, + writer, + Vec::new(), + real_peer, + local_addr, + config.clone(), + beobachten.clone(), + )); } HandshakeResult::Error(e) => return Err(e), }; @@ -225,18 +469,17 @@ where if !config.general.modes.classic && !config.general.modes.secure { debug!(peer = %real_peer, "Non-TLS modes disabled"); stats.increment_connects_bad(); + maybe_apply_mask_reject_delay(&config).await; let (reader, writer) = tokio::io::split(stream); - handle_bad_client( + return Ok(masking_outcome( reader, writer, - &first_bytes, + first_bytes.to_vec(), real_peer, local_addr, - &config, - &beobachten, - ) - .await; - return Ok(HandshakeOutcome::Handled); + config.clone(), + beobachten.clone(), + )); } let mut handshake = [0u8; HANDSHAKE_LEN]; @@ -252,17 +495,15 @@ where HandshakeResult::Success(result) => result, HandshakeResult::BadClient { reader, writer } => { stats.increment_connects_bad(); - handle_bad_client( + return Ok(masking_outcome( reader, writer, - &handshake, + handshake.to_vec(), real_peer, local_addr, - &config, - &beobachten, - ) - .await; - return Ok(HandshakeOutcome::Handled); + config.clone(), + beobachten.clone(), + )); } HandshakeResult::Error(e) => return Err(e), }; @@ -312,8 +553,7 @@ where // Phase 2: relay (WITHOUT handshake timeout — relay has its own activity timeouts) match outcome { - HandshakeOutcome::NeedsRelay(fut) => fut.await, - HandshakeOutcome::Handled => Ok(()), + HandshakeOutcome::NeedsRelay(fut) | HandshakeOutcome::NeedsMasking(fut) => fut.await, } } @@ -382,7 +622,6 @@ impl RunningClientHandler { pub async fn run(self) -> Result<()> { self.stats.increment_connects_all(); let peer = self.peer; - let _ip_tracker = self.ip_tracker.clone(); debug!(peer = %peer, "New connection"); if let Err(e) = configure_client_socket( @@ -393,7 +632,7 @@ impl RunningClientHandler { debug!(peer = %peer, error = %e, "Failed to configure client socket"); } - let handshake_timeout = Duration::from_secs(self.config.timeouts.client_handshake); + let handshake_timeout = handshake_timeout_with_mask_grace(&self.config); let stats = self.stats.clone(); let config_for_timeout = self.config.clone(); let beobachten_for_timeout = self.beobachten.clone(); @@ -427,8 +666,7 @@ impl RunningClientHandler { // Phase 2: relay (WITHOUT handshake timeout — relay has its own activity timeouts) match outcome { - HandshakeOutcome::NeedsRelay(fut) => fut.await, - HandshakeOutcome::Handled => Ok(()), + HandshakeOutcome::NeedsRelay(fut) | HandshakeOutcome::NeedsMasking(fut) => fut.await, } } @@ -446,6 +684,24 @@ impl RunningClientHandler { .await { Ok(Ok(info)) => { + if !is_trusted_proxy_source( + self.peer.ip(), + &self.config.server.proxy_protocol_trusted_cidrs, + ) { + self.stats.increment_connects_bad(); + warn!( + peer = %self.peer, + trusted = ?self.config.server.proxy_protocol_trusted_cidrs, + "Rejecting PROXY protocol header from untrusted source" + ); + record_beobachten_class( + &self.beobachten, + &self.config, + self.peer.ip(), + "other", + ); + return Err(ProxyError::InvalidProxyProtocol); + } debug!( peer = %self.peer, client = %info.src_addr, @@ -495,7 +751,6 @@ impl RunningClientHandler { let is_tls = tls::is_tls_handshake(&first_bytes[..3]); let peer = self.peer; - let _ip_tracker = self.ip_tracker.clone(); debug!(peer = %peer, is_tls = is_tls, "Handshake type detected"); @@ -508,32 +763,72 @@ impl RunningClientHandler { async fn handle_tls_client(mut self, first_bytes: [u8; 5], local_addr: SocketAddr) -> Result { let peer = self.peer; - let _ip_tracker = self.ip_tracker.clone(); let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize; debug!(peer = %peer, tls_len = tls_len, "Reading TLS handshake"); - if tls_len < 512 { - debug!(peer = %peer, tls_len = tls_len, "TLS handshake too short"); + // RFC 8446 §5.1: TLS record payload MUST NOT exceed 2^14 (16_384) bytes. + // Lower bound is a structural minimum for a valid TLS 1.3 ClientHello + // (record header + handshake header + random + session_id + cipher_suites + // + compression + at least one extension with SNI). The previous value of + // 512 was implicitly coupled to TLS_REQUEST_LENGTH=517 from the official + // Telegram MTProxy reference server, leaving only a 5-byte margin and + // incorrectly rejecting compact but spec-compliant ClientHellos from + // third-party clients or future Telegram versions. + if !tls_clienthello_len_in_bounds(tls_len) { + debug!(peer = %peer, tls_len = tls_len, max_tls_len = MAX_TLS_PLAINTEXT_SIZE, "TLS handshake length out of bounds"); self.stats.increment_connects_bad(); + maybe_apply_mask_reject_delay(&self.config).await; let (reader, writer) = self.stream.into_split(); - handle_bad_client( + return Ok(masking_outcome( reader, writer, - &first_bytes, + first_bytes.to_vec(), peer, local_addr, - &self.config, - &self.beobachten, - ) - .await; - return Ok(HandshakeOutcome::Handled); + self.config.clone(), + self.beobachten.clone(), + )); } let mut handshake = vec![0u8; 5 + tls_len]; handshake[..5].copy_from_slice(&first_bytes); - self.stream.read_exact(&mut handshake[5..]).await?; + let body_read = match read_with_progress(&mut self.stream, &mut handshake[5..]).await { + Ok(n) => n, + Err(e) => { + debug!(peer = %peer, error = %e, tls_len = tls_len, "TLS ClientHello body read failed; engaging masking fallback"); + self.stats.increment_connects_bad(); + maybe_apply_mask_reject_delay(&self.config).await; + let (reader, writer) = self.stream.into_split(); + return Ok(masking_outcome( + reader, + writer, + handshake[..5].to_vec(), + peer, + local_addr, + self.config.clone(), + self.beobachten.clone(), + )); + } + }; + + if body_read < tls_len { + debug!(peer = %peer, got = body_read, expected = tls_len, "Truncated in-range TLS ClientHello; engaging masking fallback"); + self.stats.increment_connects_bad(); + maybe_apply_mask_reject_delay(&self.config).await; + let initial_len = 5 + body_read; + let (reader, writer) = self.stream.into_split(); + return Ok(masking_outcome( + reader, + writer, + handshake[..initial_len].to_vec(), + peer, + local_addr, + self.config.clone(), + self.beobachten.clone(), + )); + } let config = self.config.clone(); let replay_checker = self.replay_checker.clone(); @@ -557,17 +852,15 @@ impl RunningClientHandler { HandshakeResult::Success(result) => result, HandshakeResult::BadClient { reader, writer } => { stats.increment_connects_bad(); - handle_bad_client( + return Ok(masking_outcome( reader, writer, - &handshake, + handshake.clone(), peer, local_addr, - &config, - &self.beobachten, - ) - .await; - return Ok(HandshakeOutcome::Handled); + config.clone(), + self.beobachten.clone(), + )); } HandshakeResult::Error(e) => return Err(e), }; @@ -591,13 +884,32 @@ impl RunningClientHandler { .await { HandshakeResult::Success(result) => result, - HandshakeResult::BadClient { - reader: _, - writer: _, - } => { + HandshakeResult::BadClient { reader, writer } => { + // MTProto failed after TLS ServerHello was already sent. + // Switch fallback relay back to raw transport so the mask + // backend receives valid TLS records (not unwrapped payload). + let (reader, pending_plaintext) = reader.into_inner_with_pending_plaintext(); + let writer = writer.into_inner(); + let pending_record = if pending_plaintext.is_empty() { + Vec::new() + } else { + wrap_tls_application_record(&pending_plaintext) + }; + let reader = tokio::io::AsyncReadExt::chain(std::io::Cursor::new(pending_record), reader); stats.increment_connects_bad(); - debug!(peer = %peer, "Valid TLS but invalid MTProto handshake"); - return Ok(HandshakeOutcome::Handled); + debug!( + peer = %peer, + "Authenticated TLS session failed MTProto validation; engaging masking fallback" + ); + return Ok(masking_outcome( + reader, + writer, + Vec::new(), + peer, + local_addr, + config.clone(), + self.beobachten.clone(), + )); } HandshakeResult::Error(e) => return Err(e), }; @@ -623,23 +935,21 @@ impl RunningClientHandler { async fn handle_direct_client(mut self, first_bytes: [u8; 5], local_addr: SocketAddr) -> Result { let peer = self.peer; - let _ip_tracker = self.ip_tracker.clone(); if !self.config.general.modes.classic && !self.config.general.modes.secure { debug!(peer = %peer, "Non-TLS modes disabled"); self.stats.increment_connects_bad(); + maybe_apply_mask_reject_delay(&self.config).await; let (reader, writer) = self.stream.into_split(); - handle_bad_client( + return Ok(masking_outcome( reader, writer, - &first_bytes, + first_bytes.to_vec(), peer, local_addr, - &self.config, - &self.beobachten, - ) - .await; - return Ok(HandshakeOutcome::Handled); + self.config.clone(), + self.beobachten.clone(), + )); } let mut handshake = [0u8; HANDSHAKE_LEN]; @@ -668,17 +978,15 @@ impl RunningClientHandler { HandshakeResult::Success(result) => result, HandshakeResult::BadClient { reader, writer } => { stats.increment_connects_bad(); - handle_bad_client( + return Ok(masking_outcome( reader, writer, - &handshake, + handshake.to_vec(), peer, local_addr, - &config, - &self.beobachten, - ) - .await; - return Ok(HandshakeOutcome::Handled); + config.clone(), + self.beobachten.clone(), + )); } HandshakeResult::Error(e) => return Err(e), }; @@ -727,21 +1035,22 @@ impl RunningClientHandler { { let user = success.user.clone(); - if let Err(e) = Self::check_user_limits_static(&user, &config, &stats, peer_addr, &ip_tracker).await { - warn!(user = %user, error = %e, "User limit exceeded"); - return Err(e); - } - - let registration = register_session(&user, success.dc_idx); - if registration.replaced_existing { - stats.increment_reconnect_evict_total(); - warn!( - user = %user, - dc = success.dc_idx, - "Reconnect detected: replacing active session for user+dc" - ); - } - let session_lease = registration.lease; + let user_limit_reservation = + match Self::acquire_user_connection_reservation_static( + &user, + &config, + stats.clone(), + peer_addr, + ip_tracker, + ) + .await + { + Ok(reservation) => reservation, + Err(e) => { + warn!(user = %user, error = %e, "User admission check failed"); + return Err(e); + } + }; let route_snapshot = route_runtime.snapshot(); let session_id = rng.u64(); @@ -754,7 +1063,7 @@ impl RunningClientHandler { client_writer, success, pool.clone(), - stats, + stats.clone(), config, buffer_pool, local_addr, @@ -762,7 +1071,6 @@ impl RunningClientHandler { route_runtime.subscribe(), route_snapshot, session_id, - session_lease.clone(), ) .await } else { @@ -772,14 +1080,13 @@ impl RunningClientHandler { client_writer, success, upstream_manager, - stats, + stats.clone(), config, buffer_pool, rng, route_runtime.subscribe(), route_snapshot, session_id, - session_lease.clone(), ) .await } @@ -790,25 +1097,78 @@ impl RunningClientHandler { client_writer, success, upstream_manager, - stats, + stats.clone(), config, buffer_pool, rng, route_runtime.subscribe(), route_snapshot, session_id, - session_lease.clone(), ) .await }; - - ip_tracker.remove_ip(&user, peer_addr.ip()).await; + user_limit_reservation.release().await; relay_result } + async fn acquire_user_connection_reservation_static( + user: &str, + config: &ProxyConfig, + stats: Arc, + peer_addr: SocketAddr, + ip_tracker: Arc, + ) -> Result { + if let Some(expiration) = config.access.user_expirations.get(user) + && chrono::Utc::now() > *expiration + { + return Err(ProxyError::UserExpired { + user: user.to_string(), + }); + } + + if let Some(quota) = config.access.user_data_quota.get(user) + && stats.get_user_total_octets(user) >= *quota + { + return Err(ProxyError::DataQuotaExceeded { + user: user.to_string(), + }); + } + + let limit = config.access.user_max_tcp_conns.get(user).map(|v| *v as u64); + if !stats.try_acquire_user_curr_connects(user, limit) { + return Err(ProxyError::ConnectionLimitExceeded { + user: user.to_string(), + }); + } + + match ip_tracker.check_and_add(user, peer_addr.ip()).await { + Ok(()) => {} + Err(reason) => { + stats.decrement_user_curr_connects(user); + warn!( + user = %user, + ip = %peer_addr.ip(), + reason = %reason, + "IP limit exceeded" + ); + return Err(ProxyError::ConnectionLimitExceeded { + user: user.to_string(), + }); + } + } + + Ok(UserConnectionReservation::new( + stats, + ip_tracker, + user.to_string(), + peer_addr.ip(), + )) + } + + #[cfg(test)] async fn check_user_limits_static( - user: &str, - config: &ProxyConfig, + user: &str, + config: &ProxyConfig, stats: &Stats, peer_addr: SocketAddr, ip_tracker: &UserIpTracker, @@ -821,9 +1181,32 @@ impl RunningClientHandler { }); } - let ip_reserved = match ip_tracker.check_and_add(user, peer_addr.ip()).await { - Ok(()) => true, + if let Some(quota) = config.access.user_data_quota.get(user) + && stats.get_user_total_octets(user) >= *quota + { + return Err(ProxyError::DataQuotaExceeded { + user: user.to_string(), + }); + } + + let limit = config + .access + .user_max_tcp_conns + .get(user) + .map(|v| *v as u64); + if !stats.try_acquire_user_curr_connects(user, limit) { + return Err(ProxyError::ConnectionLimitExceeded { + user: user.to_string(), + }); + } + + match ip_tracker.check_and_add(user, peer_addr.ip()).await { + Ok(()) => { + ip_tracker.remove_ip(user, peer_addr.ip()).await; + stats.decrement_user_curr_connects(user); + } Err(reason) => { + stats.decrement_user_curr_connects(user); warn!( user = %user, ip = %peer_addr.ip(), @@ -834,33 +1217,80 @@ impl RunningClientHandler { user: user.to_string(), }); } - }; - // IP limit check - - if let Some(limit) = config.access.user_max_tcp_conns.get(user) - && stats.get_user_curr_connects(user) >= *limit as u64 - { - if ip_reserved { - ip_tracker.remove_ip(user, peer_addr.ip()).await; - stats.increment_ip_reservation_rollback_tcp_limit_total(); - } - return Err(ProxyError::ConnectionLimitExceeded { - user: user.to_string(), - }); - } - - if let Some(quota) = config.access.user_data_quota.get(user) - && stats.get_user_total_octets(user) >= *quota - { - if ip_reserved { - ip_tracker.remove_ip(user, peer_addr.ip()).await; - stats.increment_ip_reservation_rollback_quota_limit_total(); - } - return Err(ProxyError::DataQuotaExceeded { - user: user.to_string(), - }); } Ok(()) } } + +#[cfg(test)] +#[path = "tests/client_security_tests.rs"] +mod security_tests; + +#[cfg(test)] +#[path = "tests/client_adversarial_tests.rs"] +mod adversarial_tests; + +#[cfg(test)] +#[path = "tests/client_tls_mtproto_fallback_security_tests.rs"] +mod tls_mtproto_fallback_security_tests; + +#[cfg(test)] +#[path = "tests/client_tls_clienthello_size_security_tests.rs"] +mod tls_clienthello_size_security_tests; + +#[cfg(test)] +#[path = "tests/client_tls_clienthello_truncation_adversarial_tests.rs"] +mod tls_clienthello_truncation_adversarial_tests; + +#[cfg(test)] +#[path = "tests/client_timing_profile_adversarial_tests.rs"] +mod timing_profile_adversarial_tests; + +#[cfg(test)] +#[path = "tests/client_masking_budget_security_tests.rs"] +mod masking_budget_security_tests; + +#[cfg(test)] +#[path = "tests/client_masking_redteam_expected_fail_tests.rs"] +mod masking_redteam_expected_fail_tests; + +#[cfg(test)] +#[path = "tests/client_masking_hard_adversarial_tests.rs"] +mod masking_hard_adversarial_tests; + +#[cfg(test)] +#[path = "tests/client_masking_stress_adversarial_tests.rs"] +mod masking_stress_adversarial_tests; + +#[cfg(test)] +#[path = "tests/client_masking_blackhat_campaign_tests.rs"] +mod masking_blackhat_campaign_tests; + +#[cfg(test)] +#[path = "tests/client_masking_diagnostics_security_tests.rs"] +mod masking_diagnostics_security_tests; + +#[cfg(test)] +#[path = "tests/client_masking_shape_hardening_security_tests.rs"] +mod masking_shape_hardening_security_tests; + +#[cfg(test)] +#[path = "tests/client_masking_shape_hardening_adversarial_tests.rs"] +mod masking_shape_hardening_adversarial_tests; + +#[cfg(test)] +#[path = "tests/client_masking_shape_hardening_redteam_expected_fail_tests.rs"] +mod masking_shape_hardening_redteam_expected_fail_tests; + +#[cfg(test)] +#[path = "tests/client_masking_shape_classifier_fuzz_redteam_expected_fail_tests.rs"] +mod masking_shape_classifier_fuzz_redteam_expected_fail_tests; + +#[cfg(test)] +#[path = "tests/client_masking_probe_evasion_blackhat_tests.rs"] +mod masking_probe_evasion_blackhat_tests; + +#[cfg(test)] +#[path = "tests/client_beobachten_ttl_bounds_security_tests.rs"] +mod beobachten_ttl_bounds_security_tests; diff --git a/src/proxy/direct_relay.rs b/src/proxy/direct_relay.rs index 108949c..18cbda3 100644 --- a/src/proxy/direct_relay.rs +++ b/src/proxy/direct_relay.rs @@ -1,7 +1,11 @@ +use std::ffi::OsString; use std::fs::OpenOptions; use std::io::Write; use std::net::SocketAddr; +use std::path::{Component, Path, PathBuf}; use std::sync::Arc; +use std::collections::HashSet; +use std::sync::{Mutex, OnceLock}; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf, split}; use tokio::sync::watch; @@ -17,12 +21,218 @@ use crate::proxy::route_mode::{ ROUTE_SWITCH_ERROR_MSG, RelayRouteMode, RouteCutoverState, affected_cutover_state, cutover_stagger_delay, }; -use crate::proxy::adaptive_buffers; -use crate::proxy::session_eviction::SessionLease; use crate::stats::Stats; use crate::stream::{BufferPool, CryptoReader, CryptoWriter}; use crate::transport::UpstreamManager; +#[cfg(unix)] +use std::os::unix::fs::OpenOptionsExt; +#[cfg(unix)] +use std::os::unix::ffi::OsStrExt; +#[cfg(unix)] +use std::os::unix::io::{AsRawFd, FromRawFd}; + +const UNKNOWN_DC_LOG_DISTINCT_LIMIT: usize = 1024; +static LOGGED_UNKNOWN_DCS: OnceLock>> = OnceLock::new(); +const MAX_SCOPE_HINT_LEN: usize = 64; + +fn validated_scope_hint(user: &str) -> Option<&str> { + let scope = user.strip_prefix("scope_")?; + if scope.is_empty() || scope.len() > MAX_SCOPE_HINT_LEN { + return None; + } + if scope + .bytes() + .all(|b| b.is_ascii_alphanumeric() || b == b'-') + { + Some(scope) + } else { + None + } +} + +#[derive(Clone)] +struct SanitizedUnknownDcLogPath { + resolved_path: PathBuf, + allowed_parent: PathBuf, + file_name: OsString, +} + +// In tests, this function shares global mutable state. Callers that also use +// cache-reset helpers must hold `unknown_dc_test_lock()` to keep assertions +// deterministic under parallel execution. +fn should_log_unknown_dc(dc_idx: i16) -> bool { + let set = LOGGED_UNKNOWN_DCS.get_or_init(|| Mutex::new(HashSet::new())); + should_log_unknown_dc_with_set(set, dc_idx) +} + +fn should_log_unknown_dc_with_set(set: &Mutex>, dc_idx: i16) -> bool { + match set.lock() { + Ok(mut guard) => { + if guard.contains(&dc_idx) { + return false; + } + if guard.len() >= UNKNOWN_DC_LOG_DISTINCT_LIMIT { + return false; + } + guard.insert(dc_idx) + } + // Fail closed on poisoned state to avoid unbounded blocking log writes. + Err(_) => false, + } +} + +fn sanitize_unknown_dc_log_path(path: &str) -> Option { + let candidate = Path::new(path); + if candidate.as_os_str().is_empty() { + return None; + } + if candidate + .components() + .any(|component| matches!(component, Component::ParentDir)) + { + return None; + } + + let cwd = std::env::current_dir().ok()?; + let file_name = candidate.file_name()?; + let parent = candidate.parent().unwrap_or_else(|| Path::new(".")); + let parent_path = if parent.is_absolute() { + parent.to_path_buf() + } else { + cwd.join(parent) + }; + let canonical_parent = parent_path.canonicalize().ok()?; + if !canonical_parent.is_dir() { + return None; + } + + Some(SanitizedUnknownDcLogPath { + resolved_path: canonical_parent.join(file_name), + allowed_parent: canonical_parent, + file_name: file_name.to_os_string(), + }) +} + +fn unknown_dc_log_path_is_still_safe(path: &SanitizedUnknownDcLogPath) -> bool { + let Some(parent) = path.resolved_path.parent() else { + return false; + }; + let Ok(current_parent) = parent.canonicalize() else { + return false; + }; + if current_parent != path.allowed_parent { + return false; + } + + if let Ok(canonical_target) = path.resolved_path.canonicalize() { + let Some(target_parent) = canonical_target.parent() else { + return false; + }; + let Some(target_name) = canonical_target.file_name() else { + return false; + }; + if target_parent != path.allowed_parent || target_name != path.file_name { + return false; + } + } + + true +} + +#[cfg(test)] +fn open_unknown_dc_log_append(path: &Path) -> std::io::Result { + #[cfg(unix)] + { + OpenOptions::new() + .create(true) + .append(true) + .custom_flags(libc::O_NOFOLLOW) + .open(path) + } + #[cfg(not(unix))] + { + let _ = path; + Err(std::io::Error::new( + std::io::ErrorKind::PermissionDenied, + "unknown_dc_file_log_enabled requires unix O_NOFOLLOW support", + )) + } +} + +fn open_unknown_dc_log_append_anchored(path: &SanitizedUnknownDcLogPath) -> std::io::Result { + #[cfg(unix)] + { + let parent = OpenOptions::new() + .read(true) + .custom_flags(libc::O_DIRECTORY | libc::O_NOFOLLOW | libc::O_CLOEXEC) + .open(&path.allowed_parent)?; + + let file_name = std::ffi::CString::new(path.file_name.as_os_str().as_bytes()) + .map_err(|_| std::io::Error::new(std::io::ErrorKind::InvalidInput, "unknown DC log file name contains NUL byte"))?; + + let fd = unsafe { + libc::openat( + parent.as_raw_fd(), + file_name.as_ptr(), + libc::O_CREAT | libc::O_APPEND | libc::O_WRONLY | libc::O_NOFOLLOW | libc::O_CLOEXEC, + 0o600, + ) + }; + + if fd < 0 { + return Err(std::io::Error::last_os_error()); + } + + let file = unsafe { std::fs::File::from_raw_fd(fd) }; + Ok(file) + } + #[cfg(not(unix))] + { + let _ = path; + Err(std::io::Error::new( + std::io::ErrorKind::PermissionDenied, + "unknown_dc_file_log_enabled requires unix O_NOFOLLOW support", + )) + } +} + +fn append_unknown_dc_line(file: &mut std::fs::File, dc_idx: i16) -> std::io::Result<()> { + #[cfg(unix)] + { + if unsafe { libc::flock(file.as_raw_fd(), libc::LOCK_EX) } != 0 { + return Err(std::io::Error::last_os_error()); + } + + let write_result = writeln!(file, "dc_idx={dc_idx}"); + + if unsafe { libc::flock(file.as_raw_fd(), libc::LOCK_UN) } != 0 { + return Err(std::io::Error::last_os_error()); + } + + write_result + } + #[cfg(not(unix))] + { + writeln!(file, "dc_idx={dc_idx}") + } +} + +#[cfg(test)] +fn clear_unknown_dc_log_cache_for_testing() { + if let Some(set) = LOGGED_UNKNOWN_DCS.get() + && let Ok(mut guard) = set.lock() + { + guard.clear(); + } +} + +#[cfg(test)] +fn unknown_dc_test_lock() -> &'static Mutex<()> { + static TEST_LOCK: OnceLock> = OnceLock::new(); + TEST_LOCK.get_or_init(|| Mutex::new(())) +} + pub(crate) async fn handle_via_direct( client_reader: CryptoReader, client_writer: CryptoWriter, @@ -35,7 +245,6 @@ pub(crate) async fn handle_via_direct( mut route_rx: watch::Receiver, route_snapshot: RouteCutoverState, session_id: u64, - session_lease: SessionLease, ) -> Result<()> where R: AsyncRead + Unpin + Send + 'static, @@ -54,12 +263,15 @@ where "Connecting to Telegram DC" ); + let scope_hint = validated_scope_hint(user); + if user.starts_with("scope_") && scope_hint.is_none() { + warn!( + user = %user, + "Ignoring invalid scope hint and falling back to default upstream selection" + ); + } let tg_stream = upstream_manager - .connect( - dc_addr, - Some(success.dc_idx), - user.strip_prefix("scope_").filter(|s| !s.is_empty()), - ) + .connect(dc_addr, Some(success.dc_idx), scope_hint) .await?; debug!(peer = %success.peer, dc_addr = %dc_addr, "Connected, performing TG handshake"); @@ -70,29 +282,19 @@ where debug!(peer = %success.peer, "TG handshake complete, starting relay"); stats.increment_user_connects(user); - stats.increment_user_curr_connects(user); - stats.increment_current_connections_direct(); - - let seed_tier = adaptive_buffers::seed_tier_for_user(user); - let (c2s_copy_buf, s2c_copy_buf) = adaptive_buffers::direct_copy_buffers_for_tier( - seed_tier, - config.general.direct_relay_copy_buf_c2s_bytes, - config.general.direct_relay_copy_buf_s2c_bytes, - ); + let _direct_connection_lease = stats.acquire_direct_connection_lease(); let relay_result = relay_bidirectional( client_reader, client_writer, tg_reader, tg_writer, - c2s_copy_buf, - s2c_copy_buf, + config.general.direct_relay_copy_buf_c2s_bytes, + config.general.direct_relay_copy_buf_s2c_bytes, user, - success.dc_idx, Arc::clone(&stats), + config.access.user_data_quota.get(user).copied(), buffer_pool, - session_lease, - seed_tier, ); tokio::pin!(relay_result); let relay_result = loop { @@ -122,9 +324,6 @@ where } }; - stats.decrement_current_connections_direct(); - stats.decrement_user_curr_connects(user); - match &relay_result { Ok(()) => debug!(user = %user, "Direct relay completed"), Err(e) => debug!(user = %user, error = %e, "Direct relay ended with error"), @@ -181,12 +380,19 @@ fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result { && let Some(path) = &config.general.unknown_dc_log_path && let Ok(handle) = tokio::runtime::Handle::try_current() { - let path = path.clone(); - handle.spawn_blocking(move || { - if let Ok(mut file) = OpenOptions::new().create(true).append(true).open(path) { - let _ = writeln!(file, "dc_idx={dc_idx}"); + if let Some(path) = sanitize_unknown_dc_log_path(path) { + if should_log_unknown_dc(dc_idx) { + handle.spawn_blocking(move || { + if unknown_dc_log_path_is_still_safe(&path) + && let Ok(mut file) = open_unknown_dc_log_append_anchored(&path) + { + let _ = append_unknown_dc_line(&mut file, dc_idx); + } + }); } - }); + } else { + warn!(dc_idx = dc_idx, raw_path = %path, "Rejected unsafe unknown DC log path"); + } } } @@ -194,7 +400,7 @@ fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result { let fallback_idx = if default_dc >= 1 && default_dc <= num_dcs { default_dc - 1 } else { - 1 + 0 }; info!( @@ -222,8 +428,6 @@ where let (nonce, _tg_enc_key, _tg_enc_iv, _tg_dec_key, _tg_dec_iv) = generate_tg_nonce( success.proto_tag, success.dc_idx, - &success.dec_key, - success.dec_iv, &success.enc_key, success.enc_iv, rng, @@ -249,3 +453,19 @@ where CryptoWriter::new(write_half, tg_encryptor, max_pending), )) } + +#[cfg(test)] +#[path = "tests/direct_relay_security_tests.rs"] +mod security_tests; + +#[cfg(test)] +#[path = "tests/direct_relay_business_logic_tests.rs"] +mod business_logic_tests; + +#[cfg(test)] +#[path = "tests/direct_relay_common_mistakes_tests.rs"] +mod common_mistakes_tests; + +#[cfg(test)] +#[path = "tests/direct_relay_subtle_adversarial_tests.rs"] +mod subtle_adversarial_tests; diff --git a/src/proxy/handshake.rs b/src/proxy/handshake.rs index 296432f..8be9075 100644 --- a/src/proxy/handshake.rs +++ b/src/proxy/handshake.rs @@ -3,14 +3,21 @@ #![allow(dead_code)] use std::net::SocketAddr; +use std::collections::HashSet; +use std::collections::hash_map::RandomState; +use std::net::{IpAddr, Ipv6Addr}; use std::sync::Arc; -use std::time::Duration; +use std::sync::{Mutex, OnceLock}; +use std::hash::{BuildHasher, Hash, Hasher}; +use std::time::{Duration, Instant}; +use dashmap::DashMap; +use dashmap::mapref::entry::Entry; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tracing::{debug, warn, trace}; -use zeroize::Zeroize; +use zeroize::{Zeroize, Zeroizing}; use crate::crypto::{sha256, AesCtr, SecureRandom}; -use rand::Rng; +use rand::RngExt; use crate::protocol::constants::*; use crate::protocol::tls; use crate::stream::{FakeTlsReader, FakeTlsWriter, CryptoReader, CryptoWriter}; @@ -19,6 +26,461 @@ use crate::stats::ReplayChecker; use crate::config::ProxyConfig; use crate::tls_front::{TlsFrontCache, emulator}; +const ACCESS_SECRET_BYTES: usize = 16; +static INVALID_SECRET_WARNED: OnceLock>> = OnceLock::new(); +#[cfg(test)] +const WARNED_SECRET_MAX_ENTRIES: usize = 64; +#[cfg(not(test))] +const WARNED_SECRET_MAX_ENTRIES: usize = 1_024; + +const AUTH_PROBE_TRACK_RETENTION_SECS: u64 = 10 * 60; +#[cfg(test)] +const AUTH_PROBE_TRACK_MAX_ENTRIES: usize = 256; +#[cfg(not(test))] +const AUTH_PROBE_TRACK_MAX_ENTRIES: usize = 65_536; +const AUTH_PROBE_PRUNE_SCAN_LIMIT: usize = 1_024; +const AUTH_PROBE_BACKOFF_START_FAILS: u32 = 4; +const AUTH_PROBE_SATURATION_GRACE_FAILS: u32 = 2; + +#[cfg(test)] +const AUTH_PROBE_BACKOFF_BASE_MS: u64 = 1; +#[cfg(not(test))] +const AUTH_PROBE_BACKOFF_BASE_MS: u64 = 25; + +#[cfg(test)] +const AUTH_PROBE_BACKOFF_MAX_MS: u64 = 16; +#[cfg(not(test))] +const AUTH_PROBE_BACKOFF_MAX_MS: u64 = 1_000; + +#[derive(Clone, Copy)] +struct AuthProbeState { + fail_streak: u32, + blocked_until: Instant, + last_seen: Instant, +} + +#[derive(Clone, Copy)] +struct AuthProbeSaturationState { + fail_streak: u32, + blocked_until: Instant, + last_seen: Instant, +} + +static AUTH_PROBE_STATE: OnceLock> = OnceLock::new(); +static AUTH_PROBE_SATURATION_STATE: OnceLock>> = OnceLock::new(); +static AUTH_PROBE_EVICTION_HASHER: OnceLock = OnceLock::new(); + +fn auth_probe_state_map() -> &'static DashMap { + AUTH_PROBE_STATE.get_or_init(DashMap::new) +} + +fn auth_probe_saturation_state() -> &'static Mutex> { + AUTH_PROBE_SATURATION_STATE.get_or_init(|| Mutex::new(None)) +} + +fn auth_probe_saturation_state_lock( +) -> std::sync::MutexGuard<'static, Option> { + auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +fn normalize_auth_probe_ip(peer_ip: IpAddr) -> IpAddr { + match peer_ip { + IpAddr::V4(ip) => IpAddr::V4(ip), + IpAddr::V6(ip) => { + let [a, b, c, d, _, _, _, _] = ip.segments(); + IpAddr::V6(Ipv6Addr::new(a, b, c, d, 0, 0, 0, 0)) + } + } +} + +fn auth_probe_backoff(fail_streak: u32) -> Duration { + if fail_streak < AUTH_PROBE_BACKOFF_START_FAILS { + return Duration::ZERO; + } + let shift = (fail_streak - AUTH_PROBE_BACKOFF_START_FAILS).min(10); + let multiplier = 1u64.checked_shl(shift).unwrap_or(u64::MAX); + let ms = AUTH_PROBE_BACKOFF_BASE_MS + .saturating_mul(multiplier) + .min(AUTH_PROBE_BACKOFF_MAX_MS); + Duration::from_millis(ms) +} + +fn auth_probe_state_expired(state: &AuthProbeState, now: Instant) -> bool { + let retention = Duration::from_secs(AUTH_PROBE_TRACK_RETENTION_SECS); + now.duration_since(state.last_seen) > retention +} + +fn auth_probe_eviction_offset(peer_ip: IpAddr, now: Instant) -> usize { + let hasher_state = AUTH_PROBE_EVICTION_HASHER.get_or_init(RandomState::new); + let mut hasher = hasher_state.build_hasher(); + peer_ip.hash(&mut hasher); + now.hash(&mut hasher); + hasher.finish() as usize +} + +fn auth_probe_is_throttled(peer_ip: IpAddr, now: Instant) -> bool { + let peer_ip = normalize_auth_probe_ip(peer_ip); + let state = auth_probe_state_map(); + let Some(entry) = state.get(&peer_ip) else { + return false; + }; + if auth_probe_state_expired(&entry, now) { + drop(entry); + state.remove(&peer_ip); + return false; + } + now < entry.blocked_until +} + +fn auth_probe_saturation_grace_exhausted(peer_ip: IpAddr, now: Instant) -> bool { + let peer_ip = normalize_auth_probe_ip(peer_ip); + let state = auth_probe_state_map(); + let Some(entry) = state.get(&peer_ip) else { + return false; + }; + if auth_probe_state_expired(&entry, now) { + drop(entry); + state.remove(&peer_ip); + return false; + } + + entry.fail_streak >= AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS +} + +fn auth_probe_should_apply_preauth_throttle(peer_ip: IpAddr, now: Instant) -> bool { + if !auth_probe_is_throttled(peer_ip, now) { + return false; + } + + if !auth_probe_saturation_is_throttled(now) { + return true; + } + + auth_probe_saturation_grace_exhausted(peer_ip, now) +} + +fn auth_probe_saturation_is_throttled(now: Instant) -> bool { + let mut guard = auth_probe_saturation_state_lock(); + + let Some(state) = guard.as_mut() else { + return false; + }; + + if now.duration_since(state.last_seen) > Duration::from_secs(AUTH_PROBE_TRACK_RETENTION_SECS) { + *guard = None; + return false; + } + + if now < state.blocked_until { + return true; + } + + false +} + +fn auth_probe_note_saturation(now: Instant) { + let mut guard = auth_probe_saturation_state_lock(); + + match guard.as_mut() { + Some(state) + if now.duration_since(state.last_seen) + <= Duration::from_secs(AUTH_PROBE_TRACK_RETENTION_SECS) => + { + state.fail_streak = state.fail_streak.saturating_add(1); + state.last_seen = now; + state.blocked_until = now + auth_probe_backoff(state.fail_streak); + } + _ => { + let fail_streak = AUTH_PROBE_BACKOFF_START_FAILS; + *guard = Some(AuthProbeSaturationState { + fail_streak, + blocked_until: now + auth_probe_backoff(fail_streak), + last_seen: now, + }); + } + } +} + +fn auth_probe_record_failure(peer_ip: IpAddr, now: Instant) { + let peer_ip = normalize_auth_probe_ip(peer_ip); + let state = auth_probe_state_map(); + auth_probe_record_failure_with_state(state, peer_ip, now); +} + +fn auth_probe_record_failure_with_state( + state: &DashMap, + peer_ip: IpAddr, + now: Instant, +) { + let make_new_state = || AuthProbeState { + fail_streak: 1, + blocked_until: now + auth_probe_backoff(1), + last_seen: now, + }; + + let update_existing = |entry: &mut AuthProbeState| { + if auth_probe_state_expired(entry, now) { + *entry = make_new_state(); + } else { + entry.fail_streak = entry.fail_streak.saturating_add(1); + entry.last_seen = now; + entry.blocked_until = now + auth_probe_backoff(entry.fail_streak); + } + }; + + match state.entry(peer_ip) { + Entry::Occupied(mut entry) => { + update_existing(entry.get_mut()); + return; + } + Entry::Vacant(_) => {} + } + + if state.len() >= AUTH_PROBE_TRACK_MAX_ENTRIES { + let mut rounds = 0usize; + while state.len() >= AUTH_PROBE_TRACK_MAX_ENTRIES { + rounds += 1; + if rounds > 8 { + auth_probe_note_saturation(now); + let mut eviction_candidate: Option<(IpAddr, u32, Instant)> = None; + for entry in state.iter().take(AUTH_PROBE_PRUNE_SCAN_LIMIT) { + let key = *entry.key(); + let fail_streak = entry.value().fail_streak; + let last_seen = entry.value().last_seen; + match eviction_candidate { + Some((_, current_fail, current_seen)) + if fail_streak > current_fail + || (fail_streak == current_fail && last_seen >= current_seen) => + { + } + _ => eviction_candidate = Some((key, fail_streak, last_seen)), + } + } + + let Some((evict_key, _, _)) = eviction_candidate else { + return; + }; + state.remove(&evict_key); + break; + } + + let mut stale_keys = Vec::new(); + let mut eviction_candidate: Option<(IpAddr, u32, Instant)> = None; + let state_len = state.len(); + let scan_limit = state_len.min(AUTH_PROBE_PRUNE_SCAN_LIMIT); + let start_offset = if state_len == 0 { + 0 + } else { + auth_probe_eviction_offset(peer_ip, now) % state_len + }; + + let mut scanned = 0usize; + for entry in state.iter().skip(start_offset) { + let key = *entry.key(); + let fail_streak = entry.value().fail_streak; + let last_seen = entry.value().last_seen; + match eviction_candidate { + Some((_, current_fail, current_seen)) + if fail_streak > current_fail + || (fail_streak == current_fail && last_seen >= current_seen) => + { + } + _ => eviction_candidate = Some((key, fail_streak, last_seen)), + } + if auth_probe_state_expired(entry.value(), now) { + stale_keys.push(key); + } + scanned += 1; + if scanned >= scan_limit { + break; + } + } + + if scanned < scan_limit { + for entry in state.iter().take(scan_limit - scanned) { + let key = *entry.key(); + let fail_streak = entry.value().fail_streak; + let last_seen = entry.value().last_seen; + match eviction_candidate { + Some((_, current_fail, current_seen)) + if fail_streak > current_fail + || (fail_streak == current_fail && last_seen >= current_seen) => + { + } + _ => eviction_candidate = Some((key, fail_streak, last_seen)), + } + if auth_probe_state_expired(entry.value(), now) { + stale_keys.push(key); + } + } + } + + for stale_key in stale_keys { + state.remove(&stale_key); + } + + if state.len() < AUTH_PROBE_TRACK_MAX_ENTRIES { + break; + } + + let Some((evict_key, _, _)) = eviction_candidate else { + auth_probe_note_saturation(now); + return; + }; + state.remove(&evict_key); + auth_probe_note_saturation(now); + } + } + + match state.entry(peer_ip) { + Entry::Occupied(mut entry) => { + update_existing(entry.get_mut()); + } + Entry::Vacant(entry) => { + entry.insert(make_new_state()); + } + } +} + +fn auth_probe_record_success(peer_ip: IpAddr) { + let peer_ip = normalize_auth_probe_ip(peer_ip); + let state = auth_probe_state_map(); + state.remove(&peer_ip); +} + +#[cfg(test)] +fn clear_auth_probe_state_for_testing() { + if let Some(state) = AUTH_PROBE_STATE.get() { + state.clear(); + } + if AUTH_PROBE_SATURATION_STATE.get().is_some() { + let mut guard = auth_probe_saturation_state_lock(); + *guard = None; + } +} + +#[cfg(test)] +fn auth_probe_fail_streak_for_testing(peer_ip: IpAddr) -> Option { + let peer_ip = normalize_auth_probe_ip(peer_ip); + let state = AUTH_PROBE_STATE.get()?; + state.get(&peer_ip).map(|entry| entry.fail_streak) +} + +#[cfg(test)] +fn auth_probe_is_throttled_for_testing(peer_ip: IpAddr) -> bool { + auth_probe_is_throttled(peer_ip, Instant::now()) +} + +#[cfg(test)] +fn auth_probe_saturation_is_throttled_for_testing() -> bool { + auth_probe_saturation_is_throttled(Instant::now()) +} + +#[cfg(test)] +fn auth_probe_saturation_is_throttled_at_for_testing(now: Instant) -> bool { + auth_probe_saturation_is_throttled(now) +} + +#[cfg(test)] +fn auth_probe_test_lock() -> &'static Mutex<()> { + static TEST_LOCK: OnceLock> = OnceLock::new(); + TEST_LOCK.get_or_init(|| Mutex::new(())) +} + +#[cfg(test)] +fn clear_warned_secrets_for_testing() { + if let Some(warned) = INVALID_SECRET_WARNED.get() + && let Ok(mut guard) = warned.lock() + { + guard.clear(); + } +} + +#[cfg(test)] +fn warned_secrets_test_lock() -> &'static Mutex<()> { + static TEST_LOCK: OnceLock> = OnceLock::new(); + TEST_LOCK.get_or_init(|| Mutex::new(())) +} + +fn warn_invalid_secret_once(name: &str, reason: &str, expected: usize, got: Option) { + let key = (name.to_string(), reason.to_string()); + let warned = INVALID_SECRET_WARNED.get_or_init(|| Mutex::new(HashSet::new())); + let should_warn = match warned.lock() { + Ok(mut guard) => { + if !guard.contains(&key) && guard.len() >= WARNED_SECRET_MAX_ENTRIES { + false + } else { + guard.insert(key) + } + } + Err(_) => true, + }; + + if !should_warn { + return; + } + + match got { + Some(actual) => { + warn!( + user = %name, + expected = expected, + got = actual, + "Skipping user: access secret has unexpected length" + ); + } + None => { + warn!( + user = %name, + "Skipping user: access secret is not valid hex" + ); + } + } +} + +fn decode_user_secret(name: &str, secret_hex: &str) -> Option> { + match hex::decode(secret_hex) { + Ok(bytes) if bytes.len() == ACCESS_SECRET_BYTES => Some(bytes), + Ok(bytes) => { + warn_invalid_secret_once( + name, + "invalid_length", + ACCESS_SECRET_BYTES, + Some(bytes.len()), + ); + None + } + Err(_) => { + warn_invalid_secret_once(name, "invalid_hex", ACCESS_SECRET_BYTES, None); + None + } + } +} + +// Decide whether a client-supplied proto tag is allowed given the configured +// proxy modes and the transport that carried the handshake. +// +// A common mistake is to treat `modes.tls` and `modes.secure` as interchangeable +// even though they correspond to different transport profiles: `modes.tls` is +// for the TLS-fronted (EE-TLS) path, while `modes.secure` is for direct MTProto +// over TCP (DD). Enforcing this separation prevents an attacker from using a +// TLS-capable client to bypass the operator intent for the direct MTProto mode, +// and vice versa. +fn mode_enabled_for_proto(config: &ProxyConfig, proto_tag: ProtoTag, is_tls: bool) -> bool { + match proto_tag { + ProtoTag::Secure => { + if is_tls { + config.general.modes.tls + } else { + config.general.modes.secure + } + } + ProtoTag::Intermediate | ProtoTag::Abridged => config.general.modes.classic, + } +} + fn decode_user_secrets( config: &ProxyConfig, preferred_user: Option<&str>, @@ -27,7 +489,7 @@ fn decode_user_secrets( if let Some(preferred) = preferred_user && let Some(secret_hex) = config.access.users.get(preferred) - && let Ok(bytes) = hex::decode(secret_hex) + && let Some(bytes) = decode_user_secret(preferred, secret_hex) { secrets.push((preferred.to_string(), bytes)); } @@ -36,7 +498,7 @@ fn decode_user_secrets( if preferred_user.is_some_and(|preferred| preferred == name.as_str()) { continue; } - if let Ok(bytes) = hex::decode(secret_hex) { + if let Some(bytes) = decode_user_secret(name, secret_hex) { secrets.push((name.clone(), bytes)); } } @@ -44,11 +506,29 @@ fn decode_user_secrets( secrets } +async fn maybe_apply_server_hello_delay(config: &ProxyConfig) { + if config.censorship.server_hello_delay_max_ms == 0 { + return; + } + + let min = config.censorship.server_hello_delay_min_ms; + let max = config.censorship.server_hello_delay_max_ms.max(min); + let delay_ms = if max == min { + max + } else { + rand::rng().random_range(min..=max) + }; + + if delay_ms > 0 { + tokio::time::sleep(Duration::from_millis(delay_ms)).await; + } +} + /// Result of successful handshake /// /// Key material (`dec_key`, `dec_iv`, `enc_key`, `enc_iv`) is /// zeroized on drop. -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct HandshakeSuccess { /// Authenticated user name pub user: String, @@ -65,6 +545,7 @@ pub struct HandshakeSuccess { /// Client address pub peer: SocketAddr, /// Whether TLS was used + pub is_tls: bool, } @@ -94,28 +575,33 @@ where { debug!(peer = %peer, handshake_len = handshake.len(), "Processing TLS handshake"); + let throttle_now = Instant::now(); + if auth_probe_should_apply_preauth_throttle(peer.ip(), throttle_now) { + maybe_apply_server_hello_delay(config).await; + debug!(peer = %peer, "TLS handshake rejected by pre-auth probe throttle"); + return HandshakeResult::BadClient { reader, writer }; + } + if handshake.len() < tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 { + auth_probe_record_failure(peer.ip(), Instant::now()); + maybe_apply_server_hello_delay(config).await; debug!(peer = %peer, "TLS handshake too short"); return HandshakeResult::BadClient { reader, writer }; } - let digest = &handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN]; - let digest_half = &digest[..tls::TLS_DIGEST_HALF_LEN]; + let client_sni = tls::extract_sni_from_client_hello(handshake); + let secrets = decode_user_secrets(config, client_sni.as_deref()); - if replay_checker.check_and_add_tls_digest(digest_half) { - warn!(peer = %peer, "TLS replay attack detected (duplicate digest)"); - return HandshakeResult::BadClient { reader, writer }; - } - - let secrets = decode_user_secrets(config, None); - - let validation = match tls::validate_tls_handshake( + let validation = match tls::validate_tls_handshake_with_replay_window( handshake, &secrets, config.access.ignore_time_skew, + config.access.replay_window_secs, ) { Some(v) => v, None => { + auth_probe_record_failure(peer.ip(), Instant::now()); + maybe_apply_server_hello_delay(config).await; debug!( peer = %peer, ignore_time_skew = config.access.ignore_time_skew, @@ -125,16 +611,29 @@ where } }; + // Replay tracking is applied only after successful authentication to avoid + // letting unauthenticated probes evict valid entries from the replay cache. + let digest_half = &validation.digest[..tls::TLS_DIGEST_HALF_LEN]; + if replay_checker.check_and_add_tls_digest(digest_half) { + auth_probe_record_failure(peer.ip(), Instant::now()); + maybe_apply_server_hello_delay(config).await; + warn!(peer = %peer, "TLS replay attack detected (duplicate digest)"); + return HandshakeResult::BadClient { reader, writer }; + } + let secret = match secrets.iter().find(|(name, _)| *name == validation.user) { Some((_, s)) => s, - None => return HandshakeResult::BadClient { reader, writer }, + None => { + maybe_apply_server_hello_delay(config).await; + return HandshakeResult::BadClient { reader, writer }; + } }; let cached = if config.censorship.tls_emulation { if let Some(cache) = tls_cache.as_ref() { - let selected_domain = if let Some(sni) = tls::extract_sni_from_client_hello(handshake) { + let selected_domain = if let Some(sni) = client_sni.as_ref() { if cache.contains_domain(&sni).await { - sni + sni.clone() } else { config.censorship.tls_domain.clone() } @@ -166,6 +665,10 @@ where Some(b"h2".to_vec()) } else if alpn_list.iter().any(|p| p == b"http/1.1") { Some(b"http/1.1".to_vec()) + } else if !alpn_list.is_empty() { + maybe_apply_server_hello_delay(config).await; + debug!(peer = %peer, "Client ALPN list has no supported protocol; using masking fallback"); + return HandshakeResult::BadClient { reader, writer }; } else { None } @@ -196,19 +699,9 @@ where ) }; - // Optional anti-fingerprint delay before sending ServerHello. - if config.censorship.server_hello_delay_max_ms > 0 { - let min = config.censorship.server_hello_delay_min_ms; - let max = config.censorship.server_hello_delay_max_ms.max(min); - let delay_ms = if max == min { - max - } else { - rand::rng().random_range(min..=max) - }; - if delay_ms > 0 { - tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await; - } - } + // Apply the same optional delay budget used by reject paths to reduce + // distinguishability between success and fail-closed handshakes. + maybe_apply_server_hello_delay(config).await; debug!(peer = %peer, response_len = response.len(), "Sending TLS ServerHello"); @@ -228,6 +721,8 @@ where "TLS handshake successful" ); + auth_probe_record_success(peer.ip()); + HandshakeResult::Success(( FakeTlsReader::new(reader), FakeTlsWriter::new(writer), @@ -250,15 +745,25 @@ where R: AsyncRead + Unpin + Send, W: AsyncWrite + Unpin + Send, { - trace!(peer = %peer, handshake = ?hex::encode(handshake), "MTProto handshake bytes"); + let handshake_fingerprint = { + let digest = sha256(&handshake[..8]); + hex::encode(&digest[..4]) + }; + trace!( + peer = %peer, + handshake_fingerprint = %handshake_fingerprint, + "MTProto handshake prefix" + ); - let dec_prekey_iv = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]; - - if replay_checker.check_and_add_handshake(dec_prekey_iv) { - warn!(peer = %peer, "MTProto replay attack detected"); + let throttle_now = Instant::now(); + if auth_probe_should_apply_preauth_throttle(peer.ip(), throttle_now) { + maybe_apply_server_hello_delay(config).await; + debug!(peer = %peer, "MTProto handshake rejected by pre-auth probe throttle"); return HandshakeResult::BadClient { reader, writer }; } + let dec_prekey_iv = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]; + let enc_prekey_iv: Vec = dec_prekey_iv.iter().rev().copied().collect(); let decoded_users = decode_user_secrets(config, preferred_user); @@ -268,57 +773,66 @@ where let dec_prekey = &dec_prekey_iv[..PREKEY_LEN]; let dec_iv_bytes = &dec_prekey_iv[PREKEY_LEN..]; - let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len()); + let mut dec_key_input = Zeroizing::new(Vec::with_capacity(PREKEY_LEN + secret.len())); dec_key_input.extend_from_slice(dec_prekey); dec_key_input.extend_from_slice(&secret); let dec_key = sha256(&dec_key_input); - let dec_iv = u128::from_be_bytes(dec_iv_bytes.try_into().unwrap()); + let mut dec_iv_arr = [0u8; IV_LEN]; + dec_iv_arr.copy_from_slice(dec_iv_bytes); + let dec_iv = u128::from_be_bytes(dec_iv_arr); let mut decryptor = AesCtr::new(&dec_key, dec_iv); let decrypted = decryptor.decrypt(handshake); - let tag_bytes: [u8; 4] = decrypted[PROTO_TAG_POS..PROTO_TAG_POS + 4] - .try_into() - .unwrap(); + let tag_bytes: [u8; 4] = [ + decrypted[PROTO_TAG_POS], + decrypted[PROTO_TAG_POS + 1], + decrypted[PROTO_TAG_POS + 2], + decrypted[PROTO_TAG_POS + 3], + ]; let proto_tag = match ProtoTag::from_bytes(tag_bytes) { Some(tag) => tag, None => continue, }; - let mode_ok = match proto_tag { - ProtoTag::Secure => { - if is_tls { - config.general.modes.tls || config.general.modes.secure - } else { - config.general.modes.secure || config.general.modes.tls - } - } - ProtoTag::Intermediate | ProtoTag::Abridged => config.general.modes.classic, - }; + let mode_ok = mode_enabled_for_proto(config, proto_tag, is_tls); if !mode_ok { debug!(peer = %peer, user = %user, proto = ?proto_tag, "Mode not enabled"); continue; } - let dc_idx = i16::from_le_bytes( - decrypted[DC_IDX_POS..DC_IDX_POS + 2].try_into().unwrap() - ); + let dc_idx = i16::from_le_bytes([decrypted[DC_IDX_POS], decrypted[DC_IDX_POS + 1]]); let enc_prekey = &enc_prekey_iv[..PREKEY_LEN]; let enc_iv_bytes = &enc_prekey_iv[PREKEY_LEN..]; - let mut enc_key_input = Vec::with_capacity(PREKEY_LEN + secret.len()); + let mut enc_key_input = Zeroizing::new(Vec::with_capacity(PREKEY_LEN + secret.len())); enc_key_input.extend_from_slice(enc_prekey); enc_key_input.extend_from_slice(&secret); let enc_key = sha256(&enc_key_input); - let enc_iv = u128::from_be_bytes(enc_iv_bytes.try_into().unwrap()); + let mut enc_iv_arr = [0u8; IV_LEN]; + enc_iv_arr.copy_from_slice(enc_iv_bytes); + let enc_iv = u128::from_be_bytes(enc_iv_arr); let encryptor = AesCtr::new(&enc_key, enc_iv); +// Apply replay tracking only after successful authentication. + // + // This ordering prevents an attacker from producing invalid handshakes that + // still collide with a valid handshake's replay slot and thus evict a valid + // entry from the cache. We accept the cost of performing the full + // authentication check first to avoid poisoning the replay cache. + if replay_checker.check_and_add_handshake(dec_prekey_iv) { + auth_probe_record_failure(peer.ip(), Instant::now()); + maybe_apply_server_hello_delay(config).await; + warn!(peer = %peer, user = %user, "MTProto replay attack detected"); + return HandshakeResult::BadClient { reader, writer }; + } + let success = HandshakeSuccess { user: user.clone(), dc_idx, @@ -340,6 +854,8 @@ where "MTProto handshake successful" ); + auth_probe_record_success(peer.ip()); + let max_pending = config.general.crypto_pending_buffer; return HandshakeResult::Success(( CryptoReader::new(reader, decryptor), @@ -348,6 +864,8 @@ where )); } + auth_probe_record_failure(peer.ip(), Instant::now()); + maybe_apply_server_hello_delay(config).await; debug!(peer = %peer, "MTProto handshake: no matching user found"); HandshakeResult::BadClient { reader, writer } } @@ -356,8 +874,6 @@ where pub fn generate_tg_nonce( proto_tag: ProtoTag, dc_idx: i16, - _client_dec_key: &[u8; 32], - _client_dec_iv: u128, client_enc_key: &[u8; 32], client_enc_iv: u128, rng: &SecureRandom, @@ -365,14 +881,16 @@ pub fn generate_tg_nonce( ) -> ([u8; HANDSHAKE_LEN], [u8; 32], u128, [u8; 32], u128) { loop { let bytes = rng.bytes(HANDSHAKE_LEN); - let mut nonce: [u8; HANDSHAKE_LEN] = bytes.try_into().unwrap(); + let Ok(mut nonce): Result<[u8; HANDSHAKE_LEN], _> = bytes.try_into() else { + continue; + }; if RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]) { continue; } - let first_four: [u8; 4] = nonce[..4].try_into().unwrap(); + let first_four: [u8; 4] = [nonce[0], nonce[1], nonce[2], nonce[3]]; if RESERVED_NONCE_BEGINNINGS.contains(&first_four) { continue; } - let continue_four: [u8; 4] = nonce[4..8].try_into().unwrap(); + let continue_four: [u8; 4] = [nonce[4], nonce[5], nonce[6], nonce[7]]; if RESERVED_NONCE_CONTINUES.contains(&continue_four) { continue; } nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes()); @@ -380,7 +898,7 @@ pub fn generate_tg_nonce( nonce[DC_IDX_POS..DC_IDX_POS + 2].copy_from_slice(&dc_idx.to_le_bytes()); if fast_mode { - let mut key_iv = Vec::with_capacity(KEY_LEN + IV_LEN); + let mut key_iv = Zeroizing::new(Vec::with_capacity(KEY_LEN + IV_LEN)); key_iv.extend_from_slice(client_enc_key); key_iv.extend_from_slice(&client_enc_iv.to_be_bytes()); key_iv.reverse(); // Python/C behavior: reversed enc_key+enc_iv in nonce @@ -388,13 +906,19 @@ pub fn generate_tg_nonce( } let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN]; - let dec_key_iv: Vec = enc_key_iv.iter().rev().copied().collect(); + let dec_key_iv = Zeroizing::new(enc_key_iv.iter().rev().copied().collect::>()); - let tg_enc_key: [u8; 32] = enc_key_iv[..KEY_LEN].try_into().unwrap(); - let tg_enc_iv = u128::from_be_bytes(enc_key_iv[KEY_LEN..].try_into().unwrap()); + let mut tg_enc_key = [0u8; 32]; + tg_enc_key.copy_from_slice(&enc_key_iv[..KEY_LEN]); + let mut tg_enc_iv_arr = [0u8; IV_LEN]; + tg_enc_iv_arr.copy_from_slice(&enc_key_iv[KEY_LEN..]); + let tg_enc_iv = u128::from_be_bytes(tg_enc_iv_arr); - let tg_dec_key: [u8; 32] = dec_key_iv[..KEY_LEN].try_into().unwrap(); - let tg_dec_iv = u128::from_be_bytes(dec_key_iv[KEY_LEN..].try_into().unwrap()); + let mut tg_dec_key = [0u8; 32]; + tg_dec_key.copy_from_slice(&dec_key_iv[..KEY_LEN]); + let mut tg_dec_iv_arr = [0u8; IV_LEN]; + tg_dec_iv_arr.copy_from_slice(&dec_key_iv[KEY_LEN..]); + let tg_dec_iv = u128::from_be_bytes(tg_dec_iv_arr); return (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv); } @@ -403,13 +927,19 @@ pub fn generate_tg_nonce( /// Encrypt nonce for sending to Telegram and return cipher objects with correct counter state pub fn encrypt_tg_nonce_with_ciphers(nonce: &[u8; HANDSHAKE_LEN]) -> (Vec, AesCtr, AesCtr) { let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN]; - let dec_key_iv: Vec = enc_key_iv.iter().rev().copied().collect(); + let dec_key_iv = Zeroizing::new(enc_key_iv.iter().rev().copied().collect::>()); - let enc_key: [u8; 32] = enc_key_iv[..KEY_LEN].try_into().unwrap(); - let enc_iv = u128::from_be_bytes(enc_key_iv[KEY_LEN..].try_into().unwrap()); + let mut enc_key = [0u8; 32]; + enc_key.copy_from_slice(&enc_key_iv[..KEY_LEN]); + let mut enc_iv_arr = [0u8; IV_LEN]; + enc_iv_arr.copy_from_slice(&enc_key_iv[KEY_LEN..]); + let enc_iv = u128::from_be_bytes(enc_iv_arr); - let dec_key: [u8; 32] = dec_key_iv[..KEY_LEN].try_into().unwrap(); - let dec_iv = u128::from_be_bytes(dec_key_iv[KEY_LEN..].try_into().unwrap()); + let mut dec_key = [0u8; 32]; + dec_key.copy_from_slice(&dec_key_iv[..KEY_LEN]); + let mut dec_iv_arr = [0u8; IV_LEN]; + dec_iv_arr.copy_from_slice(&dec_key_iv[KEY_LEN..]); + let dec_iv = u128::from_be_bytes(dec_iv_arr); let mut encryptor = AesCtr::new(&enc_key, enc_iv); let encrypted_full = encryptor.encrypt(nonce); // counter: 0 → 4 @@ -418,91 +948,45 @@ pub fn encrypt_tg_nonce_with_ciphers(nonce: &[u8; HANDSHAKE_LEN]) -> (Vec, A result.extend_from_slice(&encrypted_full[PROTO_TAG_POS..]); let decryptor = AesCtr::new(&dec_key, dec_iv); + enc_key.zeroize(); + dec_key.zeroize(); (result, encryptor, decryptor) } /// Encrypt nonce for sending to Telegram (legacy function for compatibility) + pub fn encrypt_tg_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec { let (encrypted, _, _) = encrypt_tg_nonce_with_ciphers(nonce); encrypted } #[cfg(test)] -mod tests { - use super::*; +#[path = "tests/handshake_security_tests.rs"] +mod security_tests; - #[test] - fn test_generate_tg_nonce() { - let client_dec_key = [0x42u8; 32]; - let client_dec_iv = 12345u128; - let client_enc_key = [0x24u8; 32]; - let client_enc_iv = 54321u128; +#[cfg(test)] +#[path = "tests/handshake_adversarial_tests.rs"] +mod adversarial_tests; - let rng = SecureRandom::new(); - let (nonce, _tg_enc_key, _tg_enc_iv, _tg_dec_key, _tg_dec_iv) = - generate_tg_nonce( - ProtoTag::Secure, - 2, - &client_dec_key, - client_dec_iv, - &client_enc_key, - client_enc_iv, - &rng, - false, - ); +#[cfg(test)] +#[path = "tests/handshake_fuzz_security_tests.rs"] +mod fuzz_security_tests; - assert_eq!(nonce.len(), HANDSHAKE_LEN); +#[cfg(test)] +#[path = "tests/handshake_saturation_poison_security_tests.rs"] +mod saturation_poison_security_tests; - let tag_bytes: [u8; 4] = nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].try_into().unwrap(); - assert_eq!(ProtoTag::from_bytes(tag_bytes), Some(ProtoTag::Secure)); - } +#[cfg(test)] +#[path = "tests/handshake_auth_probe_hardening_adversarial_tests.rs"] +mod auth_probe_hardening_adversarial_tests; - #[test] - fn test_encrypt_tg_nonce() { - let client_dec_key = [0x42u8; 32]; - let client_dec_iv = 12345u128; - let client_enc_key = [0x24u8; 32]; - let client_enc_iv = 54321u128; +/// Compile-time guard: HandshakeSuccess holds cryptographic key material and +/// must never be Copy. A Copy impl would allow silent key duplication, +/// undermining the zeroize-on-drop guarantee. +mod compile_time_security_checks { + use super::HandshakeSuccess; + use static_assertions::assert_not_impl_all; - let rng = SecureRandom::new(); - let (nonce, _, _, _, _) = - generate_tg_nonce( - ProtoTag::Secure, - 2, - &client_dec_key, - client_dec_iv, - &client_enc_key, - client_enc_iv, - &rng, - false, - ); - - let encrypted = encrypt_tg_nonce(&nonce); - - assert_eq!(encrypted.len(), HANDSHAKE_LEN); - assert_eq!(&encrypted[..PROTO_TAG_POS], &nonce[..PROTO_TAG_POS]); - assert_ne!(&encrypted[PROTO_TAG_POS..], &nonce[PROTO_TAG_POS..]); - } - - #[test] - fn test_handshake_success_zeroize_on_drop() { - let success = HandshakeSuccess { - user: "test".to_string(), - dc_idx: 2, - proto_tag: ProtoTag::Secure, - dec_key: [0xAA; 32], - dec_iv: 0xBBBBBBBB, - enc_key: [0xCC; 32], - enc_iv: 0xDDDDDDDD, - peer: "127.0.0.1:1234".parse().unwrap(), - is_tls: true, - }; - - assert_eq!(success.dec_key, [0xAA; 32]); - assert_eq!(success.enc_key, [0xCC; 32]); - - drop(success); - // Drop impl zeroizes key material without panic - } + assert_not_impl_all!(HandshakeSuccess: Copy, Clone); } diff --git a/src/proxy/masking.rs b/src/proxy/masking.rs index 318071b..d647a3a 100644 --- a/src/proxy/masking.rs +++ b/src/proxy/masking.rs @@ -3,23 +3,212 @@ use std::str; use std::net::SocketAddr; use std::time::Duration; +use rand::{Rng, RngExt}; use tokio::net::TcpStream; #[cfg(unix)] use tokio::net::UnixStream; use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt}; -use tokio::time::timeout; +use tokio::time::{Instant, timeout}; use tracing::debug; use crate::config::ProxyConfig; use crate::network::dns_overrides::resolve_socket_addr; use crate::stats::beobachten::BeobachtenStore; use crate::transport::proxy_protocol::{ProxyProtocolV1Builder, ProxyProtocolV2Builder}; +#[cfg(not(test))] const MASK_TIMEOUT: Duration = Duration::from_secs(5); +#[cfg(test)] +const MASK_TIMEOUT: Duration = Duration::from_millis(50); /// Maximum duration for the entire masking relay. /// Limits resource consumption from slow-loris attacks and port scanners. +#[cfg(not(test))] const MASK_RELAY_TIMEOUT: Duration = Duration::from_secs(60); +#[cfg(test)] +const MASK_RELAY_TIMEOUT: Duration = Duration::from_millis(200); +#[cfg(not(test))] +const MASK_RELAY_IDLE_TIMEOUT: Duration = Duration::from_secs(5); +#[cfg(test)] +const MASK_RELAY_IDLE_TIMEOUT: Duration = Duration::from_millis(100); const MASK_BUFFER_SIZE: usize = 8192; +struct CopyOutcome { + total: usize, + ended_by_eof: bool, +} + +async fn copy_with_idle_timeout(reader: &mut R, writer: &mut W) -> CopyOutcome +where + R: AsyncRead + Unpin, + W: AsyncWrite + Unpin, +{ + let mut buf = [0u8; MASK_BUFFER_SIZE]; + let mut total = 0usize; + let mut ended_by_eof = false; + loop { + let read_res = timeout(MASK_RELAY_IDLE_TIMEOUT, reader.read(&mut buf)).await; + let n = match read_res { + Ok(Ok(n)) => n, + Ok(Err(_)) | Err(_) => break, + }; + if n == 0 { + ended_by_eof = true; + break; + } + total = total.saturating_add(n); + + let write_res = timeout(MASK_RELAY_IDLE_TIMEOUT, writer.write_all(&buf[..n])).await; + match write_res { + Ok(Ok(())) => {} + Ok(Err(_)) | Err(_) => break, + } + } + CopyOutcome { + total, + ended_by_eof, + } +} + +fn next_mask_shape_bucket(total: usize, floor: usize, cap: usize) -> usize { + if total == 0 || floor == 0 || cap < floor { + return total; + } + + if total >= cap { + return total; + } + + let mut bucket = floor; + while bucket < total { + match bucket.checked_mul(2) { + Some(next) => bucket = next, + None => return total, + } + if bucket > cap { + return cap; + } + } + bucket +} + +async fn maybe_write_shape_padding( + mask_write: &mut W, + total_sent: usize, + enabled: bool, + floor: usize, + cap: usize, + above_cap_blur: bool, + above_cap_blur_max_bytes: usize, +) +where + W: AsyncWrite + Unpin, +{ + if !enabled { + return; + } + + let target_total = if total_sent >= cap && above_cap_blur && above_cap_blur_max_bytes > 0 { + let mut rng = rand::rng(); + let extra = rng.random_range(0..=above_cap_blur_max_bytes); + total_sent.saturating_add(extra) + } else { + next_mask_shape_bucket(total_sent, floor, cap) + }; + + if target_total <= total_sent { + return; + } + + let mut remaining = target_total - total_sent; + let mut pad_chunk = [0u8; 1024]; + let deadline = Instant::now() + MASK_TIMEOUT; + + while remaining > 0 { + let now = Instant::now(); + if now >= deadline { + return; + } + + let write_len = remaining.min(pad_chunk.len()); + { + let mut rng = rand::rng(); + rng.fill_bytes(&mut pad_chunk[..write_len]); + } + let write_budget = deadline.saturating_duration_since(now); + match timeout(write_budget, mask_write.write_all(&pad_chunk[..write_len])).await { + Ok(Ok(())) => {} + Ok(Err(_)) | Err(_) => return, + } + remaining -= write_len; + } + + let now = Instant::now(); + if now >= deadline { + return; + } + let flush_budget = deadline.saturating_duration_since(now); + let _ = timeout(flush_budget, mask_write.flush()).await; +} + +async fn write_proxy_header_with_timeout(mask_write: &mut W, header: &[u8]) -> bool +where + W: AsyncWrite + Unpin, +{ + match timeout(MASK_TIMEOUT, mask_write.write_all(header)).await { + Ok(Ok(())) => true, + Ok(Err(_)) => false, + Err(_) => { + debug!("Timeout writing proxy protocol header to mask backend"); + false + } + } +} + +async fn consume_client_data_with_timeout(reader: R) +where + R: AsyncRead + Unpin, +{ + if timeout(MASK_RELAY_TIMEOUT, consume_client_data(reader)).await.is_err() { + debug!("Timed out while consuming client data on masking fallback path"); + } +} + +async fn wait_mask_connect_budget(started: Instant) { + let elapsed = started.elapsed(); + if elapsed < MASK_TIMEOUT { + tokio::time::sleep(MASK_TIMEOUT - elapsed).await; + } +} + +fn mask_outcome_target_budget(config: &ProxyConfig) -> Duration { + if config.censorship.mask_timing_normalization_enabled { + let floor = config.censorship.mask_timing_normalization_floor_ms; + let ceiling = config.censorship.mask_timing_normalization_ceiling_ms; + if ceiling > floor { + let mut rng = rand::rng(); + return Duration::from_millis(rng.random_range(floor..=ceiling)); + } + return Duration::from_millis(floor); + } + + MASK_TIMEOUT +} + +async fn wait_mask_connect_budget_if_needed(started: Instant, config: &ProxyConfig) { + if config.censorship.mask_timing_normalization_enabled { + return; + } + + wait_mask_connect_budget(started).await; +} + +async fn wait_mask_outcome_budget(started: Instant, config: &ProxyConfig) { + let target = mask_outcome_target_budget(config); + let elapsed = started.elapsed(); + if elapsed < target { + tokio::time::sleep(target - elapsed).await; + } +} + /// Detect client type based on initial data fn detect_client_type(data: &[u8]) -> &'static str { // Check for HTTP request @@ -49,6 +238,37 @@ fn detect_client_type(data: &[u8]) -> &'static str { "unknown" } +fn build_mask_proxy_header( + version: u8, + peer: SocketAddr, + local_addr: SocketAddr, +) -> Option> { + match version { + 0 => None, + 2 => Some( + ProxyProtocolV2Builder::new() + .with_addrs(peer, local_addr) + .build(), + ), + _ => { + let header = match (peer, local_addr) { + (SocketAddr::V4(src), SocketAddr::V4(dst)) => { + ProxyProtocolV1Builder::new() + .tcp4(src.into(), dst.into()) + .build() + } + (SocketAddr::V6(src), SocketAddr::V6(dst)) => { + ProxyProtocolV1Builder::new() + .tcp6(src.into(), dst.into()) + .build() + } + _ => ProxyProtocolV1Builder::new().build(), + }; + Some(header) + } + } +} + /// Handle a bad client by forwarding to mask host pub async fn handle_bad_client( reader: R, @@ -71,13 +291,15 @@ where if !config.censorship.mask { // Masking disabled, just consume data - consume_client_data(reader).await; + consume_client_data_with_timeout(reader).await; return; } // Connect via Unix socket or TCP #[cfg(unix)] if let Some(ref sock_path) = config.censorship.mask_unix_sock { + let outcome_started = Instant::now(); + let connect_started = Instant::now(); debug!( client_type = client_type, sock = %sock_path, @@ -89,39 +311,46 @@ where match connect_result { Ok(Ok(stream)) => { let (mask_read, mut mask_write) = stream.into_split(); - let proxy_header: Option> = match config.censorship.mask_proxy_protocol { - 0 => None, - version => { - let header = match version { - 2 => ProxyProtocolV2Builder::new().with_addrs(peer, local_addr).build(), - _ => match (peer, local_addr) { - (SocketAddr::V4(src), SocketAddr::V4(dst)) => - ProxyProtocolV1Builder::new().tcp4(src.into(), dst.into()).build(), - (SocketAddr::V6(src), SocketAddr::V6(dst)) => - ProxyProtocolV1Builder::new().tcp6(src.into(), dst.into()).build(), - _ => - ProxyProtocolV1Builder::new().build(), - }, - }; - Some(header) - } - }; + let proxy_header = + build_mask_proxy_header(config.censorship.mask_proxy_protocol, peer, local_addr); if let Some(header) = proxy_header { - if mask_write.write_all(&header).await.is_err() { + if !write_proxy_header_with_timeout(&mut mask_write, &header).await { + wait_mask_outcome_budget(outcome_started, config).await; return; } } - if timeout(MASK_RELAY_TIMEOUT, relay_to_mask(reader, writer, mask_read, mask_write, initial_data)).await.is_err() { + if timeout( + MASK_RELAY_TIMEOUT, + relay_to_mask( + reader, + writer, + mask_read, + mask_write, + initial_data, + config.censorship.mask_shape_hardening, + config.censorship.mask_shape_bucket_floor_bytes, + config.censorship.mask_shape_bucket_cap_bytes, + config.censorship.mask_shape_above_cap_blur, + config.censorship.mask_shape_above_cap_blur_max_bytes, + ), + ) + .await + .is_err() + { debug!("Mask relay timed out (unix socket)"); } + wait_mask_outcome_budget(outcome_started, config).await; } Ok(Err(e)) => { + wait_mask_connect_budget_if_needed(connect_started, config).await; debug!(error = %e, "Failed to connect to mask unix socket"); - consume_client_data(reader).await; + consume_client_data_with_timeout(reader).await; + wait_mask_outcome_budget(outcome_started, config).await; } Err(_) => { debug!("Timeout connecting to mask unix socket"); - consume_client_data(reader).await; + consume_client_data_with_timeout(reader).await; + wait_mask_outcome_budget(outcome_started, config).await; } } return; @@ -143,44 +372,53 @@ where let mask_addr = resolve_socket_addr(mask_host, mask_port) .map(|addr| addr.to_string()) .unwrap_or_else(|| format!("{}:{}", mask_host, mask_port)); + let outcome_started = Instant::now(); + let connect_started = Instant::now(); let connect_result = timeout(MASK_TIMEOUT, TcpStream::connect(&mask_addr)).await; match connect_result { Ok(Ok(stream)) => { - let proxy_header: Option> = match config.censorship.mask_proxy_protocol { - 0 => None, - version => { - let header = match version { - 2 => ProxyProtocolV2Builder::new().with_addrs(peer, local_addr).build(), - _ => match (peer, local_addr) { - (SocketAddr::V4(src), SocketAddr::V4(dst)) => - ProxyProtocolV1Builder::new().tcp4(src.into(), dst.into()).build(), - (SocketAddr::V6(src), SocketAddr::V6(dst)) => - ProxyProtocolV1Builder::new().tcp6(src.into(), dst.into()).build(), - _ => - ProxyProtocolV1Builder::new().build(), - }, - }; - Some(header) - } - }; + let proxy_header = + build_mask_proxy_header(config.censorship.mask_proxy_protocol, peer, local_addr); let (mask_read, mut mask_write) = stream.into_split(); if let Some(header) = proxy_header { - if mask_write.write_all(&header).await.is_err() { + if !write_proxy_header_with_timeout(&mut mask_write, &header).await { + wait_mask_outcome_budget(outcome_started, config).await; return; } } - if timeout(MASK_RELAY_TIMEOUT, relay_to_mask(reader, writer, mask_read, mask_write, initial_data)).await.is_err() { + if timeout( + MASK_RELAY_TIMEOUT, + relay_to_mask( + reader, + writer, + mask_read, + mask_write, + initial_data, + config.censorship.mask_shape_hardening, + config.censorship.mask_shape_bucket_floor_bytes, + config.censorship.mask_shape_bucket_cap_bytes, + config.censorship.mask_shape_above_cap_blur, + config.censorship.mask_shape_above_cap_blur_max_bytes, + ), + ) + .await + .is_err() + { debug!("Mask relay timed out"); } + wait_mask_outcome_budget(outcome_started, config).await; } Ok(Err(e)) => { + wait_mask_connect_budget_if_needed(connect_started, config).await; debug!(error = %e, "Failed to connect to mask host"); - consume_client_data(reader).await; + consume_client_data_with_timeout(reader).await; + wait_mask_outcome_budget(outcome_started, config).await; } Err(_) => { debug!("Timeout connecting to mask host"); - consume_client_data(reader).await; + consume_client_data_with_timeout(reader).await; + wait_mask_outcome_budget(outcome_started, config).await; } } } @@ -192,6 +430,11 @@ async fn relay_to_mask( mut mask_read: MR, mut mask_write: MW, initial_data: &[u8], + shape_hardening_enabled: bool, + shape_bucket_floor_bytes: usize, + shape_bucket_cap_bytes: usize, + shape_above_cap_blur: bool, + shape_above_cap_blur_max_bytes: usize, ) where R: AsyncRead + Unpin + Send + 'static, @@ -203,47 +446,36 @@ where if mask_write.write_all(initial_data).await.is_err() { return; } - - // Relay traffic - let c2m = tokio::spawn(async move { - let mut buf = vec![0u8; MASK_BUFFER_SIZE]; - loop { - match reader.read(&mut buf).await { - Ok(0) | Err(_) => { - let _ = mask_write.shutdown().await; - break; - } - Ok(n) => { - if mask_write.write_all(&buf[..n]).await.is_err() { - break; - } - } - } - } - }); - - let m2c = tokio::spawn(async move { - let mut buf = vec![0u8; MASK_BUFFER_SIZE]; - loop { - match mask_read.read(&mut buf).await { - Ok(0) | Err(_) => { - let _ = writer.shutdown().await; - break; - } - Ok(n) => { - if writer.write_all(&buf[..n]).await.is_err() { - break; - } - } - } - } - }); - - // Wait for either to complete - tokio::select! { - _ = c2m => {} - _ = m2c => {} + if mask_write.flush().await.is_err() { + return; } + + let _ = tokio::join!( + async { + let copied = copy_with_idle_timeout(&mut reader, &mut mask_write).await; + let total_sent = initial_data.len().saturating_add(copied.total); + + let should_shape = shape_hardening_enabled + && copied.ended_by_eof + && !initial_data.is_empty(); + + maybe_write_shape_padding( + &mut mask_write, + total_sent, + should_shape, + shape_bucket_floor_bytes, + shape_bucket_cap_bytes, + shape_above_cap_blur, + shape_above_cap_blur_max_bytes, + ) + .await; + let _ = mask_write.shutdown().await; + }, + async { + let _ = copy_with_idle_timeout(&mut mask_read, &mut writer).await; + let _ = writer.shutdown().await; + } + ); } /// Just consume all data from client without responding @@ -255,3 +487,43 @@ async fn consume_client_data(mut reader: R) { } } } + +#[cfg(test)] +#[path = "tests/masking_security_tests.rs"] +mod security_tests; + +#[cfg(test)] +#[path = "tests/masking_adversarial_tests.rs"] +mod adversarial_tests; + +#[cfg(test)] +#[path = "tests/masking_shape_hardening_adversarial_tests.rs"] +mod masking_shape_hardening_adversarial_tests; + +#[cfg(test)] +#[path = "tests/masking_shape_above_cap_blur_security_tests.rs"] +mod masking_shape_above_cap_blur_security_tests; + +#[cfg(test)] +#[path = "tests/masking_timing_normalization_security_tests.rs"] +mod masking_timing_normalization_security_tests; + +#[cfg(test)] +#[path = "tests/masking_ab_envelope_blur_integration_security_tests.rs"] +mod masking_ab_envelope_blur_integration_security_tests; + +#[cfg(test)] +#[path = "tests/masking_shape_guard_security_tests.rs"] +mod masking_shape_guard_security_tests; + +#[cfg(test)] +#[path = "tests/masking_shape_guard_adversarial_tests.rs"] +mod masking_shape_guard_adversarial_tests; + +#[cfg(test)] +#[path = "tests/masking_shape_classifier_resistance_adversarial_tests.rs"] +mod masking_shape_classifier_resistance_adversarial_tests; + +#[cfg(test)] +#[path = "tests/masking_timing_sidechannel_redteam_expected_fail_tests.rs"] +mod masking_timing_sidechannel_redteam_expected_fail_tests; diff --git a/src/proxy/middle_relay.rs b/src/proxy/middle_relay.rs index 102b06c..2000977 100644 --- a/src/proxy/middle_relay.rs +++ b/src/proxy/middle_relay.rs @@ -1,15 +1,17 @@ -use std::collections::HashMap; -use std::collections::hash_map::DefaultHasher; +use std::collections::hash_map::RandomState; +use std::collections::{BTreeSet, HashMap}; +use std::hash::BuildHasher; use std::hash::{Hash, Hasher}; use std::net::{IpAddr, SocketAddr}; -use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; use std::sync::{Arc, Mutex, OnceLock}; use std::time::{Duration, Instant}; -use bytes::Bytes; +use dashmap::DashMap; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; -use tokio::sync::{mpsc, oneshot, watch}; -use tracing::{debug, trace, warn}; +use tokio::sync::{mpsc, oneshot, watch, Mutex as AsyncMutex}; +use tokio::time::timeout; +use tracing::{debug, info, trace, warn}; use crate::config::ProxyConfig; use crate::crypto::SecureRandom; @@ -20,25 +22,41 @@ use crate::proxy::route_mode::{ RelayRouteMode, RouteCutoverState, ROUTE_SWITCH_ERROR_MSG, affected_cutover_state, cutover_stagger_delay, }; -use crate::proxy::adaptive_buffers::{self, AdaptiveTier}; -use crate::proxy::session_eviction::SessionLease; use crate::stats::Stats; -use crate::stream::{BufferPool, CryptoReader, CryptoWriter}; +use crate::stream::{BufferPool, CryptoReader, CryptoWriter, PooledBuffer}; use crate::transport::middle_proxy::{MePool, MeResponse, proto_flags_for_tag}; enum C2MeCommand { - Data { payload: Bytes, flags: u32 }, + Data { payload: PooledBuffer, flags: u32 }, Close, } const DESYNC_DEDUP_WINDOW: Duration = Duration::from_secs(60); +const DESYNC_DEDUP_MAX_ENTRIES: usize = 65_536; +const DESYNC_DEDUP_PRUNE_SCAN_LIMIT: usize = 1024; +const DESYNC_FULL_CACHE_EMIT_MIN_INTERVAL: Duration = Duration::from_millis(1000); const DESYNC_ERROR_CLASS: &str = "frame_too_large_crypto_desync"; const C2ME_CHANNEL_CAPACITY_FALLBACK: usize = 128; const C2ME_SOFT_PRESSURE_MIN_FREE_SLOTS: usize = 64; const C2ME_SENDER_FAIRNESS_BUDGET: usize = 32; +const RELAY_IDLE_IO_POLL_MAX: Duration = Duration::from_secs(1); +#[cfg(test)] +const C2ME_SEND_TIMEOUT: Duration = Duration::from_millis(50); +#[cfg(not(test))] +const C2ME_SEND_TIMEOUT: Duration = Duration::from_secs(5); const ME_D2C_FLUSH_BATCH_MAX_FRAMES_MIN: usize = 1; const ME_D2C_FLUSH_BATCH_MAX_BYTES_MIN: usize = 4096; -static DESYNC_DEDUP: OnceLock>> = OnceLock::new(); +#[cfg(test)] +const QUOTA_USER_LOCKS_MAX: usize = 64; +#[cfg(not(test))] +const QUOTA_USER_LOCKS_MAX: usize = 4_096; +static DESYNC_DEDUP: OnceLock> = OnceLock::new(); +static DESYNC_HASHER: OnceLock = OnceLock::new(); +static DESYNC_FULL_CACHE_LAST_EMIT_AT: OnceLock>> = OnceLock::new(); +static DESYNC_DEDUP_EVER_SATURATED: OnceLock = OnceLock::new(); +static QUOTA_USER_LOCKS: OnceLock>>> = OnceLock::new(); +static RELAY_IDLE_CANDIDATE_REGISTRY: OnceLock> = OnceLock::new(); +static RELAY_IDLE_MARK_SEQ: AtomicU64 = AtomicU64::new(0); struct RelayForensicsState { trace_id: u64, @@ -52,6 +70,140 @@ struct RelayForensicsState { desync_all_full: bool, } +#[derive(Default)] +struct RelayIdleCandidateRegistry { + by_conn_id: HashMap, + ordered: BTreeSet<(u64, u64)>, + pressure_event_seq: u64, + pressure_consumed_seq: u64, +} + +#[derive(Clone, Copy)] +struct RelayIdleCandidateMeta { + mark_order_seq: u64, + mark_pressure_seq: u64, +} + +fn relay_idle_candidate_registry() -> &'static Mutex { + RELAY_IDLE_CANDIDATE_REGISTRY.get_or_init(|| Mutex::new(RelayIdleCandidateRegistry::default())) +} + +fn mark_relay_idle_candidate(conn_id: u64) -> bool { + let Ok(mut guard) = relay_idle_candidate_registry().lock() else { + return false; + }; + + if guard.by_conn_id.contains_key(&conn_id) { + return false; + } + + let mark_order_seq = RELAY_IDLE_MARK_SEQ + .fetch_add(1, Ordering::Relaxed) + .saturating_add(1); + let meta = RelayIdleCandidateMeta { + mark_order_seq, + mark_pressure_seq: guard.pressure_event_seq, + }; + guard.by_conn_id.insert(conn_id, meta); + guard.ordered.insert((meta.mark_order_seq, conn_id)); + true +} + +fn clear_relay_idle_candidate(conn_id: u64) { + let Ok(mut guard) = relay_idle_candidate_registry().lock() else { + return; + }; + + if let Some(meta) = guard.by_conn_id.remove(&conn_id) { + guard.ordered.remove(&(meta.mark_order_seq, conn_id)); + } +} + +#[cfg(test)] +fn oldest_relay_idle_candidate() -> Option { + let Ok(guard) = relay_idle_candidate_registry().lock() else { + return None; + }; + guard.ordered.iter().next().map(|(_, conn_id)| *conn_id) +} + +fn note_relay_pressure_event() { + let Ok(mut guard) = relay_idle_candidate_registry().lock() else { + return; + }; + guard.pressure_event_seq = guard.pressure_event_seq.wrapping_add(1); +} + +fn relay_pressure_event_seq() -> u64 { + let Ok(guard) = relay_idle_candidate_registry().lock() else { + return 0; + }; + guard.pressure_event_seq +} + +fn maybe_evict_idle_candidate_on_pressure( + conn_id: u64, + seen_pressure_seq: &mut u64, + stats: &Stats, +) -> bool { + let Ok(mut guard) = relay_idle_candidate_registry().lock() else { + return false; + }; + + let latest_pressure_seq = guard.pressure_event_seq; + if latest_pressure_seq == *seen_pressure_seq { + return false; + } + *seen_pressure_seq = latest_pressure_seq; + + if latest_pressure_seq == guard.pressure_consumed_seq { + return false; + } + + if guard.ordered.is_empty() { + guard.pressure_consumed_seq = latest_pressure_seq; + return false; + } + + let oldest = guard + .ordered + .iter() + .next() + .map(|(_, candidate_conn_id)| *candidate_conn_id); + if oldest != Some(conn_id) { + return false; + } + + let Some(candidate_meta) = guard.by_conn_id.get(&conn_id).copied() else { + return false; + }; + + // Pressure events that happened before candidate soft-mark are stale for this candidate. + if latest_pressure_seq == candidate_meta.mark_pressure_seq { + return false; + } + + if let Some(meta) = guard.by_conn_id.remove(&conn_id) { + guard.ordered.remove(&(meta.mark_order_seq, conn_id)); + } + guard.pressure_consumed_seq = latest_pressure_seq; + stats.increment_relay_pressure_evict_total(); + true +} + +#[cfg(test)] +fn clear_relay_idle_pressure_state_for_testing() { + if let Some(registry) = RELAY_IDLE_CANDIDATE_REGISTRY.get() + && let Ok(mut guard) = registry.lock() + { + guard.by_conn_id.clear(); + guard.ordered.clear(); + guard.pressure_event_seq = 0; + guard.pressure_consumed_seq = 0; + } + RELAY_IDLE_MARK_SEQ.store(0, Ordering::Relaxed); +} + #[derive(Clone, Copy)] struct MeD2cFlushPolicy { max_frames: usize, @@ -60,9 +212,64 @@ struct MeD2cFlushPolicy { ack_flush_immediate: bool, } +#[derive(Clone, Copy)] +struct RelayClientIdlePolicy { + enabled: bool, + soft_idle: Duration, + hard_idle: Duration, + grace_after_downstream_activity: Duration, + legacy_frame_read_timeout: Duration, +} + +impl RelayClientIdlePolicy { + fn from_config(config: &ProxyConfig) -> Self { + Self { + enabled: config.timeouts.relay_idle_policy_v2_enabled, + soft_idle: Duration::from_secs(config.timeouts.relay_client_idle_soft_secs.max(1)), + hard_idle: Duration::from_secs(config.timeouts.relay_client_idle_hard_secs.max(1)), + grace_after_downstream_activity: Duration::from_secs( + config + .timeouts + .relay_idle_grace_after_downstream_activity_secs, + ), + legacy_frame_read_timeout: Duration::from_secs(config.timeouts.client_handshake.max(1)), + } + } + + #[cfg(test)] + fn disabled(frame_read_timeout: Duration) -> Self { + Self { + enabled: false, + soft_idle: Duration::from_secs(0), + hard_idle: Duration::from_secs(0), + grace_after_downstream_activity: Duration::from_secs(0), + legacy_frame_read_timeout: frame_read_timeout, + } + } +} + +struct RelayClientIdleState { + last_client_frame_at: Instant, + soft_idle_marked: bool, +} + +impl RelayClientIdleState { + fn new(now: Instant) -> Self { + Self { + last_client_frame_at: now, + soft_idle_marked: false, + } + } + + fn on_client_frame(&mut self, now: Instant) { + self.last_client_frame_at = now; + self.soft_idle_marked = false; + } +} + impl MeD2cFlushPolicy { - fn from_config(config: &ProxyConfig, tier: AdaptiveTier) -> Self { - let base = Self { + fn from_config(config: &ProxyConfig) -> Self { + Self { max_frames: config .general .me_d2c_flush_batch_max_frames @@ -73,24 +280,13 @@ impl MeD2cFlushPolicy { .max(ME_D2C_FLUSH_BATCH_MAX_BYTES_MIN), max_delay: Duration::from_micros(config.general.me_d2c_flush_batch_max_delay_us), ack_flush_immediate: config.general.me_d2c_ack_flush_immediate, - }; - let (max_frames, max_bytes, max_delay) = adaptive_buffers::me_flush_policy_for_tier( - tier, - base.max_frames, - base.max_bytes, - base.max_delay, - ); - Self { - max_frames, - max_bytes, - max_delay, - ack_flush_immediate: base.ack_flush_immediate, } } } fn hash_value(value: &T) -> u64 { - let mut hasher = DefaultHasher::new(); + let state = DESYNC_HASHER.get_or_init(RandomState::new); + let mut hasher = state.build_hasher(); value.hash(&mut hasher); hasher.finish() } @@ -104,26 +300,122 @@ fn should_emit_full_desync(key: u64, all_full: bool, now: Instant) -> bool { return true; } - let dedup = DESYNC_DEDUP.get_or_init(|| Mutex::new(HashMap::new())); - let mut guard = dedup.lock().expect("desync dedup mutex poisoned"); - guard.retain(|_, seen_at| now.duration_since(*seen_at) < DESYNC_DEDUP_WINDOW); + let dedup = DESYNC_DEDUP.get_or_init(DashMap::new); + let saturated_before = dedup.len() >= DESYNC_DEDUP_MAX_ENTRIES; + let ever_saturated = DESYNC_DEDUP_EVER_SATURATED.get_or_init(|| AtomicBool::new(false)); + if saturated_before { + ever_saturated.store(true, Ordering::Relaxed); + } - match guard.get_mut(&key) { - Some(seen_at) => { - if now.duration_since(*seen_at) >= DESYNC_DEDUP_WINDOW { - *seen_at = now; + if let Some(mut seen_at) = dedup.get_mut(&key) { + if now.duration_since(*seen_at) >= DESYNC_DEDUP_WINDOW { + *seen_at = now; + return true; + } + return false; + } + + if dedup.len() >= DESYNC_DEDUP_MAX_ENTRIES { + let mut stale_keys = Vec::new(); + let mut oldest_candidate: Option<(u64, Instant)> = None; + for entry in dedup.iter().take(DESYNC_DEDUP_PRUNE_SCAN_LIMIT) { + let key = *entry.key(); + let seen_at = *entry.value(); + + match oldest_candidate { + Some((_, oldest_seen)) if seen_at >= oldest_seen => {} + _ => oldest_candidate = Some((key, seen_at)), + } + + if now.duration_since(seen_at) >= DESYNC_DEDUP_WINDOW { + stale_keys.push(*entry.key()); + } + } + for stale_key in stale_keys { + dedup.remove(&stale_key); + } + if dedup.len() >= DESYNC_DEDUP_MAX_ENTRIES { + let Some((evict_key, _)) = oldest_candidate else { + return false; + }; + dedup.remove(&evict_key); + dedup.insert(key, now); + return should_emit_full_desync_full_cache(now); + } + } + + dedup.insert(key, now); + let saturated_after = dedup.len() >= DESYNC_DEDUP_MAX_ENTRIES; + // Preserve the first sequential insert that reaches capacity as a normal + // emit, while still gating concurrent newcomer churn after the cache has + // ever been observed at saturation. + let was_ever_saturated = if saturated_after { + ever_saturated.swap(true, Ordering::Relaxed) + } else { + ever_saturated.load(Ordering::Relaxed) + }; + + if saturated_before || (saturated_after && was_ever_saturated) { + should_emit_full_desync_full_cache(now) + } else { + true + } +} + +fn should_emit_full_desync_full_cache(now: Instant) -> bool { + let gate = DESYNC_FULL_CACHE_LAST_EMIT_AT.get_or_init(|| Mutex::new(None)); + let Ok(mut last_emit_at) = gate.lock() else { + return false; + }; + + match *last_emit_at { + None => { + *last_emit_at = Some(now); + true + } + Some(last) => { + let Some(elapsed) = now.checked_duration_since(last) else { + *last_emit_at = Some(now); + return true; + }; + if elapsed >= DESYNC_FULL_CACHE_EMIT_MIN_INTERVAL { + *last_emit_at = Some(now); true } else { false } } - None => { - guard.insert(key, now); - true + } +} + +#[cfg(test)] +fn clear_desync_dedup_for_testing() { + if let Some(dedup) = DESYNC_DEDUP.get() { + dedup.clear(); + } + if let Some(ever_saturated) = DESYNC_DEDUP_EVER_SATURATED.get() { + ever_saturated.store(false, Ordering::Relaxed); + } + if let Some(last_emit_at) = DESYNC_FULL_CACHE_LAST_EMIT_AT.get() { + match last_emit_at.lock() { + Ok(mut guard) => { + *guard = None; + } + Err(poisoned) => { + let mut guard = poisoned.into_inner(); + *guard = None; + last_emit_at.clear_poison(); + } } } } +#[cfg(test)] +fn desync_dedup_test_lock() -> &'static Mutex<()> { + static TEST_LOCK: OnceLock> = OnceLock::new(); + TEST_LOCK.get_or_init(|| Mutex::new(())) +} + fn report_desync_frame_too_large( state: &RelayForensicsState, proto_tag: ProtoTag, @@ -152,6 +444,7 @@ fn report_desync_frame_too_large( let bytes_me2c = state.bytes_me2c.load(Ordering::Relaxed); stats.increment_desync_total(); + stats.increment_relay_protocol_desync_close_total(); stats.observe_desync_frames_ok(frame_counter); if emit_full { stats.increment_desync_full_logged(); @@ -219,23 +512,60 @@ fn should_yield_c2me_sender(sent_since_yield: usize, has_backlog: bool) -> bool has_backlog && sent_since_yield >= C2ME_SENDER_FAIRNESS_BUDGET } +fn quota_exceeded_for_user(stats: &Stats, user: &str, quota_limit: Option) -> bool { + quota_limit.is_some_and(|quota| stats.get_user_total_octets(user) >= quota) +} + +fn quota_would_be_exceeded_for_user( + stats: &Stats, + user: &str, + quota_limit: Option, + bytes: u64, +) -> bool { + quota_limit.is_some_and(|quota| { + let used = stats.get_user_total_octets(user); + used >= quota || bytes > quota.saturating_sub(used) + }) +} + +fn quota_user_lock(user: &str) -> Arc> { + let locks = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + if let Some(existing) = locks.get(user) { + return Arc::clone(existing.value()); + } + + if locks.len() >= QUOTA_USER_LOCKS_MAX { + locks.retain(|_, value| Arc::strong_count(value) > 1); + } + + if locks.len() >= QUOTA_USER_LOCKS_MAX { + return Arc::new(AsyncMutex::new(())); + } + + let created = Arc::new(AsyncMutex::new(())); + match locks.entry(user.to_string()) { + dashmap::mapref::entry::Entry::Occupied(entry) => Arc::clone(entry.get()), + dashmap::mapref::entry::Entry::Vacant(entry) => { + entry.insert(Arc::clone(&created)); + created + } + } +} + async fn enqueue_c2me_command( tx: &mpsc::Sender, cmd: C2MeCommand, - send_timeout: Duration, ) -> std::result::Result<(), mpsc::error::SendError> { match tx.try_send(cmd) { Ok(()) => Ok(()), Err(mpsc::error::TrySendError::Closed(cmd)) => Err(mpsc::error::SendError(cmd)), Err(mpsc::error::TrySendError::Full(cmd)) => { + note_relay_pressure_event(); // Cooperative yield reduces burst catch-up when the per-conn queue is near saturation. if tx.capacity() <= C2ME_SOFT_PRESSURE_MIN_FREE_SLOTS { tokio::task::yield_now().await; } - if send_timeout.is_zero() { - return tx.send(cmd).await; - } - match tokio::time::timeout(send_timeout, tx.reserve()).await { + match timeout(C2ME_SEND_TIMEOUT, tx.reserve()).await { Ok(Ok(permit)) => { permit.send(cmd); Ok(()) @@ -254,23 +584,22 @@ pub(crate) async fn handle_via_middle_proxy( me_pool: Arc, stats: Arc, config: Arc, - _buffer_pool: Arc, + buffer_pool: Arc, local_addr: SocketAddr, rng: Arc, mut route_rx: watch::Receiver, route_snapshot: RouteCutoverState, session_id: u64, - session_lease: SessionLease, ) -> Result<()> where R: AsyncRead + Unpin + Send + 'static, W: AsyncWrite + Unpin + Send + 'static, { let user = success.user.clone(); + let quota_limit = config.access.user_data_quota.get(&user).copied(); let peer = success.peer; let proto_tag = success.proto_tag; let pool_generation = me_pool.current_generation(); - let seed_tier = adaptive_buffers::seed_tier_for_user(&user); debug!( user = %user, @@ -283,7 +612,7 @@ where ); let (conn_id, me_rx) = me_pool.registry().register().await; - let trace_id = conn_id; + let trace_id = session_id; let bytes_me2c = Arc::new(AtomicU64::new(0)); let mut forensics = RelayForensicsState { trace_id, @@ -298,8 +627,7 @@ where }; stats.increment_user_connects(&user); - stats.increment_user_curr_connects(&user); - stats.increment_current_connections_me(); + let _me_connection_lease = stats.acquire_me_connection_lease(); if let Some(cutover) = affected_cutover_state( &route_rx, @@ -317,20 +645,9 @@ where tokio::time::sleep(delay).await; let _ = me_pool.send_close(conn_id).await; me_pool.registry().unregister(conn_id).await; - stats.decrement_current_connections_me(); - stats.decrement_user_curr_connects(&user); return Err(ProxyError::Proxy(ROUTE_SWITCH_ERROR_MSG.to_string())); } - if session_lease.is_stale() { - stats.increment_reconnect_stale_close_total(); - let _ = me_pool.send_close(conn_id).await; - me_pool.registry().unregister(conn_id).await; - stats.decrement_current_connections_me(); - stats.decrement_user_curr_connects(&user); - return Err(ProxyError::Proxy("Session evicted by reconnect".to_string())); - } - // Per-user ad_tag from access.user_ad_tags; fallback to general.ad_tag (hot-reloadable) let user_tag: Option> = config .access @@ -361,12 +678,15 @@ where let translated_local_addr = me_pool.translate_our_addr(local_addr); let frame_limit = config.general.max_client_frame; + let relay_idle_policy = RelayClientIdlePolicy::from_config(&config); + let session_started_at = forensics.started_at; + let mut relay_idle_state = RelayClientIdleState::new(session_started_at); + let last_downstream_activity_ms = Arc::new(AtomicU64::new(0)); let c2me_channel_capacity = config .general .me_c2me_channel_capacity .max(C2ME_CHANNEL_CAPACITY_FALLBACK); - let c2me_send_timeout = Duration::from_millis(config.general.me_c2me_send_timeout_ms); let (c2me_tx, mut c2me_rx) = mpsc::channel::(c2me_channel_capacity); let me_pool_c2me = me_pool.clone(); let effective_tag = effective_tag; @@ -375,42 +695,15 @@ where while let Some(cmd) = c2me_rx.recv().await { match cmd { C2MeCommand::Data { payload, flags } => { - if c2me_send_timeout.is_zero() { - me_pool_c2me - .send_proxy_req( - conn_id, - success.dc_idx, - peer, - translated_local_addr, - payload.as_ref(), - flags, - effective_tag.as_deref(), - ) - .await?; - } else { - match tokio::time::timeout( - c2me_send_timeout, - me_pool_c2me.send_proxy_req( - conn_id, - success.dc_idx, - peer, - translated_local_addr, - payload.as_ref(), - flags, - effective_tag.as_deref(), - ), - ) - .await - { - Ok(send_result) => send_result?, - Err(_) => { - return Err(ProxyError::Proxy(format!( - "ME send timeout after {}ms", - c2me_send_timeout.as_millis() - ))); - } - } - } + me_pool_c2me.send_proxy_req( + conn_id, + success.dc_idx, + peer, + translated_local_addr, + payload.as_ref(), + flags, + effective_tag.as_deref(), + ).await?; sent_since_yield = sent_since_yield.saturating_add(1); if should_yield_c2me_sender(sent_since_yield, !c2me_rx.is_empty()) { sent_since_yield = 0; @@ -431,8 +724,9 @@ where let stats_clone = stats.clone(); let rng_clone = rng.clone(); let user_clone = user.clone(); + let last_downstream_activity_ms_clone = last_downstream_activity_ms.clone(); let bytes_me2c_clone = bytes_me2c.clone(); - let d2c_flush_policy = MeD2cFlushPolicy::from_config(&config, seed_tier); + let d2c_flush_policy = MeD2cFlushPolicy::from_config(&config); let me_writer = tokio::spawn(async move { let mut writer = crypto_writer; let mut frame_buf = Vec::with_capacity(16 * 1024); @@ -448,6 +742,8 @@ where let mut batch_bytes = 0usize; let mut flush_immediately; + let first_is_downstream_activity = + matches!(&first, MeResponse::Data { .. } | MeResponse::Ack(_)); match process_me_writer_response( first, &mut writer, @@ -456,12 +752,17 @@ where &mut frame_buf, stats_clone.as_ref(), &user_clone, + quota_limit, bytes_me2c_clone.as_ref(), conn_id, d2c_flush_policy.ack_flush_immediate, false, ).await? { MeWriterResponseOutcome::Continue { frames, bytes, flush_immediately: immediate } => { + if first_is_downstream_activity { + last_downstream_activity_ms_clone + .store(session_started_at.elapsed().as_millis() as u64, Ordering::Relaxed); + } batch_frames = batch_frames.saturating_add(frames); batch_bytes = batch_bytes.saturating_add(bytes); flush_immediately = immediate; @@ -480,6 +781,8 @@ where break; }; + let next_is_downstream_activity = + matches!(&next, MeResponse::Data { .. } | MeResponse::Ack(_)); match process_me_writer_response( next, &mut writer, @@ -488,12 +791,17 @@ where &mut frame_buf, stats_clone.as_ref(), &user_clone, + quota_limit, bytes_me2c_clone.as_ref(), conn_id, d2c_flush_policy.ack_flush_immediate, true, ).await? { MeWriterResponseOutcome::Continue { frames, bytes, flush_immediately: immediate } => { + if next_is_downstream_activity { + last_downstream_activity_ms_clone + .store(session_started_at.elapsed().as_millis() as u64, Ordering::Relaxed); + } batch_frames = batch_frames.saturating_add(frames); batch_bytes = batch_bytes.saturating_add(bytes); flush_immediately |= immediate; @@ -512,6 +820,8 @@ where { match tokio::time::timeout(d2c_flush_policy.max_delay, me_rx_task.recv()).await { Ok(Some(next)) => { + let next_is_downstream_activity = + matches!(&next, MeResponse::Data { .. } | MeResponse::Ack(_)); match process_me_writer_response( next, &mut writer, @@ -520,12 +830,17 @@ where &mut frame_buf, stats_clone.as_ref(), &user_clone, + quota_limit, bytes_me2c_clone.as_ref(), conn_id, d2c_flush_policy.ack_flush_immediate, true, ).await? { MeWriterResponseOutcome::Continue { frames, bytes, flush_immediately: immediate } => { + if next_is_downstream_activity { + last_downstream_activity_ms_clone + .store(session_started_at.elapsed().as_millis() as u64, Ordering::Relaxed); + } batch_frames = batch_frames.saturating_add(frames); batch_bytes = batch_bytes.saturating_add(bytes); flush_immediately |= immediate; @@ -544,6 +859,8 @@ where break; }; + let extra_is_downstream_activity = + matches!(&extra, MeResponse::Data { .. } | MeResponse::Ack(_)); match process_me_writer_response( extra, &mut writer, @@ -552,12 +869,17 @@ where &mut frame_buf, stats_clone.as_ref(), &user_clone, + quota_limit, bytes_me2c_clone.as_ref(), conn_id, d2c_flush_policy.ack_flush_immediate, true, ).await? { MeWriterResponseOutcome::Continue { frames, bytes, flush_immediately: immediate } => { + if extra_is_downstream_activity { + last_downstream_activity_ms_clone + .store(session_started_at.elapsed().as_millis() as u64, Ordering::Relaxed); + } batch_frames = batch_frames.saturating_add(frames); batch_bytes = batch_bytes.saturating_add(bytes); flush_immediately |= immediate; @@ -591,13 +913,24 @@ where let mut client_closed = false; let mut frame_counter: u64 = 0; let mut route_watch_open = true; + let mut seen_pressure_seq = relay_pressure_event_seq(); loop { - if session_lease.is_stale() { - stats.increment_reconnect_stale_close_total(); - let _ = enqueue_c2me_command(&c2me_tx, C2MeCommand::Close, c2me_send_timeout).await; - main_result = Err(ProxyError::Proxy("Session evicted by reconnect".to_string())); + if relay_idle_policy.enabled + && maybe_evict_idle_candidate_on_pressure(conn_id, &mut seen_pressure_seq, stats.as_ref()) + { + info!( + conn_id, + trace_id = format_args!("0x{:016x}", trace_id), + user = %user, + "Middle-relay pressure eviction for idle-candidate session" + ); + let _ = enqueue_c2me_command(&c2me_tx, C2MeCommand::Close).await; + main_result = Err(ProxyError::Proxy( + "middle-relay session evicted under pressure (idle-candidate)".to_string(), + )); break; } + if let Some(cutover) = affected_cutover_state( &route_rx, RelayRouteMode::Middle, @@ -612,7 +945,7 @@ where "Cutover affected middle session, closing client connection" ); tokio::time::sleep(delay).await; - let _ = enqueue_c2me_command(&c2me_tx, C2MeCommand::Close, c2me_send_timeout).await; + let _ = enqueue_c2me_command(&c2me_tx, C2MeCommand::Close).await; main_result = Err(ProxyError::Proxy(ROUTE_SWITCH_ERROR_MSG.to_string())); break; } @@ -623,13 +956,18 @@ where route_watch_open = false; } } - payload_result = read_client_payload( + payload_result = read_client_payload_with_idle_policy( &mut crypto_reader, proto_tag, frame_limit, + &buffer_pool, &forensics, &mut frame_counter, &stats, + &relay_idle_policy, + &mut relay_idle_state, + last_downstream_activity_ms.as_ref(), + session_started_at, ) => { match payload_result { Ok(Some((payload, quickack))) => { @@ -637,7 +975,19 @@ where forensics.bytes_c2me = forensics .bytes_c2me .saturating_add(payload.len() as u64); - stats.add_user_octets_from(&user, payload.len() as u64); + if let Some(limit) = quota_limit { + let quota_lock = quota_user_lock(&user); + let _quota_guard = quota_lock.lock().await; + stats.add_user_octets_from(&user, payload.len() as u64); + if quota_exceeded_for_user(stats.as_ref(), &user, Some(limit)) { + main_result = Err(ProxyError::DataQuotaExceeded { + user: user.clone(), + }); + break; + } + } else { + stats.add_user_octets_from(&user, payload.len() as u64); + } let mut flags = proto_flags; if quickack { flags |= RPC_FLAG_QUICKACK; @@ -646,13 +996,9 @@ where flags |= RPC_FLAG_NOT_ENCRYPTED; } // Keep client read loop lightweight: route heavy ME send path via a dedicated task. - if enqueue_c2me_command( - &c2me_tx, - C2MeCommand::Data { payload, flags }, - c2me_send_timeout, - ) - .await - .is_err() + if enqueue_c2me_command(&c2me_tx, C2MeCommand::Data { payload, flags }) + .await + .is_err() { main_result = Err(ProxyError::Proxy("ME sender channel closed".into())); break; @@ -661,12 +1007,7 @@ where Ok(None) => { debug!(conn_id, "Client EOF"); client_closed = true; - let _ = enqueue_c2me_command( - &c2me_tx, - C2MeCommand::Close, - c2me_send_timeout, - ) - .await; + let _ = enqueue_c2me_command(&c2me_tx, C2MeCommand::Close).await; break; } Err(e) => { @@ -715,41 +1056,203 @@ where frames_ok = frame_counter, "ME relay cleanup" ); - adaptive_buffers::record_user_tier(&user, seed_tier); + clear_relay_idle_candidate(conn_id); me_pool.registry().unregister(conn_id).await; - stats.decrement_current_connections_me(); - stats.decrement_user_curr_connects(&user); result } -async fn read_client_payload( +async fn read_client_payload_with_idle_policy( client_reader: &mut CryptoReader, proto_tag: ProtoTag, max_frame: usize, + buffer_pool: &Arc, forensics: &RelayForensicsState, frame_counter: &mut u64, stats: &Stats, -) -> Result> + idle_policy: &RelayClientIdlePolicy, + idle_state: &mut RelayClientIdleState, + last_downstream_activity_ms: &AtomicU64, + session_started_at: Instant, +) -> Result> where R: AsyncRead + Unpin + Send + 'static, { + async fn read_exact_with_policy( + client_reader: &mut CryptoReader, + buf: &mut [u8], + idle_policy: &RelayClientIdlePolicy, + idle_state: &mut RelayClientIdleState, + last_downstream_activity_ms: &AtomicU64, + session_started_at: Instant, + forensics: &RelayForensicsState, + stats: &Stats, + read_label: &'static str, + ) -> Result<()> + where + R: AsyncRead + Unpin + Send + 'static, + { + fn hard_deadline( + idle_policy: &RelayClientIdlePolicy, + idle_state: &RelayClientIdleState, + session_started_at: Instant, + last_downstream_activity_ms: u64, + ) -> Instant { + let mut deadline = idle_state.last_client_frame_at + idle_policy.hard_idle; + if idle_policy.grace_after_downstream_activity.is_zero() { + return deadline; + } + + let downstream_at = session_started_at + Duration::from_millis(last_downstream_activity_ms); + if downstream_at > idle_state.last_client_frame_at { + let grace_deadline = downstream_at + idle_policy.grace_after_downstream_activity; + if grace_deadline > deadline { + deadline = grace_deadline; + } + } + deadline + } + + let mut filled = 0usize; + while filled < buf.len() { + let timeout_window = if idle_policy.enabled { + let now = Instant::now(); + let downstream_ms = last_downstream_activity_ms.load(Ordering::Relaxed); + let hard_deadline = hard_deadline( + idle_policy, + idle_state, + session_started_at, + downstream_ms, + ); + if now >= hard_deadline { + clear_relay_idle_candidate(forensics.conn_id); + stats.increment_relay_idle_hard_close_total(); + let client_idle_secs = now + .saturating_duration_since(idle_state.last_client_frame_at) + .as_secs(); + let downstream_idle_secs = now + .saturating_duration_since(session_started_at + Duration::from_millis(downstream_ms)) + .as_secs(); + warn!( + trace_id = format_args!("0x{:016x}", forensics.trace_id), + conn_id = forensics.conn_id, + user = %forensics.user, + read_label, + client_idle_secs, + downstream_idle_secs, + soft_idle_secs = idle_policy.soft_idle.as_secs(), + hard_idle_secs = idle_policy.hard_idle.as_secs(), + grace_secs = idle_policy.grace_after_downstream_activity.as_secs(), + "Middle-relay hard idle close" + ); + return Err(ProxyError::Io(std::io::Error::new( + std::io::ErrorKind::TimedOut, + format!( + "middle-relay hard idle timeout while reading {read_label}: client_idle_secs={client_idle_secs}, downstream_idle_secs={downstream_idle_secs}, soft_idle_secs={}, hard_idle_secs={}, grace_secs={}", + idle_policy.soft_idle.as_secs(), + idle_policy.hard_idle.as_secs(), + idle_policy.grace_after_downstream_activity.as_secs(), + ), + ))); + } + + if !idle_state.soft_idle_marked + && now.saturating_duration_since(idle_state.last_client_frame_at) + >= idle_policy.soft_idle + { + idle_state.soft_idle_marked = true; + if mark_relay_idle_candidate(forensics.conn_id) { + stats.increment_relay_idle_soft_mark_total(); + } + info!( + trace_id = format_args!("0x{:016x}", forensics.trace_id), + conn_id = forensics.conn_id, + user = %forensics.user, + read_label, + soft_idle_secs = idle_policy.soft_idle.as_secs(), + hard_idle_secs = idle_policy.hard_idle.as_secs(), + grace_secs = idle_policy.grace_after_downstream_activity.as_secs(), + "Middle-relay soft idle mark" + ); + } + + let soft_deadline = idle_state.last_client_frame_at + idle_policy.soft_idle; + let next_deadline = if idle_state.soft_idle_marked { + hard_deadline + } else { + soft_deadline.min(hard_deadline) + }; + let mut remaining = next_deadline.saturating_duration_since(now); + if remaining.is_zero() { + remaining = Duration::from_millis(1); + } + remaining.min(RELAY_IDLE_IO_POLL_MAX) + } else { + idle_policy.legacy_frame_read_timeout + }; + + let read_result = timeout(timeout_window, client_reader.read(&mut buf[filled..])).await; + match read_result { + Ok(Ok(0)) => { + return Err(ProxyError::Io(std::io::Error::from( + std::io::ErrorKind::UnexpectedEof, + ))); + } + Ok(Ok(n)) => { + filled = filled.saturating_add(n); + } + Ok(Err(e)) => return Err(ProxyError::Io(e)), + Err(_) if !idle_policy.enabled => { + return Err(ProxyError::Io(std::io::Error::new( + std::io::ErrorKind::TimedOut, + format!("middle-relay client frame read timeout while reading {read_label}"), + ))); + } + Err(_) => {} + } + } + + Ok(()) + } + loop { let (len, quickack, raw_len_bytes) = match proto_tag { ProtoTag::Abridged => { let mut first = [0u8; 1]; - match client_reader.read_exact(&mut first).await { - Ok(_) => {} - Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None), - Err(e) => return Err(ProxyError::Io(e)), + match read_exact_with_policy( + client_reader, + &mut first, + idle_policy, + idle_state, + last_downstream_activity_ms, + session_started_at, + forensics, + stats, + "abridged.first_len_byte", + ) + .await + { + Ok(()) => {} + Err(ProxyError::Io(e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => { + return Ok(None); + } + Err(e) => return Err(e), } let quickack = (first[0] & 0x80) != 0; let len_words = if (first[0] & 0x7f) == 0x7f { let mut ext = [0u8; 3]; - client_reader - .read_exact(&mut ext) - .await - .map_err(ProxyError::Io)?; + read_exact_with_policy( + client_reader, + &mut ext, + idle_policy, + idle_state, + last_downstream_activity_ms, + session_started_at, + forensics, + stats, + "abridged.extended_len", + ) + .await?; u32::from_le_bytes([ext[0], ext[1], ext[2], 0]) as usize } else { (first[0] & 0x7f) as usize @@ -762,10 +1265,24 @@ where } ProtoTag::Intermediate | ProtoTag::Secure => { let mut len_buf = [0u8; 4]; - match client_reader.read_exact(&mut len_buf).await { - Ok(_) => {} - Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None), - Err(e) => return Err(ProxyError::Io(e)), + match read_exact_with_policy( + client_reader, + &mut len_buf, + idle_policy, + idle_state, + last_downstream_activity_ms, + session_started_at, + forensics, + stats, + "len_prefix", + ) + .await + { + Ok(()) => {} + Err(ProxyError::Io(e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => { + return Ok(None); + } + Err(e) => return Err(e), } let quickack = (len_buf[3] & 0x80) != 0; ( @@ -788,6 +1305,7 @@ where proto = ?proto_tag, "Frame too small — corrupt or probe" ); + stats.increment_relay_protocol_desync_close_total(); return Err(ProxyError::Proxy(format!("Frame too small: {len}"))); } @@ -808,6 +1326,7 @@ where Some(payload_len) => payload_len, None => { stats.increment_secure_padding_invalid(); + stats.increment_relay_protocol_desync_close_total(); return Err(ProxyError::Proxy(format!( "Invalid secure frame length: {len}" ))); @@ -817,21 +1336,98 @@ where len }; - let mut payload = vec![0u8; len]; - client_reader - .read_exact(&mut payload) - .await - .map_err(ProxyError::Io)?; + let mut payload = buffer_pool.get(); + payload.clear(); + let current_cap = payload.capacity(); + if current_cap < len { + payload.reserve(len - current_cap); + } + payload.resize(len, 0); + read_exact_with_policy( + client_reader, + &mut payload[..len], + idle_policy, + idle_state, + last_downstream_activity_ms, + session_started_at, + forensics, + stats, + "payload", + ) + .await?; // Secure Intermediate: strip validated trailing padding bytes. if proto_tag == ProtoTag::Secure { payload.truncate(secure_payload_len); } *frame_counter += 1; - return Ok(Some((Bytes::from(payload), quickack))); + idle_state.on_client_frame(Instant::now()); + clear_relay_idle_candidate(forensics.conn_id); + return Ok(Some((payload, quickack))); } } +#[cfg(test)] +async fn read_client_payload_legacy( + client_reader: &mut CryptoReader, + proto_tag: ProtoTag, + max_frame: usize, + frame_read_timeout: Duration, + buffer_pool: &Arc, + forensics: &RelayForensicsState, + frame_counter: &mut u64, + stats: &Stats, +) -> Result> +where + R: AsyncRead + Unpin + Send + 'static, +{ + let now = Instant::now(); + let mut idle_state = RelayClientIdleState::new(now); + let last_downstream_activity_ms = AtomicU64::new(0); + let idle_policy = RelayClientIdlePolicy::disabled(frame_read_timeout); + read_client_payload_with_idle_policy( + client_reader, + proto_tag, + max_frame, + buffer_pool, + forensics, + frame_counter, + stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + now, + ) + .await +} + +#[cfg(test)] +async fn read_client_payload( + client_reader: &mut CryptoReader, + proto_tag: ProtoTag, + max_frame: usize, + frame_read_timeout: Duration, + buffer_pool: &Arc, + forensics: &RelayForensicsState, + frame_counter: &mut u64, + stats: &Stats, +) -> Result> +where + R: AsyncRead + Unpin + Send + 'static, +{ + read_client_payload_legacy( + client_reader, + proto_tag, + max_frame, + frame_read_timeout, + buffer_pool, + forensics, + frame_counter, + stats, + ) + .await +} + enum MeWriterResponseOutcome { Continue { frames: usize, @@ -849,6 +1445,7 @@ async fn process_me_writer_response( frame_buf: &mut Vec, stats: &Stats, user: &str, + quota_limit: Option, bytes_me2c: &AtomicU64, conn_id: u64, ack_flush_immediate: bool, @@ -864,17 +1461,47 @@ where } else { trace!(conn_id, bytes = data.len(), flags, "ME->C data"); } - bytes_me2c.fetch_add(data.len() as u64, Ordering::Relaxed); - stats.add_user_octets_to(user, data.len() as u64); - write_client_payload( - client_writer, - proto_tag, - flags, - &data, - rng, - frame_buf, - ) - .await?; + let data_len = data.len() as u64; + if let Some(limit) = quota_limit { + let quota_lock = quota_user_lock(user); + let _quota_guard = quota_lock.lock().await; + if quota_would_be_exceeded_for_user(stats, user, Some(limit), data_len) { + return Err(ProxyError::DataQuotaExceeded { + user: user.to_string(), + }); + } + write_client_payload( + client_writer, + proto_tag, + flags, + &data, + rng, + frame_buf, + ) + .await?; + + bytes_me2c.fetch_add(data.len() as u64, Ordering::Relaxed); + stats.add_user_octets_to(user, data.len() as u64); + + if quota_exceeded_for_user(stats, user, Some(limit)) { + return Err(ProxyError::DataQuotaExceeded { + user: user.to_string(), + }); + } + } else { + write_client_payload( + client_writer, + proto_tag, + flags, + &data, + rng, + frame_buf, + ) + .await?; + + bytes_me2c.fetch_add(data.len() as u64, Ordering::Relaxed); + stats.add_user_octets_to(user, data.len() as u64); + } Ok(MeWriterResponseOutcome::Continue { frames: 1, @@ -1020,84 +1647,13 @@ where } #[cfg(test)] -mod tests { - use super::*; - use tokio::time::{Duration as TokioDuration, timeout}; +#[path = "tests/middle_relay_security_tests.rs"] +mod security_tests; - #[test] - fn should_yield_sender_only_on_budget_with_backlog() { - assert!(!should_yield_c2me_sender(0, true)); - assert!(!should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET - 1, true)); - assert!(!should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET, false)); - assert!(should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET, true)); - } +#[cfg(test)] +#[path = "tests/middle_relay_idle_policy_security_tests.rs"] +mod idle_policy_security_tests; - #[tokio::test] - async fn enqueue_c2me_command_uses_try_send_fast_path() { - let (tx, mut rx) = mpsc::channel::(2); - enqueue_c2me_command( - &tx, - C2MeCommand::Data { - payload: Bytes::from_static(&[1, 2, 3]), - flags: 0, - }, - TokioDuration::from_millis(50), - ) - .await - .unwrap(); - - let recv = timeout(TokioDuration::from_millis(50), rx.recv()) - .await - .unwrap() - .unwrap(); - match recv { - C2MeCommand::Data { payload, flags } => { - assert_eq!(payload.as_ref(), &[1, 2, 3]); - assert_eq!(flags, 0); - } - C2MeCommand::Close => panic!("unexpected close command"), - } - } - - #[tokio::test] - async fn enqueue_c2me_command_falls_back_to_send_when_queue_is_full() { - let (tx, mut rx) = mpsc::channel::(1); - tx.send(C2MeCommand::Data { - payload: Bytes::from_static(&[9]), - flags: 9, - }) - .await - .unwrap(); - - let tx2 = tx.clone(); - let producer = tokio::spawn(async move { - enqueue_c2me_command( - &tx2, - C2MeCommand::Data { - payload: Bytes::from_static(&[7, 7]), - flags: 7, - }, - TokioDuration::from_millis(100), - ) - .await - .unwrap(); - }); - - let _ = timeout(TokioDuration::from_millis(100), rx.recv()) - .await - .unwrap(); - producer.await.unwrap(); - - let recv = timeout(TokioDuration::from_millis(100), rx.recv()) - .await - .unwrap() - .unwrap(); - match recv { - C2MeCommand::Data { payload, flags } => { - assert_eq!(payload.as_ref(), &[7, 7]); - assert_eq!(flags, 7); - } - C2MeCommand::Close => panic!("unexpected close command"), - } - } -} +#[cfg(test)] +#[path = "tests/middle_relay_desync_all_full_dedup_security_tests.rs"] +mod desync_all_full_dedup_security_tests; diff --git a/src/proxy/relay.rs b/src/proxy/relay.rs index 2b12d5a..88a8bd5 100644 --- a/src/proxy/relay.rs +++ b/src/proxy/relay.rs @@ -53,20 +53,17 @@ use std::io; use std::pin::Pin; -use std::sync::Arc; -use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::{Arc, Mutex, OnceLock}; +use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; use std::task::{Context, Poll}; use std::time::Duration; +use dashmap::DashMap; use tokio::io::{ AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf, copy_bidirectional_with_sizes, }; use tokio::time::Instant; use tracing::{debug, trace, warn}; -use crate::error::Result; -use crate::proxy::adaptive_buffers::{ - self, AdaptiveTier, RelaySignalSample, SessionAdaptiveController, TierTransitionReason, -}; -use crate::proxy::session_eviction::SessionLease; +use crate::error::{ProxyError, Result}; use crate::stats::Stats; use crate::stream::BufferPool; @@ -83,7 +80,11 @@ const ACTIVITY_TIMEOUT: Duration = Duration::from_secs(1800); /// 10 seconds gives responsive timeout detection (±10s accuracy) /// without measurable overhead from atomic reads. const WATCHDOG_INTERVAL: Duration = Duration::from_secs(10); -const ADAPTIVE_TICK: Duration = Duration::from_millis(250); + +#[inline] +fn watchdog_delta(current: u64, previous: u64) -> u64 { + current.saturating_sub(previous) +} // ============= CombinedStream ============= @@ -160,16 +161,6 @@ struct SharedCounters { s2c_ops: AtomicU64, /// Milliseconds since relay epoch of last I/O activity last_activity_ms: AtomicU64, - /// Bytes requested to write to client (S→C direction). - s2c_requested_bytes: AtomicU64, - /// Total write operations for S→C direction. - s2c_write_ops: AtomicU64, - /// Number of partial writes to client. - s2c_partial_writes: AtomicU64, - /// Number of times S→C poll_write returned Pending. - s2c_pending_writes: AtomicU64, - /// Consecutive pending writes in S→C direction. - s2c_consecutive_pending_writes: AtomicU64, } impl SharedCounters { @@ -180,11 +171,6 @@ impl SharedCounters { c2s_ops: AtomicU64::new(0), s2c_ops: AtomicU64::new(0), last_activity_ms: AtomicU64::new(0), - s2c_requested_bytes: AtomicU64::new(0), - s2c_write_ops: AtomicU64::new(0), - s2c_partial_writes: AtomicU64::new(0), - s2c_pending_writes: AtomicU64::new(0), - s2c_consecutive_pending_writes: AtomicU64::new(0), } } @@ -225,6 +211,12 @@ struct StatsIo { counters: Arc, stats: Arc, user: String, + quota_limit: Option, + quota_exceeded: Arc, + quota_read_wake_scheduled: bool, + quota_write_wake_scheduled: bool, + quota_read_retry_active: Arc, + quota_write_retry_active: Arc, epoch: Instant, } @@ -234,11 +226,135 @@ impl StatsIo { counters: Arc, stats: Arc, user: String, + quota_limit: Option, + quota_exceeded: Arc, epoch: Instant, ) -> Self { // Mark initial activity so the watchdog doesn't fire before data flows counters.touch(Instant::now(), epoch); - Self { inner, counters, stats, user, epoch } + Self { + inner, + counters, + stats, + user, + quota_limit, + quota_exceeded, + quota_read_wake_scheduled: false, + quota_write_wake_scheduled: false, + quota_read_retry_active: Arc::new(AtomicBool::new(false)), + quota_write_retry_active: Arc::new(AtomicBool::new(false)), + epoch, + } + } +} + +impl Drop for StatsIo { + fn drop(&mut self) { + self.quota_read_retry_active.store(false, Ordering::Relaxed); + self.quota_write_retry_active.store(false, Ordering::Relaxed); + } +} + +#[derive(Debug)] +struct QuotaIoSentinel; + +impl std::fmt::Display for QuotaIoSentinel { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("user data quota exceeded") + } +} + +impl std::error::Error for QuotaIoSentinel {} + +fn quota_io_error() -> io::Error { + io::Error::new(io::ErrorKind::PermissionDenied, QuotaIoSentinel) +} + +fn is_quota_io_error(err: &io::Error) -> bool { + err.kind() == io::ErrorKind::PermissionDenied + && err + .get_ref() + .and_then(|source| source.downcast_ref::()) + .is_some() +} + +#[cfg(test)] +const QUOTA_CONTENTION_RETRY_INTERVAL: Duration = Duration::from_millis(1); +#[cfg(not(test))] +const QUOTA_CONTENTION_RETRY_INTERVAL: Duration = Duration::from_millis(2); + +fn spawn_quota_retry_waker(retry_active: Arc, waker: std::task::Waker) { + tokio::task::spawn(async move { + loop { + if !retry_active.load(Ordering::Relaxed) { + break; + } + tokio::time::sleep(QUOTA_CONTENTION_RETRY_INTERVAL).await; + if !retry_active.load(Ordering::Relaxed) { + break; + } + waker.wake_by_ref(); + } + }); +} + +static QUOTA_USER_LOCKS: OnceLock>>> = OnceLock::new(); +static QUOTA_USER_OVERFLOW_LOCKS: OnceLock>>> = OnceLock::new(); + +#[cfg(test)] +const QUOTA_USER_LOCKS_MAX: usize = 64; +#[cfg(not(test))] +const QUOTA_USER_LOCKS_MAX: usize = 4_096; +#[cfg(test)] +const QUOTA_OVERFLOW_LOCK_STRIPES: usize = 16; +#[cfg(not(test))] +const QUOTA_OVERFLOW_LOCK_STRIPES: usize = 256; + +#[cfg(test)] +fn quota_user_lock_test_guard() -> &'static Mutex<()> { + static TEST_LOCK: OnceLock> = OnceLock::new(); + TEST_LOCK.get_or_init(|| Mutex::new(())) +} + +#[cfg(test)] +fn quota_user_lock_test_scope() -> std::sync::MutexGuard<'static, ()> { + quota_user_lock_test_guard() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +fn quota_overflow_user_lock(user: &str) -> Arc> { + let stripes = QUOTA_USER_OVERFLOW_LOCKS.get_or_init(|| { + (0..QUOTA_OVERFLOW_LOCK_STRIPES) + .map(|_| Arc::new(Mutex::new(()))) + .collect() + }); + + let hash = crc32fast::hash(user.as_bytes()) as usize; + Arc::clone(&stripes[hash % stripes.len()]) +} + +fn quota_user_lock(user: &str) -> Arc> { + let locks = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + if let Some(existing) = locks.get(user) { + return Arc::clone(existing.value()); + } + + if locks.len() >= QUOTA_USER_LOCKS_MAX { + locks.retain(|_, value| Arc::strong_count(value) > 1); + } + + if locks.len() >= QUOTA_USER_LOCKS_MAX { + return quota_overflow_user_lock(user); + } + + let created = Arc::new(Mutex::new(())); + match locks.entry(user.to_string()) { + dashmap::mapref::entry::Entry::Occupied(entry) => Arc::clone(entry.get()), + dashmap::mapref::entry::Entry::Vacant(entry) => { + entry.insert(Arc::clone(&created)); + created + } } } @@ -249,12 +365,68 @@ impl AsyncRead for StatsIo { buf: &mut ReadBuf<'_>, ) -> Poll> { let this = self.get_mut(); + if this.quota_exceeded.load(Ordering::Relaxed) { + return Poll::Ready(Err(quota_io_error())); + } + + let quota_lock = this + .quota_limit + .is_some() + .then(|| quota_user_lock(&this.user)); + let _quota_guard = if let Some(lock) = quota_lock.as_ref() { + match lock.try_lock() { + Ok(guard) => { + this.quota_read_wake_scheduled = false; + this.quota_read_retry_active.store(false, Ordering::Relaxed); + Some(guard) + } + Err(_) => { + if !this.quota_read_wake_scheduled { + this.quota_read_wake_scheduled = true; + this.quota_read_retry_active.store(true, Ordering::Relaxed); + spawn_quota_retry_waker( + Arc::clone(&this.quota_read_retry_active), + cx.waker().clone(), + ); + } + return Poll::Pending; + } + } + } else { + None + }; + + if let Some(limit) = this.quota_limit + && this.stats.get_user_total_octets(&this.user) >= limit + { + this.quota_exceeded.store(true, Ordering::Relaxed); + return Poll::Ready(Err(quota_io_error())); + } let before = buf.filled().len(); match Pin::new(&mut this.inner).poll_read(cx, buf) { Poll::Ready(Ok(())) => { let n = buf.filled().len() - before; if n > 0 { + let mut reached_quota_boundary = false; + if let Some(limit) = this.quota_limit { + let used = this.stats.get_user_total_octets(&this.user); + if used >= limit { + this.quota_exceeded.store(true, Ordering::Relaxed); + return Poll::Ready(Err(quota_io_error())); + } + + let remaining = limit - used; + if (n as u64) > remaining { + // Fail closed: when a single read chunk would cross quota, + // stop relay immediately without accounting beyond the cap. + this.quota_exceeded.store(true, Ordering::Relaxed); + return Poll::Ready(Err(quota_io_error())); + } + + reached_quota_boundary = (n as u64) == remaining; + } + // C→S: client sent data this.counters.c2s_bytes.fetch_add(n as u64, Ordering::Relaxed); this.counters.c2s_ops.fetch_add(1, Ordering::Relaxed); @@ -263,6 +435,10 @@ impl AsyncRead for StatsIo { this.stats.add_user_octets_from(&this.user, n as u64); this.stats.increment_user_msgs_from(&this.user); + if reached_quota_boundary { + this.quota_exceeded.store(true, Ordering::Relaxed); + } + trace!(user = %this.user, bytes = n, "C->S"); } Poll::Ready(Ok(())) @@ -279,21 +455,58 @@ impl AsyncWrite for StatsIo { buf: &[u8], ) -> Poll> { let this = self.get_mut(); - this.counters - .s2c_requested_bytes - .fetch_add(buf.len() as u64, Ordering::Relaxed); + if this.quota_exceeded.load(Ordering::Relaxed) { + return Poll::Ready(Err(quota_io_error())); + } - match Pin::new(&mut this.inner).poll_write(cx, buf) { - Poll::Ready(Ok(n)) => { - this.counters.s2c_write_ops.fetch_add(1, Ordering::Relaxed); - this.counters - .s2c_consecutive_pending_writes - .store(0, Ordering::Relaxed); - if n < buf.len() { - this.counters - .s2c_partial_writes - .fetch_add(1, Ordering::Relaxed); + let quota_lock = this + .quota_limit + .is_some() + .then(|| quota_user_lock(&this.user)); + let _quota_guard = if let Some(lock) = quota_lock.as_ref() { + match lock.try_lock() { + Ok(guard) => { + this.quota_write_wake_scheduled = false; + this.quota_write_retry_active.store(false, Ordering::Relaxed); + Some(guard) } + Err(_) => { + if !this.quota_write_wake_scheduled { + this.quota_write_wake_scheduled = true; + this.quota_write_retry_active.store(true, Ordering::Relaxed); + spawn_quota_retry_waker( + Arc::clone(&this.quota_write_retry_active), + cx.waker().clone(), + ); + } + return Poll::Pending; + } + } + } else { + None + }; + + let write_buf = if let Some(limit) = this.quota_limit { + let used = this.stats.get_user_total_octets(&this.user); + if used >= limit { + this.quota_exceeded.store(true, Ordering::Relaxed); + return Poll::Ready(Err(quota_io_error())); + } + + let remaining = (limit - used) as usize; + if buf.len() > remaining { + // Fail closed: do not emit partial S->C payload when remaining + // quota cannot accommodate the pending write request. + this.quota_exceeded.store(true, Ordering::Relaxed); + return Poll::Ready(Err(quota_io_error())); + } + buf + } else { + buf + }; + + match Pin::new(&mut this.inner).poll_write(cx, write_buf) { + Poll::Ready(Ok(n)) => { if n > 0 { // S→C: data written to client this.counters.s2c_bytes.fetch_add(n as u64, Ordering::Relaxed); @@ -303,19 +516,17 @@ impl AsyncWrite for StatsIo { this.stats.add_user_octets_to(&this.user, n as u64); this.stats.increment_user_msgs_to(&this.user); + if let Some(limit) = this.quota_limit + && this.stats.get_user_total_octets(&this.user) >= limit + { + this.quota_exceeded.store(true, Ordering::Relaxed); + return Poll::Ready(Err(quota_io_error())); + } + trace!(user = %this.user, bytes = n, "S->C"); } Poll::Ready(Ok(n)) } - Poll::Pending => { - this.counters - .s2c_pending_writes - .fetch_add(1, Ordering::Relaxed); - this.counters - .s2c_consecutive_pending_writes - .fetch_add(1, Ordering::Relaxed); - Poll::Pending - } other => other, } } @@ -348,7 +559,8 @@ impl AsyncWrite for StatsIo { /// - Per-user stats: bytes and ops counted per direction /// - Periodic rate logging: every 10 seconds when active /// - Clean shutdown: both write sides are shut down on exit -/// - Error propagation: I/O errors are returned as `ProxyError::Io` +/// - Error propagation: quota exits return `ProxyError::DataQuotaExceeded`, +/// other I/O failures are returned as `ProxyError::Io` pub async fn relay_bidirectional( client_reader: CR, client_writer: CW, @@ -357,11 +569,9 @@ pub async fn relay_bidirectional( c2s_buf_size: usize, s2c_buf_size: usize, user: &str, - dc_idx: i16, stats: Arc, + quota_limit: Option, _buffer_pool: Arc, - session_lease: SessionLease, - seed_tier: AdaptiveTier, ) -> Result<()> where CR: AsyncRead + Unpin + Send + 'static, @@ -371,6 +581,7 @@ where { let epoch = Instant::now(); let counters = Arc::new(SharedCounters::new()); + let quota_exceeded = Arc::new(AtomicBool::new(false)); let user_owned = user.to_string(); // ── Combine split halves into bidirectional streams ────────────── @@ -383,43 +594,31 @@ where Arc::clone(&counters), Arc::clone(&stats), user_owned.clone(), + quota_limit, + Arc::clone("a_exceeded), epoch, ); // ── Watchdog: activity timeout + periodic rate logging ────────── let wd_counters = Arc::clone(&counters); let wd_user = user_owned.clone(); - let wd_dc = dc_idx; - let wd_stats = Arc::clone(&stats); - let wd_session = session_lease.clone(); + let wd_quota_exceeded = Arc::clone("a_exceeded); let watchdog = async { - let mut prev_c2s_log: u64 = 0; - let mut prev_s2c_log: u64 = 0; - let mut prev_c2s_sample: u64 = 0; - let mut prev_s2c_requested_sample: u64 = 0; - let mut prev_s2c_written_sample: u64 = 0; - let mut prev_s2c_write_ops_sample: u64 = 0; - let mut prev_s2c_partial_sample: u64 = 0; - let mut accumulated_log = Duration::ZERO; - let mut adaptive = SessionAdaptiveController::new(seed_tier); + let mut prev_c2s: u64 = 0; + let mut prev_s2c: u64 = 0; loop { - tokio::time::sleep(ADAPTIVE_TICK).await; - - if wd_session.is_stale() { - wd_stats.increment_reconnect_stale_close_total(); - warn!( - user = %wd_user, - dc = wd_dc, - "Session evicted by reconnect" - ); - return; - } + tokio::time::sleep(WATCHDOG_INTERVAL).await; let now = Instant::now(); let idle = wd_counters.idle_duration(now, epoch); + if wd_quota_exceeded.load(Ordering::Relaxed) { + warn!(user = %wd_user, "User data quota reached, closing relay"); + return; + } + // ── Activity timeout ──────────────────────────────────── if idle >= ACTIVITY_TIMEOUT { let c2s = wd_counters.c2s_bytes.load(Ordering::Relaxed); @@ -434,80 +633,11 @@ where return; // Causes select! to cancel copy_bidirectional } - let c2s_total = wd_counters.c2s_bytes.load(Ordering::Relaxed); - let s2c_requested_total = wd_counters - .s2c_requested_bytes - .load(Ordering::Relaxed); - let s2c_written_total = wd_counters.s2c_bytes.load(Ordering::Relaxed); - let s2c_write_ops_total = wd_counters - .s2c_write_ops - .load(Ordering::Relaxed); - let s2c_partial_total = wd_counters - .s2c_partial_writes - .load(Ordering::Relaxed); - let consecutive_pending = wd_counters - .s2c_consecutive_pending_writes - .load(Ordering::Relaxed) as u32; - - let sample = RelaySignalSample { - c2s_bytes: c2s_total.saturating_sub(prev_c2s_sample), - s2c_requested_bytes: s2c_requested_total - .saturating_sub(prev_s2c_requested_sample), - s2c_written_bytes: s2c_written_total - .saturating_sub(prev_s2c_written_sample), - s2c_write_ops: s2c_write_ops_total - .saturating_sub(prev_s2c_write_ops_sample), - s2c_partial_writes: s2c_partial_total - .saturating_sub(prev_s2c_partial_sample), - s2c_consecutive_pending_writes: consecutive_pending, - }; - - if let Some(transition) = adaptive.observe(sample, ADAPTIVE_TICK.as_secs_f64()) { - match transition.reason { - TierTransitionReason::SoftConfirmed => { - wd_stats.increment_relay_adaptive_promotions_total(); - } - TierTransitionReason::HardPressure => { - wd_stats.increment_relay_adaptive_promotions_total(); - wd_stats.increment_relay_adaptive_hard_promotions_total(); - } - TierTransitionReason::QuietDemotion => { - wd_stats.increment_relay_adaptive_demotions_total(); - } - } - adaptive_buffers::record_user_tier(&wd_user, adaptive.max_tier_seen()); - debug!( - user = %wd_user, - dc = wd_dc, - from_tier = transition.from.as_u8(), - to_tier = transition.to.as_u8(), - reason = ?transition.reason, - throughput_ema_bps = sample - .c2s_bytes - .max(sample.s2c_written_bytes) - .saturating_mul(8) - .saturating_mul(4), - "Adaptive relay tier transition" - ); - } - - prev_c2s_sample = c2s_total; - prev_s2c_requested_sample = s2c_requested_total; - prev_s2c_written_sample = s2c_written_total; - prev_s2c_write_ops_sample = s2c_write_ops_total; - prev_s2c_partial_sample = s2c_partial_total; - - accumulated_log = accumulated_log.saturating_add(ADAPTIVE_TICK); - if accumulated_log < WATCHDOG_INTERVAL { - continue; - } - accumulated_log = Duration::ZERO; - // ── Periodic rate logging ─────────────────────────────── let c2s = wd_counters.c2s_bytes.load(Ordering::Relaxed); let s2c = wd_counters.s2c_bytes.load(Ordering::Relaxed); - let c2s_delta = c2s.saturating_sub(prev_c2s_log); - let s2c_delta = s2c.saturating_sub(prev_s2c_log); + let c2s_delta = watchdog_delta(c2s, prev_c2s); + let s2c_delta = watchdog_delta(s2c, prev_s2c); if c2s_delta > 0 || s2c_delta > 0 { let secs = WATCHDOG_INTERVAL.as_secs_f64(); @@ -521,8 +651,8 @@ where ); } - prev_c2s_log = c2s; - prev_s2c_log = s2c; + prev_c2s = c2s; + prev_s2c = s2c; } }; @@ -557,7 +687,6 @@ where let c2s_ops = counters.c2s_ops.load(Ordering::Relaxed); let s2c_ops = counters.s2c_ops.load(Ordering::Relaxed); let duration = epoch.elapsed(); - adaptive_buffers::record_user_tier(&user_owned, seed_tier); match copy_result { Some(Ok((c2s, s2c))) => { @@ -573,6 +702,22 @@ where ); Ok(()) } + Some(Err(e)) if is_quota_io_error(&e) => { + let c2s = counters.c2s_bytes.load(Ordering::Relaxed); + let s2c = counters.s2c_bytes.load(Ordering::Relaxed); + warn!( + user = %user_owned, + c2s_bytes = c2s, + s2c_bytes = s2c, + c2s_msgs = c2s_ops, + s2c_msgs = s2c_ops, + duration_secs = duration.as_secs(), + "Data quota reached, closing relay" + ); + Err(ProxyError::DataQuotaExceeded { + user: user_owned.clone(), + }) + } Some(Err(e)) => { // I/O error in one of the directions let c2s = counters.c2s_bytes.load(Ordering::Relaxed); @@ -606,3 +751,39 @@ where } } } + +#[cfg(test)] +#[path = "tests/relay_security_tests.rs"] +mod security_tests; + +#[cfg(test)] +#[path = "tests/relay_adversarial_tests.rs"] +mod adversarial_tests; + +#[cfg(test)] +#[path = "tests/relay_quota_lock_pressure_adversarial_tests.rs"] +mod relay_quota_lock_pressure_adversarial_tests; + +#[cfg(test)] +#[path = "tests/relay_quota_boundary_blackhat_tests.rs"] +mod relay_quota_boundary_blackhat_tests; + +#[cfg(test)] +#[path = "tests/relay_quota_model_adversarial_tests.rs"] +mod relay_quota_model_adversarial_tests; + +#[cfg(test)] +#[path = "tests/relay_quota_overflow_regression_tests.rs"] +mod relay_quota_overflow_regression_tests; + +#[cfg(test)] +#[path = "tests/relay_watchdog_delta_security_tests.rs"] +mod relay_watchdog_delta_security_tests; + +#[cfg(test)] +#[path = "tests/relay_quota_waker_storm_adversarial_tests.rs"] +mod relay_quota_waker_storm_adversarial_tests; + +#[cfg(test)] +#[path = "tests/relay_quota_wake_liveness_regression_tests.rs"] +mod relay_quota_wake_liveness_regression_tests; \ No newline at end of file diff --git a/src/proxy/route_mode.rs b/src/proxy/route_mode.rs index 306c536..e2232d2 100644 --- a/src/proxy/route_mode.rs +++ b/src/proxy/route_mode.rs @@ -1,10 +1,10 @@ use std::sync::Arc; -use std::sync::atomic::{AtomicU8, AtomicU64, Ordering}; +use std::sync::atomic::{AtomicU64, Ordering}; use std::time::{Duration, SystemTime, UNIX_EPOCH}; use tokio::sync::watch; -pub(crate) const ROUTE_SWITCH_ERROR_MSG: &str = "Route mode switched by cutover"; +pub(crate) const ROUTE_SWITCH_ERROR_MSG: &str = "Session terminated"; #[derive(Clone, Copy, Debug, PartialEq, Eq)] #[repr(u8)] @@ -14,17 +14,6 @@ pub(crate) enum RelayRouteMode { } impl RelayRouteMode { - pub(crate) fn as_u8(self) -> u8 { - self as u8 - } - - pub(crate) fn from_u8(value: u8) -> Self { - match value { - 1 => Self::Middle, - _ => Self::Direct, - } - } - pub(crate) fn as_str(self) -> &'static str { match self { Self::Direct => "direct", @@ -41,8 +30,6 @@ pub(crate) struct RouteCutoverState { #[derive(Clone)] pub(crate) struct RouteRuntimeController { - mode: Arc, - generation: Arc, direct_since_epoch_secs: Arc, tx: watch::Sender, } @@ -60,18 +47,13 @@ impl RouteRuntimeController { 0 }; Self { - mode: Arc::new(AtomicU8::new(initial_mode.as_u8())), - generation: Arc::new(AtomicU64::new(0)), direct_since_epoch_secs: Arc::new(AtomicU64::new(direct_since_epoch_secs)), tx, } } pub(crate) fn snapshot(&self) -> RouteCutoverState { - RouteCutoverState { - mode: RelayRouteMode::from_u8(self.mode.load(Ordering::Relaxed)), - generation: self.generation.load(Ordering::Relaxed), - } + *self.tx.borrow() } pub(crate) fn subscribe(&self) -> watch::Receiver { @@ -84,20 +66,28 @@ impl RouteRuntimeController { } pub(crate) fn set_mode(&self, mode: RelayRouteMode) -> Option { - let previous = self.mode.swap(mode.as_u8(), Ordering::Relaxed); - if previous == mode.as_u8() { + let mut next = None; + let changed = self.tx.send_if_modified(|state| { + if state.mode == mode { + return false; + } + if matches!(mode, RelayRouteMode::Direct) { + self.direct_since_epoch_secs + .store(now_epoch_secs(), Ordering::Relaxed); + } else { + self.direct_since_epoch_secs.store(0, Ordering::Relaxed); + } + state.mode = mode; + state.generation = state.generation.saturating_add(1); + next = Some(*state); + true + }); + + if !changed { return None; } - if matches!(mode, RelayRouteMode::Direct) { - self.direct_since_epoch_secs - .store(now_epoch_secs(), Ordering::Relaxed); - } else { - self.direct_since_epoch_secs.store(0, Ordering::Relaxed); - } - let generation = self.generation.fetch_add(1, Ordering::Relaxed) + 1; - let next = RouteCutoverState { mode, generation }; - self.tx.send_replace(next); - Some(next) + + next } } @@ -110,10 +100,10 @@ fn now_epoch_secs() -> u64 { pub(crate) fn is_session_affected_by_cutover( current: RouteCutoverState, - _session_mode: RelayRouteMode, + session_mode: RelayRouteMode, session_generation: u64, ) -> bool { - current.generation > session_generation + current.generation > session_generation && current.mode != session_mode } pub(crate) fn affected_cutover_state( @@ -140,3 +130,11 @@ pub(crate) fn cutover_stagger_delay(session_id: u64, generation: u64) -> Duratio let ms = 1000 + (value % 1000); Duration::from_millis(ms) } + +#[cfg(test)] +#[path = "tests/route_mode_security_tests.rs"] +mod security_tests; + +#[cfg(test)] +#[path = "tests/route_mode_coherence_adversarial_tests.rs"] +mod coherence_adversarial_tests; diff --git a/src/proxy/tests/client_adversarial_tests.rs b/src/proxy/tests/client_adversarial_tests.rs new file mode 100644 index 0000000..0e780e3 --- /dev/null +++ b/src/proxy/tests/client_adversarial_tests.rs @@ -0,0 +1,669 @@ +use super::*; +use crate::config::ProxyConfig; +use crate::stats::Stats; +use crate::ip_tracker::UserIpTracker; +use crate::error::ProxyError; +use std::sync::Arc; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; + +// ------------------------------------------------------------------ +// Priority 3: Massive Concurrency Stress (OWASP ASVS 5.1.6) +// ------------------------------------------------------------------ + +#[tokio::test] +async fn client_stress_10k_connections_limit_strict() { + let user = "stress-user"; + let limit = 512; + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let mut config = ProxyConfig::default(); + config.access.user_max_tcp_conns.insert(user.to_string(), limit); + + let iterations = 1000; + let mut tasks = Vec::new(); + + for i in 0..iterations { + let stats = Arc::clone(&stats); + let ip_tracker = Arc::clone(&ip_tracker); + let config = config.clone(); + let user_str = user.to_string(); + + tasks.push(tokio::spawn(async move { + let peer = SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(127, 0, 0, (i % 254 + 1) as u8)), + 10000 + (i % 1000) as u16, + ); + + match RunningClientHandler::acquire_user_connection_reservation_static( + &user_str, + &config, + stats, + peer, + ip_tracker, + ).await { + Ok(res) => Ok(res), + Err(ProxyError::ConnectionLimitExceeded { .. }) => Err(()), + Err(e) => panic!("Unexpected error: {:?}", e), + } + })); + } + + let results = futures::future::join_all(tasks).await; + let mut successes = 0; + let mut failures = 0; + let mut reservations = Vec::new(); + + for res in results { + match res.unwrap() { + Ok(r) => { + successes += 1; + reservations.push(r); + } + Err(_) => failures += 1, + } + } + + assert_eq!(successes, limit, "Should allow exactly 'limit' connections"); + assert_eq!(failures, iterations - limit, "Should fail the rest with LimitExceeded"); + assert_eq!(stats.get_user_curr_connects(user), limit as u64); + + drop(reservations); + + ip_tracker.drain_cleanup_queue().await; + + assert_eq!(stats.get_user_curr_connects(user), 0, "Stats must converge to 0 after all drops"); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 0, "IP tracker must converge to 0"); +} + +// ------------------------------------------------------------------ +// Priority 3: IP Tracker Race Stress +// ------------------------------------------------------------------ + +#[tokio::test] +async fn client_ip_tracker_race_condition_stress() { + let user = "race-user"; + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 100).await; + + let iterations = 1000; + let mut tasks = Vec::new(); + + for i in 0..iterations { + let ip_tracker = Arc::clone(&ip_tracker); + let ip = IpAddr::V4(Ipv4Addr::new(10, 0, 0, (i % 254 + 1) as u8)); + + tasks.push(tokio::spawn(async move { + for _ in 0..10 { + if let Ok(()) = ip_tracker.check_and_add("race-user", ip).await { + ip_tracker.remove_ip("race-user", ip).await; + } + } + })); + } + + futures::future::join_all(tasks).await; + + assert_eq!(ip_tracker.get_active_ip_count(user).await, 0, "IP count must be zero after balanced add/remove burst"); +} + +#[tokio::test] +async fn client_limit_burst_peak_never_exceeds_cap() { + let user = "peak-cap-user"; + let limit = 32; + let attempts = 256; + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let mut config = ProxyConfig::default(); + config.access.user_max_tcp_conns.insert(user.to_string(), limit); + + let peak = Arc::new(AtomicU64::new(0)); + let mut tasks = Vec::with_capacity(attempts); + + for i in 0..attempts { + let stats = Arc::clone(&stats); + let ip_tracker = Arc::clone(&ip_tracker); + let config = config.clone(); + let peak = Arc::clone(&peak); + + tasks.push(tokio::spawn(async move { + let peer = SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(203, 0, 113, (i % 250 + 1) as u8)), + 20000 + i as u16, + ); + + let acquired = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker, + ) + .await; + + if let Ok(reservation) = acquired { + let now = stats.get_user_curr_connects(user); + loop { + let prev = peak.load(Ordering::Relaxed); + if now <= prev { + break; + } + if peak + .compare_exchange(prev, now, Ordering::Relaxed, Ordering::Relaxed) + .is_ok() + { + break; + } + } + tokio::time::sleep(Duration::from_millis(2)).await; + drop(reservation); + } + })); + } + + futures::future::join_all(tasks).await; + ip_tracker.drain_cleanup_queue().await; + + assert!( + peak.load(Ordering::Relaxed) <= limit as u64, + "peak concurrent reservations must not exceed configured cap" + ); + assert_eq!(stats.get_user_curr_connects(user), 0); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 0); +} + +#[tokio::test] +async fn client_quota_rejection_never_mutates_live_counters() { + let user = "quota-reject-user"; + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let mut config = ProxyConfig::default(); + config.access.user_data_quota.insert(user.to_string(), 0); + + let peer: SocketAddr = "198.51.100.201:31111".parse().unwrap(); + let res = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await; + + assert!(matches!(res, Err(ProxyError::DataQuotaExceeded { .. }))); + assert_eq!(stats.get_user_curr_connects(user), 0); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 0); +} + +#[tokio::test] +async fn client_expiration_rejection_never_mutates_live_counters() { + let user = "expired-user"; + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let mut config = ProxyConfig::default(); + config + .access + .user_expirations + .insert(user.to_string(), chrono::Utc::now() - chrono::Duration::seconds(1)); + + let peer: SocketAddr = "198.51.100.202:31112".parse().unwrap(); + let res = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await; + + assert!(matches!(res, Err(ProxyError::UserExpired { .. }))); + assert_eq!(stats.get_user_curr_connects(user), 0); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 0); +} + +#[tokio::test] +async fn client_ip_limit_failure_rolls_back_counter_exactly() { + let user = "ip-limit-rollback-user"; + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 1).await; + + let mut config = ProxyConfig::default(); + config.access.user_max_tcp_conns.insert(user.to_string(), 16); + + let first_peer: SocketAddr = "198.51.100.203:31113".parse().unwrap(); + let first = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + first_peer, + ip_tracker.clone(), + ) + .await + .unwrap(); + + let second_peer: SocketAddr = "198.51.100.204:31114".parse().unwrap(); + let second = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + second_peer, + ip_tracker.clone(), + ) + .await; + + assert!(matches!(second, Err(ProxyError::ConnectionLimitExceeded { .. }))); + assert_eq!(stats.get_user_curr_connects(user), 1); + + drop(first); + ip_tracker.drain_cleanup_queue().await; + + assert_eq!(stats.get_user_curr_connects(user), 0); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 0); +} + +#[tokio::test] +async fn client_parallel_limit_checks_success_path_leaves_no_residue() { + let user = "parallel-check-success-user"; + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 128).await; + + let mut config = ProxyConfig::default(); + config.access.user_max_tcp_conns.insert(user.to_string(), 128); + + let mut tasks = Vec::new(); + for i in 0..128u16 { + let stats = Arc::clone(&stats); + let ip_tracker = Arc::clone(&ip_tracker); + let config = config.clone(); + + tasks.push(tokio::spawn(async move { + let peer = SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(10, 10, (i / 255) as u8, (i % 255 + 1) as u8)), + 32000 + i, + ); + RunningClientHandler::check_user_limits_static(user, &config, &stats, peer, &ip_tracker) + .await + })); + } + + for result in futures::future::join_all(tasks).await { + assert!(result.unwrap().is_ok()); + } + + assert_eq!(stats.get_user_curr_connects(user), 0); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 0); +} + +#[tokio::test] +async fn client_parallel_limit_checks_failure_path_leaves_no_residue() { + let user = "parallel-check-failure-user"; + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 0).await; + + let mut config = ProxyConfig::default(); + config.access.user_max_tcp_conns.insert(user.to_string(), 512); + + let mut tasks = Vec::new(); + for i in 0..64u16 { + let stats = Arc::clone(&stats); + let ip_tracker = Arc::clone(&ip_tracker); + let config = config.clone(); + + tasks.push(tokio::spawn(async move { + let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(172, 16, 0, (i % 250 + 1) as u8)), 33000 + i); + RunningClientHandler::check_user_limits_static(user, &config, &stats, peer, &ip_tracker) + .await + })); + } + + let mut _denied = 0usize; + for result in futures::future::join_all(tasks).await { + match result.unwrap() { + Ok(()) => {} + Err(ProxyError::ConnectionLimitExceeded { .. }) => _denied += 1, + Err(other) => panic!("unexpected error: {other}"), + } + } + + assert_eq!(stats.get_user_curr_connects(user), 0); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 0); +} + +#[tokio::test] +async fn client_churn_mixed_success_failure_converges_to_zero_state() { + let user = "mixed-churn-user"; + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 4).await; + + let mut config = ProxyConfig::default(); + config.access.user_max_tcp_conns.insert(user.to_string(), 8); + + let mut tasks = Vec::new(); + for i in 0..200u16 { + let stats = Arc::clone(&stats); + let ip_tracker = Arc::clone(&ip_tracker); + let config = config.clone(); + + tasks.push(tokio::spawn(async move { + let peer = SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(192, 0, 2, (i % 16 + 1) as u8)), + 34000 + (i % 32), + ); + let maybe_res = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats, + peer, + ip_tracker, + ) + .await; + + if let Ok(reservation) = maybe_res { + tokio::time::sleep(Duration::from_millis((i % 3) as u64)).await; + drop(reservation); + } + })); + } + + futures::future::join_all(tasks).await; + ip_tracker.drain_cleanup_queue().await; + + assert_eq!(stats.get_user_curr_connects(user), 0); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 0); +} + +#[tokio::test] +async fn client_same_ip_parallel_attempts_allow_at_most_one_when_limit_is_one() { + let user = "same-ip-user"; + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 1).await; + + let mut config = ProxyConfig::default(); + config.access.user_max_tcp_conns.insert(user.to_string(), 1); + + let peer: SocketAddr = "203.0.113.44:35555".parse().unwrap(); + let mut tasks = Vec::new(); + + for _ in 0..64 { + let stats = Arc::clone(&stats); + let ip_tracker = Arc::clone(&ip_tracker); + let config = config.clone(); + tasks.push(tokio::spawn(async move { + RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats, + peer, + ip_tracker, + ) + .await + })); + } + + let mut granted = 0usize; + let mut reservations = Vec::new(); + for result in futures::future::join_all(tasks).await { + match result.unwrap() { + Ok(reservation) => { + granted += 1; + reservations.push(reservation); + } + Err(ProxyError::ConnectionLimitExceeded { .. }) => {} + Err(other) => panic!("unexpected error: {other}"), + } + } + + assert_eq!(granted, 1, "only one reservation may be granted for same IP with limit=1"); + drop(reservations); + ip_tracker.drain_cleanup_queue().await; + assert_eq!(stats.get_user_curr_connects(user), 0); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 0); +} + +#[tokio::test] +async fn client_repeat_acquire_release_cycles_never_accumulate_state() { + let user = "repeat-cycle-user"; + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 32).await; + + let mut config = ProxyConfig::default(); + config.access.user_max_tcp_conns.insert(user.to_string(), 32); + + for i in 0..500u16 { + let peer = SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(198, 18, (i / 250) as u8, (i % 250 + 1) as u8)), + 36000 + (i % 128), + ); + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await + .unwrap(); + drop(reservation); + } + + ip_tracker.drain_cleanup_queue().await; + assert_eq!(stats.get_user_curr_connects(user), 0); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 0); +} + +#[tokio::test] +async fn client_multi_user_isolation_under_parallel_limit_exhaustion() { + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let mut config = ProxyConfig::default(); + config.access.user_max_tcp_conns.insert("u1".to_string(), 8); + config.access.user_max_tcp_conns.insert("u2".to_string(), 8); + + let mut tasks = Vec::new(); + for i in 0..128u16 { + let stats = Arc::clone(&stats); + let ip_tracker = Arc::clone(&ip_tracker); + let config = config.clone(); + tasks.push(tokio::spawn(async move { + let user = if i % 2 == 0 { "u1" } else { "u2" }; + let peer = SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(100, 64, (i / 64) as u8, (i % 64 + 1) as u8)), + 37000 + i, + ); + RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats, + peer, + ip_tracker, + ) + .await + })); + } + + let mut u1_success = 0usize; + let mut u2_success = 0usize; + let mut reservations = Vec::new(); + for (idx, result) in futures::future::join_all(tasks).await.into_iter().enumerate() { + let user = if idx % 2 == 0 { "u1" } else { "u2" }; + match result.unwrap() { + Ok(reservation) => { + if user == "u1" { + u1_success += 1; + } else { + u2_success += 1; + } + reservations.push(reservation); + } + Err(ProxyError::ConnectionLimitExceeded { .. }) => {} + Err(other) => panic!("unexpected error: {other}"), + } + } + + assert_eq!(u1_success, 8, "u1 must get exactly its own configured cap"); + assert_eq!(u2_success, 8, "u2 must get exactly its own configured cap"); + + drop(reservations); + ip_tracker.drain_cleanup_queue().await; + assert_eq!(stats.get_user_curr_connects("u1"), 0); + assert_eq!(stats.get_user_curr_connects("u2"), 0); +} + +#[tokio::test] +async fn client_limit_recovery_after_full_rejection_wave() { + let user = "recover-user"; + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 1).await; + + let mut config = ProxyConfig::default(); + config.access.user_max_tcp_conns.insert(user.to_string(), 1); + + let first_peer: SocketAddr = "198.51.100.50:38001".parse().unwrap(); + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + first_peer, + ip_tracker.clone(), + ) + .await + .unwrap(); + + for i in 0..64u16 { + let peer = SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(198, 51, 100, (i % 60 + 1) as u8)), + 38002 + i, + ); + let denied = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await; + assert!(matches!(denied, Err(ProxyError::ConnectionLimitExceeded { .. }))); + } + + drop(reservation); + ip_tracker.drain_cleanup_queue().await; + assert_eq!(stats.get_user_curr_connects(user), 0); + + let recovery_peer: SocketAddr = "198.51.100.200:38999".parse().unwrap(); + let recovered = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + recovery_peer, + ip_tracker.clone(), + ) + .await; + assert!(recovered.is_ok(), "capacity must recover after prior holder drops"); +} + +#[tokio::test] +async fn client_dual_limit_cross_product_never_leaks_on_reject() { + let user = "dual-limit-user"; + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 2).await; + + let mut config = ProxyConfig::default(); + config.access.user_max_tcp_conns.insert(user.to_string(), 2); + + let p1: SocketAddr = "203.0.113.10:39001".parse().unwrap(); + let p2: SocketAddr = "203.0.113.11:39002".parse().unwrap(); + let r1 = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + p1, + ip_tracker.clone(), + ) + .await + .unwrap(); + let r2 = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + p2, + ip_tracker.clone(), + ) + .await + .unwrap(); + + for i in 0..32u16 { + let peer = SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(203, 0, 113, (50 + i) as u8)), + 39010 + i, + ); + let denied = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await; + assert!(matches!(denied, Err(ProxyError::ConnectionLimitExceeded { .. }))); + } + + assert_eq!(stats.get_user_curr_connects(user), 2); + drop((r1, r2)); + ip_tracker.drain_cleanup_queue().await; + assert_eq!(stats.get_user_curr_connects(user), 0); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 0); +} + +#[tokio::test] +async fn client_check_user_limits_concurrent_churn_no_counter_drift() { + let user = "check-drift-user"; + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 64).await; + + let mut config = ProxyConfig::default(); + config.access.user_max_tcp_conns.insert(user.to_string(), 64); + + let mut tasks = Vec::new(); + for i in 0..512u16 { + let stats = Arc::clone(&stats); + let ip_tracker = Arc::clone(&ip_tracker); + let config = config.clone(); + tasks.push(tokio::spawn(async move { + let peer = SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(172, 20, (i / 255) as u8, (i % 255 + 1) as u8)), + 40000 + (i % 500), + ); + let _ = RunningClientHandler::check_user_limits_static( + user, + &config, + &stats, + peer, + &ip_tracker, + ) + .await; + })); + } + + for task in futures::future::join_all(tasks).await { + task.unwrap(); + } + + assert_eq!(stats.get_user_curr_connects(user), 0); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 0); +} diff --git a/src/proxy/tests/client_beobachten_ttl_bounds_security_tests.rs b/src/proxy/tests/client_beobachten_ttl_bounds_security_tests.rs new file mode 100644 index 0000000..80f9834 --- /dev/null +++ b/src/proxy/tests/client_beobachten_ttl_bounds_security_tests.rs @@ -0,0 +1,126 @@ +use super::*; + +const BEOBACHTEN_TTL_MAX_MINUTES: u64 = 24 * 60; + +#[test] +fn beobachten_ttl_exact_upper_bound_is_preserved() { + let mut config = ProxyConfig::default(); + config.general.beobachten = true; + config.general.beobachten_minutes = BEOBACHTEN_TTL_MAX_MINUTES; + + let ttl = beobachten_ttl(&config); + assert_eq!( + ttl, + Duration::from_secs(BEOBACHTEN_TTL_MAX_MINUTES * 60), + "upper-bound TTL should remain unchanged" + ); +} + +#[test] +fn beobachten_ttl_above_upper_bound_is_clamped() { + let mut config = ProxyConfig::default(); + config.general.beobachten = true; + config.general.beobachten_minutes = BEOBACHTEN_TTL_MAX_MINUTES + 1; + + let ttl = beobachten_ttl(&config); + assert_eq!( + ttl, + Duration::from_secs(BEOBACHTEN_TTL_MAX_MINUTES * 60), + "TTL above security cap must be clamped" + ); +} + +#[test] +fn beobachten_ttl_u64_max_is_clamped_fail_safe() { + let mut config = ProxyConfig::default(); + config.general.beobachten = true; + config.general.beobachten_minutes = u64::MAX; + + let ttl = beobachten_ttl(&config); + assert_eq!( + ttl, + Duration::from_secs(BEOBACHTEN_TTL_MAX_MINUTES * 60), + "extreme configured TTL must not become multi-century retention" + ); +} + +#[test] +fn positive_one_minute_maps_to_exact_60_seconds() { + let mut config = ProxyConfig::default(); + config.general.beobachten = true; + config.general.beobachten_minutes = 1; + + assert_eq!(beobachten_ttl(&config), Duration::from_secs(60)); +} + +#[test] +fn adversarial_boundary_triplet_behaves_deterministically() { + let mut config = ProxyConfig::default(); + config.general.beobachten = true; + + config.general.beobachten_minutes = BEOBACHTEN_TTL_MAX_MINUTES - 1; + assert_eq!( + beobachten_ttl(&config), + Duration::from_secs((BEOBACHTEN_TTL_MAX_MINUTES - 1) * 60) + ); + + config.general.beobachten_minutes = BEOBACHTEN_TTL_MAX_MINUTES; + assert_eq!( + beobachten_ttl(&config), + Duration::from_secs(BEOBACHTEN_TTL_MAX_MINUTES * 60) + ); + + config.general.beobachten_minutes = BEOBACHTEN_TTL_MAX_MINUTES + 1; + assert_eq!( + beobachten_ttl(&config), + Duration::from_secs(BEOBACHTEN_TTL_MAX_MINUTES * 60) + ); +} + +#[test] +fn light_fuzz_random_minutes_match_fail_safe_model() { + let mut config = ProxyConfig::default(); + config.general.beobachten = true; + + let mut seed = 0xD15E_A5E5_F00D_BAADu64; + for _ in 0..8192 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + config.general.beobachten_minutes = seed; + let ttl = beobachten_ttl(&config); + let expected = if seed == 0 { + Duration::from_secs(60) + } else { + Duration::from_secs(seed.min(BEOBACHTEN_TTL_MAX_MINUTES) * 60) + }; + + assert_eq!(ttl, expected, "ttl mismatch for minutes={seed}"); + assert!(ttl <= Duration::from_secs(BEOBACHTEN_TTL_MAX_MINUTES * 60)); + } +} + +#[test] +fn stress_monotonic_minutes_remain_monotonic_until_cap_then_flat() { + let mut config = ProxyConfig::default(); + config.general.beobachten = true; + + let mut prev = Duration::from_secs(0); + for minutes in 0..=(BEOBACHTEN_TTL_MAX_MINUTES + 4096) { + config.general.beobachten_minutes = minutes; + let ttl = beobachten_ttl(&config); + + assert!(ttl >= prev, "ttl must be non-decreasing as minutes grow"); + assert!(ttl <= Duration::from_secs(BEOBACHTEN_TTL_MAX_MINUTES * 60)); + + if minutes > BEOBACHTEN_TTL_MAX_MINUTES { + assert_eq!( + ttl, + Duration::from_secs(BEOBACHTEN_TTL_MAX_MINUTES * 60), + "ttl must stay clamped once cap is exceeded" + ); + } + prev = ttl; + } +} diff --git a/src/proxy/tests/client_masking_blackhat_campaign_tests.rs b/src/proxy/tests/client_masking_blackhat_campaign_tests.rs new file mode 100644 index 0000000..3ea9dae --- /dev/null +++ b/src/proxy/tests/client_masking_blackhat_campaign_tests.rs @@ -0,0 +1,893 @@ +use super::*; +use crate::config::{UpstreamConfig, UpstreamType}; +use crate::crypto::sha256_hmac; +use crate::protocol::constants::{ + HANDSHAKE_LEN, + MAX_TLS_PLAINTEXT_SIZE, + MIN_TLS_CLIENT_HELLO_SIZE, + TLS_RECORD_APPLICATION, + TLS_VERSION, +}; +use crate::protocol::tls; +use std::collections::HashSet; +use std::net::SocketAddr; +use std::sync::Arc; +use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpListener; +use tokio::time::{Duration, Instant}; + +struct CampaignHarness { + config: Arc, + stats: Arc, + upstream_manager: Arc, + replay_checker: Arc, + buffer_pool: Arc, + rng: Arc, + route_runtime: Arc, + ip_tracker: Arc, + beobachten: Arc, +} + +fn new_upstream_manager(stats: Arc) -> Arc { + Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats, + )) +} + +fn build_mask_harness(secret_hex: &str, mask_port: u16) -> CampaignHarness { + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = mask_port; + cfg.censorship.mask_proxy_protocol = 0; + cfg.access.ignore_time_skew = true; + cfg.access + .users + .insert("user".to_string(), secret_hex.to_string()); + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + + CampaignHarness { + config, + stats: stats.clone(), + upstream_manager: new_upstream_manager(stats), + replay_checker: Arc::new(ReplayChecker::new(1024, Duration::from_secs(60))), + buffer_pool: Arc::new(BufferPool::new()), + rng: Arc::new(SecureRandom::new()), + route_runtime: Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + ip_tracker: Arc::new(UserIpTracker::new()), + beobachten: Arc::new(BeobachtenStore::new()), + } +} + +fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec { + assert!(tls_len <= u16::MAX as usize, "TLS length must fit into record header"); + + let total_len = 5 + tls_len; + let mut handshake = vec![fill; total_len]; + + handshake[0] = 0x16; + handshake[1] = 0x03; + handshake[2] = 0x01; + handshake[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes()); + + let session_id_len: usize = 32; + handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8; + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + let computed = sha256_hmac(secret, &handshake); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] + .copy_from_slice(&digest); + handshake +} + +fn wrap_tls_record(record_type: u8, payload: &[u8]) -> Vec { + let mut record = Vec::with_capacity(5 + payload.len()); + record.push(record_type); + record.extend_from_slice(&TLS_VERSION); + record.extend_from_slice(&(payload.len() as u16).to_be_bytes()); + record.extend_from_slice(payload); + record +} + +fn wrap_tls_application_data(payload: &[u8]) -> Vec { + wrap_tls_record(TLS_RECORD_APPLICATION, payload) +} + +async fn read_and_discard_tls_record_body(stream: &mut T, header: [u8; 5]) +where + T: tokio::io::AsyncRead + Unpin, +{ + let len = u16::from_be_bytes([header[3], header[4]]) as usize; + let mut body = vec![0u8; len]; + stream.read_exact(&mut body).await.unwrap(); +} + +async fn run_tls_success_mtproto_fail_capture( + harness: CampaignHarness, + peer: SocketAddr, + client_hello: Vec, + bad_mtproto_record: Vec, + trailing_records: Vec>, + expected_forward: Vec, +) { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let mut cfg = (*harness.config).clone(); + cfg.censorship.mask_port = backend_addr.port(); + let cfg = Arc::new(cfg); + + let expected = expected_forward.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = vec![0u8; expected.len()]; + stream.read_exact(&mut got).await.unwrap(); + got + }); + + let (server_side, mut client_side) = duplex(262144); + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + cfg, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + + let mut tls_response_head = [0u8; 5]; + client_side.read_exact(&mut tls_response_head).await.unwrap(); + assert_eq!(tls_response_head[0], 0x16); + read_and_discard_tls_record_body(&mut client_side, tls_response_head).await; + + client_side.write_all(&bad_mtproto_record).await.unwrap(); + for record in trailing_records { + client_side.write_all(&record).await.unwrap(); + } + + let got = tokio::time::timeout(Duration::from_secs(4), accept_task) + .await + .unwrap() + .unwrap(); + assert_eq!(got, expected_forward); + + client_side.shutdown().await.unwrap(); + let result = tokio::time::timeout(Duration::from_secs(4), handler) + .await + .unwrap() + .unwrap(); + assert!(result.is_ok()); +} + +async fn run_invalid_tls_capture(config: Arc, payload: Vec, expected: Vec) { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let mut cfg = (*config).clone(); + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + let cfg = Arc::new(cfg); + + let expected_probe = expected.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = vec![0u8; expected_probe.len()]; + stream.read_exact(&mut got).await.unwrap(); + got + }); + + let stats = Arc::new(Stats::new()); + let (server_side, mut client_side) = duplex(65536); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.77:45001".parse().unwrap(), + cfg, + stats, + new_upstream_manager(Arc::new(Stats::new())), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + false, + )); + + client_side.write_all(&payload).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let got = tokio::time::timeout(Duration::from_secs(4), accept_task) + .await + .unwrap() + .unwrap(); + assert_eq!(got, expected); + + let result = tokio::time::timeout(Duration::from_secs(4), handler) + .await + .unwrap() + .unwrap(); + assert!(result.is_ok()); +} + +#[tokio::test] +async fn blackhat_campaign_01_tail_only_record_is_forwarded_after_tls_success_mtproto_fail() { + let secret = [0xA1u8; 16]; + let harness = build_mask_harness("a1a1a1a1a1a1a1a1a1a1a1a1a1a1a1a1", 1); + let client_hello = make_valid_tls_client_hello(&secret, 11, 600, 0x41); + let bad_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); + let tail = wrap_tls_application_data(b"blackhat-tail-01"); + + run_tls_success_mtproto_fail_capture( + harness, + "198.51.100.1:55001".parse().unwrap(), + client_hello, + bad_record, + vec![tail.clone()], + tail, + ) + .await; +} + +#[tokio::test] +async fn blackhat_campaign_02_two_ordered_records_preserved_after_fallback() { + let secret = [0xA2u8; 16]; + let harness = build_mask_harness("a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2", 1); + let client_hello = make_valid_tls_client_hello(&secret, 12, 600, 0x42); + let bad_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); + let r1 = wrap_tls_application_data(b"first"); + let r2 = wrap_tls_application_data(b"second"); + let expected = [r1.clone(), r2.clone()].concat(); + + run_tls_success_mtproto_fail_capture( + harness, + "198.51.100.2:55002".parse().unwrap(), + client_hello, + bad_record, + vec![r1, r2], + expected, + ) + .await; +} + +#[tokio::test] +async fn blackhat_campaign_03_large_tls_application_record_survives_fallback() { + let secret = [0xA3u8; 16]; + let harness = build_mask_harness("a3a3a3a3a3a3a3a3a3a3a3a3a3a3a3a3", 1); + let client_hello = make_valid_tls_client_hello(&secret, 13, 600, 0x43); + let bad_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); + let big_payload = vec![0x5Au8; MAX_TLS_PLAINTEXT_SIZE]; + let big_record = wrap_tls_application_data(&big_payload); + + run_tls_success_mtproto_fail_capture( + harness, + "198.51.100.3:55003".parse().unwrap(), + client_hello, + bad_record, + vec![big_record.clone()], + big_record, + ) + .await; +} + +#[tokio::test] +async fn blackhat_campaign_04_coalesced_tail_in_failed_record_is_reframed_and_forwarded() { + let secret = [0xA4u8; 16]; + let harness = build_mask_harness("a4a4a4a4a4a4a4a4a4a4a4a4a4a4a4a4", 1); + let client_hello = make_valid_tls_client_hello(&secret, 14, 600, 0x44); + + let coalesced_tail = b"coalesced-tail-blackhat".to_vec(); + let mut bad_payload = vec![0u8; HANDSHAKE_LEN]; + bad_payload.extend_from_slice(&coalesced_tail); + let bad_record = wrap_tls_application_data(&bad_payload); + let expected = wrap_tls_application_data(&coalesced_tail); + + run_tls_success_mtproto_fail_capture( + harness, + "198.51.100.4:55004".parse().unwrap(), + client_hello, + bad_record, + Vec::new(), + expected, + ) + .await; +} + +#[tokio::test] +async fn blackhat_campaign_05_coalesced_tail_plus_next_record_keep_wire_order() { + let secret = [0xA5u8; 16]; + let harness = build_mask_harness("a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5", 1); + let client_hello = make_valid_tls_client_hello(&secret, 15, 600, 0x45); + + let coalesced_tail = b"inline-tail".to_vec(); + let mut bad_payload = vec![0u8; HANDSHAKE_LEN]; + bad_payload.extend_from_slice(&coalesced_tail); + let bad_record = wrap_tls_application_data(&bad_payload); + let next_record = wrap_tls_application_data(b"next-record"); + + let expected = [ + wrap_tls_application_data(&coalesced_tail), + next_record.clone(), + ] + .concat(); + + run_tls_success_mtproto_fail_capture( + harness, + "198.51.100.5:55005".parse().unwrap(), + client_hello, + bad_record, + vec![next_record], + expected, + ) + .await; +} + +#[tokio::test] +async fn blackhat_campaign_06_replayed_tls_hello_is_masked_without_serverhello() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let harness = build_mask_harness("a6a6a6a6a6a6a6a6a6a6a6a6a6a6a6a6", backend_addr.port()); + let replay_checker = harness.replay_checker.clone(); + let client_hello = make_valid_tls_client_hello(&[0xA6; 16], 16, 600, 0x46); + let invalid_mtproto_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); + let first_tail = wrap_tls_application_data(b"seed-tail"); + + let expected_hello = client_hello.clone(); + let expected_tail = first_tail.clone(); + + let accept_task = tokio::spawn(async move { + let (mut s1, _) = listener.accept().await.unwrap(); + let mut got_tail = vec![0u8; expected_tail.len()]; + s1.read_exact(&mut got_tail).await.unwrap(); + assert_eq!(got_tail, expected_tail); + drop(s1); + + let (mut s2, _) = listener.accept().await.unwrap(); + let mut got_hello = vec![0u8; expected_hello.len()]; + s2.read_exact(&mut got_hello).await.unwrap(); + got_hello + }); + + let run_one = |checker: Arc, send_mtproto: bool| { + let mut cfg = (*harness.config).clone(); + cfg.censorship.mask_port = backend_addr.port(); + let cfg = Arc::new(cfg); + let hello = client_hello.clone(); + let invalid_mtproto_record = invalid_mtproto_record.clone(); + let first_tail = first_tail.clone(); + let stats = harness.stats.clone(); + let upstream = harness.upstream_manager.clone(); + let pool = harness.buffer_pool.clone(); + let rng = harness.rng.clone(); + let route = harness.route_runtime.clone(); + let ipt = harness.ip_tracker.clone(); + let beob = harness.beobachten.clone(); + + async move { + let (server_side, mut client_side) = duplex(131072); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.6:55006".parse().unwrap(), + cfg, + stats, + upstream, + checker, + pool, + rng, + None, + route, + None, + ipt, + beob, + false, + )); + + client_side.write_all(&hello).await.unwrap(); + if send_mtproto { + let mut head = [0u8; 5]; + client_side.read_exact(&mut head).await.unwrap(); + assert_eq!(head[0], 0x16); + read_and_discard_tls_record_body(&mut client_side, head).await; + client_side.write_all(&invalid_mtproto_record).await.unwrap(); + client_side.write_all(&first_tail).await.unwrap(); + } else { + let mut one = [0u8; 1]; + let no_server_hello = tokio::time::timeout( + Duration::from_millis(300), + client_side.read_exact(&mut one), + ) + .await; + assert!(no_server_hello.is_err() || no_server_hello.unwrap().is_err()); + } + client_side.shutdown().await.unwrap(); + let result = tokio::time::timeout(Duration::from_secs(4), handler) + .await + .unwrap() + .unwrap(); + assert!(result.is_ok()); + } + }; + + run_one(replay_checker.clone(), true).await; + run_one(replay_checker, false).await; + + let got = tokio::time::timeout(Duration::from_secs(4), accept_task) + .await + .unwrap() + .unwrap(); + assert_eq!(got, client_hello); +} + +#[tokio::test] +async fn blackhat_campaign_07_truncated_clienthello_exact_prefix_is_forwarded() { + let mut payload = vec![0u8; 5 + 37]; + payload[0] = 0x16; + payload[1] = 0x03; + payload[2] = 0x01; + payload[3..5].copy_from_slice(&600u16.to_be_bytes()); + payload[5..].fill(0x71); + + run_invalid_tls_capture(Arc::new(ProxyConfig::default()), payload.clone(), payload).await; +} + +#[tokio::test] +async fn blackhat_campaign_08_out_of_bounds_len_forwards_header_only() { + let header = vec![0x16, 0x03, 0x01, 0xFF, 0xFF]; + run_invalid_tls_capture(Arc::new(ProxyConfig::default()), header.clone(), header).await; +} + +#[tokio::test] +async fn blackhat_campaign_09_fragmented_header_then_partial_body_masks_seen_bytes_only() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let mut cfg = ProxyConfig::default(); + cfg.censorship.mask = true; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_unix_sock = None; + + let expected = { + let mut x = vec![0u8; 5 + 11]; + x[0] = 0x16; + x[1] = 0x03; + x[2] = 0x01; + x[3..5].copy_from_slice(&600u16.to_be_bytes()); + x[5..].fill(0xCC); + x + }; + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = vec![0u8; expected.len()]; + stream.read_exact(&mut got).await.unwrap(); + got + }); + + let (server_side, mut client_side) = duplex(65536); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.9:55009".parse().unwrap(), + Arc::new(cfg), + Arc::new(Stats::new()), + new_upstream_manager(Arc::new(Stats::new())), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + false, + )); + + client_side.write_all(&[0x16, 0x03]).await.unwrap(); + client_side.write_all(&[0x01, 0x02, 0x58]).await.unwrap(); + client_side.write_all(&vec![0xCC; 11]).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let got = tokio::time::timeout(Duration::from_secs(4), accept_task) + .await + .unwrap() + .unwrap(); + assert_eq!(got.len(), 16); + + let result = tokio::time::timeout(Duration::from_secs(4), handler) + .await + .unwrap() + .unwrap(); + assert!(result.is_ok()); +} + +#[tokio::test] +async fn blackhat_campaign_10_zero_handshake_timeout_with_delay_still_avoids_timeout_counter() { + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = 1; + cfg.timeouts.client_handshake = 0; + cfg.censorship.server_hello_delay_min_ms = 700; + cfg.censorship.server_hello_delay_max_ms = 700; + + let stats = Arc::new(Stats::new()); + let (server_side, mut client_side) = duplex(4096); + let started = Instant::now(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.10:55010".parse().unwrap(), + Arc::new(cfg), + stats.clone(), + new_upstream_manager(Arc::new(Stats::new())), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + false, + )); + + let mut invalid = vec![0u8; 5 + 700]; + invalid[0] = 0x16; + invalid[1] = 0x03; + invalid[2] = 0x01; + invalid[3..5].copy_from_slice(&700u16.to_be_bytes()); + invalid[5..].fill(0x66); + + client_side.write_all(&invalid).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let result = tokio::time::timeout(Duration::from_secs(4), handler) + .await + .unwrap() + .unwrap(); + assert!(result.is_ok()); + assert_eq!(stats.get_handshake_timeouts(), 0); + assert!(started.elapsed() >= Duration::from_millis(650)); +} + +#[tokio::test] +async fn blackhat_campaign_11_parallel_bad_tls_probes_all_masked_without_timeouts() { + let n = 24usize; + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let mut cfg = ProxyConfig::default(); + cfg.censorship.mask = true; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_port = backend_addr.port(); + + let stats = Arc::new(Stats::new()); + let accept_task = tokio::spawn(async move { + let mut seen = HashSet::new(); + for _ in 0..n { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut hdr = [0u8; 5]; + stream.read_exact(&mut hdr).await.unwrap(); + seen.insert(hdr.to_vec()); + } + seen + }); + + let mut tasks = Vec::new(); + for i in 0..n { + let mut hdr = [0u8; 5]; + hdr[0] = 0x16; + hdr[1] = 0x03; + hdr[2] = 0x01; + hdr[3] = 0xFF; + hdr[4] = i as u8; + + let cfg = Arc::new(cfg.clone()); + let stats = stats.clone(); + tasks.push(tokio::spawn(async move { + let (server_side, mut client_side) = duplex(4096); + let handler = tokio::spawn(handle_client_stream( + server_side, + format!("198.51.100.11:{}", 56000 + i).parse().unwrap(), + cfg, + stats, + new_upstream_manager(Arc::new(Stats::new())), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + false, + )); + client_side.write_all(&hdr).await.unwrap(); + client_side.shutdown().await.unwrap(); + let result = tokio::time::timeout(Duration::from_secs(4), handler) + .await + .unwrap() + .unwrap(); + assert!(result.is_ok()); + hdr.to_vec() + })); + } + + let mut expected = HashSet::new(); + for t in tasks { + expected.insert(t.await.unwrap()); + } + + let seen = tokio::time::timeout(Duration::from_secs(6), accept_task) + .await + .unwrap() + .unwrap(); + assert_eq!(seen, expected); + assert_eq!(stats.get_handshake_timeouts(), 0); +} + +#[tokio::test] +async fn blackhat_campaign_12_parallel_tls_success_mtproto_fail_sessions_keep_isolation() { + let sessions = 16usize; + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let mut expected = HashSet::new(); + for i in 0..sessions { + let rec = wrap_tls_application_data(&vec![i as u8; 8 + i]); + expected.insert(rec); + } + + let accept_task = tokio::spawn(async move { + let mut got_set = HashSet::new(); + for _ in 0..sessions { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut head = [0u8; 5]; + stream.read_exact(&mut head).await.unwrap(); + let len = u16::from_be_bytes([head[3], head[4]]) as usize; + let mut rec = vec![0u8; 5 + len]; + rec[..5].copy_from_slice(&head); + stream.read_exact(&mut rec[5..]).await.unwrap(); + got_set.insert(rec); + } + got_set + }); + + let mut tasks = Vec::new(); + for i in 0..sessions { + let mut harness = build_mask_harness("abababababababababababababababab", backend_addr.port()); + let mut cfg = (*harness.config).clone(); + cfg.censorship.mask_port = backend_addr.port(); + harness.config = Arc::new(cfg); + tasks.push(tokio::spawn(async move { + let secret = [0xABu8; 16]; + let hello = make_valid_tls_client_hello(&secret, 100 + i as u32, 600, 0x40 + (i as u8 % 10)); + let bad = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); + let tail = wrap_tls_application_data(&vec![i as u8; 8 + i]); + let (server_side, mut client_side) = duplex(131072); + let handler = tokio::spawn(handle_client_stream( + server_side, + format!("198.51.100.12:{}", 56100 + i).parse().unwrap(), + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&hello).await.unwrap(); + let mut head = [0u8; 5]; + client_side.read_exact(&mut head).await.unwrap(); + read_and_discard_tls_record_body(&mut client_side, head).await; + client_side.write_all(&bad).await.unwrap(); + client_side.write_all(&tail).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let result = tokio::time::timeout(Duration::from_secs(5), handler) + .await + .unwrap() + .unwrap(); + assert!(result.is_ok()); + tail + })); + } + + let mut produced = HashSet::new(); + for t in tasks { + produced.insert(t.await.unwrap()); + } + + let observed = tokio::time::timeout(Duration::from_secs(8), accept_task) + .await + .unwrap() + .unwrap(); + + assert_eq!(produced, expected); + assert_eq!(observed, expected); +} + +#[tokio::test] +async fn blackhat_campaign_13_backend_down_does_not_escalate_to_handshake_timeout() { + let mut cfg = ProxyConfig::default(); + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = 1; + cfg.timeouts.client_handshake = 1; + + let stats = Arc::new(Stats::new()); + let (server_side, mut client_side) = duplex(4096); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.13:55013".parse().unwrap(), + Arc::new(cfg), + stats.clone(), + new_upstream_manager(Arc::new(Stats::new())), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + false, + )); + + let bad = vec![0x16, 0x03, 0x01, 0xFF, 0x00]; + client_side.write_all(&bad).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let result = tokio::time::timeout(Duration::from_secs(4), handler) + .await + .unwrap() + .unwrap(); + assert!(result.is_ok()); + assert_eq!(stats.get_handshake_timeouts(), 0); +} + +#[tokio::test] +async fn blackhat_campaign_14_masking_disabled_path_finishes_cleanly() { + let mut cfg = ProxyConfig::default(); + cfg.censorship.mask = false; + cfg.timeouts.client_handshake = 1; + + let stats = Arc::new(Stats::new()); + let (server_side, mut client_side) = duplex(4096); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.14:55014".parse().unwrap(), + Arc::new(cfg), + stats.clone(), + new_upstream_manager(Arc::new(Stats::new())), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + false, + )); + + let bad = vec![0x16, 0x03, 0x01, 0xFF, 0xF0]; + client_side.write_all(&bad).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let result = tokio::time::timeout(Duration::from_secs(4), handler) + .await + .unwrap() + .unwrap(); + assert!(result.is_ok()); + assert_eq!(stats.get_handshake_timeouts(), 0); +} + +#[tokio::test] +async fn blackhat_campaign_15_light_fuzz_tls_lengths_and_fragmentation() { + let mut seed = 0x9E3779B97F4A7C15u64; + + for idx in 0..20u16 { + seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1); + let mut tls_len = (seed as usize) % 20000; + if idx % 3 == 0 { + tls_len = MAX_TLS_PLAINTEXT_SIZE + 1 + (tls_len % 1024); + } + + let body_to_send = if (MIN_TLS_CLIENT_HELLO_SIZE..=MAX_TLS_PLAINTEXT_SIZE).contains(&tls_len) + { + (seed as usize % 29).min(tls_len.saturating_sub(1)) + } else { + 0 + }; + + let mut probe = vec![0u8; 5 + body_to_send]; + probe[0] = 0x16; + probe[1] = 0x03; + probe[2] = 0x01; + probe[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes()); + for b in &mut probe[5..] { + seed = seed.wrapping_mul(2862933555777941757).wrapping_add(3037000493); + *b = (seed >> 24) as u8; + } + + let expected = probe.clone(); + run_invalid_tls_capture(Arc::new(ProxyConfig::default()), probe, expected).await; + } +} + +#[tokio::test] +async fn blackhat_campaign_16_mixed_probe_burst_stress_finishes_without_panics() { + let cases = 18usize; + let mut tasks = Vec::new(); + + for i in 0..cases { + tasks.push(tokio::spawn(async move { + if i % 2 == 0 { + let mut probe = vec![0u8; 5 + (i % 13)]; + probe[0] = 0x16; + probe[1] = 0x03; + probe[2] = 0x01; + probe[3..5].copy_from_slice(&600u16.to_be_bytes()); + probe[5..].fill((0x90 + i as u8) ^ 0x5A); + run_invalid_tls_capture(Arc::new(ProxyConfig::default()), probe.clone(), probe).await; + } else { + let hdr = vec![0x16, 0x03, 0x01, 0xFF, i as u8]; + run_invalid_tls_capture(Arc::new(ProxyConfig::default()), hdr.clone(), hdr).await; + } + })); + } + + for task in tasks { + task.await.unwrap(); + } +} diff --git a/src/proxy/tests/client_masking_budget_security_tests.rs b/src/proxy/tests/client_masking_budget_security_tests.rs new file mode 100644 index 0000000..8dcf114 --- /dev/null +++ b/src/proxy/tests/client_masking_budget_security_tests.rs @@ -0,0 +1,244 @@ +use super::*; +use crate::config::{UpstreamConfig, UpstreamType}; +use crate::crypto::sha256_hmac; +use crate::protocol::constants::{HANDSHAKE_LEN, TLS_VERSION}; +use crate::protocol::tls; +use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpListener; +use tokio::time::{Duration, Instant}; + +struct PipelineHarness { + config: Arc, + stats: Arc, + upstream_manager: Arc, + replay_checker: Arc, + buffer_pool: Arc, + rng: Arc, + route_runtime: Arc, + ip_tracker: Arc, + beobachten: Arc, +} + +fn build_harness(config: ProxyConfig) -> PipelineHarness { + let config = Arc::new(config); + let stats = Arc::new(Stats::new()); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + + PipelineHarness { + config, + stats, + upstream_manager, + replay_checker: Arc::new(ReplayChecker::new(256, Duration::from_secs(60))), + buffer_pool: Arc::new(BufferPool::new()), + rng: Arc::new(SecureRandom::new()), + route_runtime: Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + ip_tracker: Arc::new(UserIpTracker::new()), + beobachten: Arc::new(BeobachtenStore::new()), + } +} + +fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec { + assert!(tls_len <= u16::MAX as usize, "TLS length must fit into record header"); + + let total_len = 5 + tls_len; + let mut handshake = vec![fill; total_len]; + + handshake[0] = 0x16; + handshake[1] = 0x03; + handshake[2] = 0x01; + handshake[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes()); + + let session_id_len: usize = 32; + handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8; + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + let computed = sha256_hmac(secret, &handshake); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] + .copy_from_slice(&digest); + handshake +} + +fn wrap_tls_application_data(payload: &[u8]) -> Vec { + let mut record = Vec::with_capacity(5 + payload.len()); + record.push(0x17); + record.extend_from_slice(&TLS_VERSION); + record.extend_from_slice(&(payload.len() as u16).to_be_bytes()); + record.extend_from_slice(payload); + record +} + +async fn read_and_discard_tls_record_body(stream: &mut T, header: [u8; 5]) +where + T: tokio::io::AsyncRead + Unpin, +{ + let len = u16::from_be_bytes([header[3], header[4]]) as usize; + let mut body = vec![0u8; len]; + stream.read_exact(&mut body).await.unwrap(); +} + +#[tokio::test] +async fn masking_runs_outside_handshake_timeout_budget_with_high_reject_delay() { + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = None; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = 1; + config.timeouts.client_handshake = 0; + config.censorship.server_hello_delay_min_ms = 730; + config.censorship.server_hello_delay_max_ms = 730; + + let harness = build_harness(config); + let stats = harness.stats.clone(); + + let (server_side, mut client_side) = duplex(4096); + let peer: SocketAddr = "198.51.100.241:56541".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + let mut invalid_hello = vec![0u8; 5 + 600]; + invalid_hello[0] = 0x16; + invalid_hello[1] = 0x03; + invalid_hello[2] = 0x01; + invalid_hello[3..5].copy_from_slice(&600u16.to_be_bytes()); + invalid_hello[5..].fill(0x44); + + let started = Instant::now(); + client_side.write_all(&invalid_hello).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let result = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + + assert!(result.is_ok(), "bad-client fallback must not be canceled by handshake timeout"); + assert_eq!( + stats.get_handshake_timeouts(), + 0, + "masking fallback path must not increment handshake timeout counter" + ); + assert!( + started.elapsed() >= Duration::from_millis(700), + "configured reject delay should still be visible before masking" + ); +} + +#[tokio::test] +async fn tls_mtproto_bad_client_does_not_reinject_clienthello_into_mask_backend() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = None; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_proxy_protocol = 0; + config.access.ignore_time_skew = true; + config + .access + .users + .insert("user".to_string(), "d0d0d0d0d0d0d0d0d0d0d0d0d0d0d0d0".to_string()); + + let harness = build_harness(config); + + let secret = [0xD0u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 0, 600, 0x41); + let invalid_mtproto_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); + let trailing_record = wrap_tls_application_data(b"no-clienthello-reinject"); + let expected_trailing = trailing_record.clone(); + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got = vec![0u8; expected_trailing.len()]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!( + got, + expected_trailing, + "mask backend must receive only post-handshake trailing TLS records" + ); + }); + + let (server_side, mut client_side) = duplex(131072); + let peer: SocketAddr = "198.51.100.242:56542".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + + let mut tls_response_head = [0u8; 5]; + client_side.read_exact(&mut tls_response_head).await.unwrap(); + assert_eq!(tls_response_head[0], 0x16); + read_and_discard_tls_record_body(&mut client_side, tls_response_head).await; + + client_side.write_all(&invalid_mtproto_record).await.unwrap(); + client_side.write_all(&trailing_record).await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let result = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + assert!(result.is_ok()); +} diff --git a/src/proxy/tests/client_masking_diagnostics_security_tests.rs b/src/proxy/tests/client_masking_diagnostics_security_tests.rs new file mode 100644 index 0000000..1d069c6 --- /dev/null +++ b/src/proxy/tests/client_masking_diagnostics_security_tests.rs @@ -0,0 +1,193 @@ +use super::*; +use crate::config::{UpstreamConfig, UpstreamType}; +use std::sync::Arc; +use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpListener; +use tokio::time::{Duration, Instant}; + +fn new_upstream_manager(stats: Arc) -> Arc { + Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats, + )) +} + +fn percentile_ms(mut values: Vec, p_num: usize, p_den: usize) -> u128 { + values.sort_unstable(); + if values.is_empty() { + return 0; + } + let idx = ((values.len() - 1) * p_num) / p_den; + values[idx] +} + +async fn measure_reject_duration_ms(body_sent: usize) -> u128 { + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = 1; + cfg.timeouts.client_handshake = 1; + cfg.censorship.server_hello_delay_min_ms = 700; + cfg.censorship.server_hello_delay_max_ms = 700; + + let (server_side, mut client_side) = duplex(65536); + let started = Instant::now(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.170:56170".parse().unwrap(), + Arc::new(cfg), + Arc::new(Stats::new()), + new_upstream_manager(Arc::new(Stats::new())), + Arc::new(ReplayChecker::new(256, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + false, + )); + + let mut probe = vec![0u8; 5 + body_sent]; + probe[0] = 0x16; + probe[1] = 0x03; + probe[2] = 0x01; + probe[3..5].copy_from_slice(&600u16.to_be_bytes()); + probe[5..].fill(0xA7); + + client_side.write_all(&probe).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let result = tokio::time::timeout(Duration::from_secs(4), handler) + .await + .unwrap() + .unwrap(); + assert!(result.is_ok()); + + started.elapsed().as_millis() +} + +async fn capture_forwarded_len(body_sent: usize) -> usize { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_shape_hardening = false; + cfg.timeouts.client_handshake = 1; + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = Vec::new(); + let _ = tokio::time::timeout(Duration::from_secs(2), stream.read_to_end(&mut got)).await; + got.len() + }); + + let (server_side, mut client_side) = duplex(65536); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.171:56171".parse().unwrap(), + Arc::new(cfg), + Arc::new(Stats::new()), + new_upstream_manager(Arc::new(Stats::new())), + Arc::new(ReplayChecker::new(256, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + false, + )); + + let mut probe = vec![0u8; 5 + body_sent]; + probe[0] = 0x16; + probe[1] = 0x03; + probe[2] = 0x01; + probe[3..5].copy_from_slice(&600u16.to_be_bytes()); + probe[5..].fill(0xB4); + + client_side.write_all(&probe).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let result = tokio::time::timeout(Duration::from_secs(4), handler) + .await + .unwrap() + .unwrap(); + assert!(result.is_ok()); + + tokio::time::timeout(Duration::from_secs(4), accept_task) + .await + .unwrap() + .unwrap() +} + +#[tokio::test] +async fn diagnostic_timing_profiles_are_within_realistic_guardrails() { + let classes = [17usize, 511usize, 1023usize, 4095usize]; + for class in classes { + let mut samples = Vec::new(); + for _ in 0..8 { + samples.push(measure_reject_duration_ms(class).await); + } + + let p50 = percentile_ms(samples.clone(), 50, 100); + let p95 = percentile_ms(samples.clone(), 95, 100); + let max = *samples.iter().max().unwrap(); + println!( + "diagnostic_timing class={} p50={}ms p95={}ms max={}ms", + class, p50, p95, max + ); + + assert!(p50 >= 650, "p50 too low for delayed reject class={}", class); + assert!(p95 <= 1200, "p95 too high for delayed reject class={}", class); + assert!(max <= 1500, "max too high for delayed reject class={}", class); + } +} + +#[tokio::test] +async fn diagnostic_forwarded_size_profiles_by_probe_class() { + let classes = [0usize, 1usize, 7usize, 17usize, 63usize, 511usize, 1023usize, 2047usize]; + let mut observed = Vec::new(); + + for class in classes { + let len = capture_forwarded_len(class).await; + println!("diagnostic_shape class={} forwarded_len={}", class, len); + observed.push(len as u128); + assert_eq!(len, 5 + class, "unexpected forwarded len for class={}", class); + } + + let p50 = percentile_ms(observed.clone(), 50, 100); + let p95 = percentile_ms(observed.clone(), 95, 100); + let max = *observed.iter().max().unwrap(); + println!( + "diagnostic_shape_summary p50={}bytes p95={}bytes max={}bytes", + p50, p95, max + ); + + assert!(p95 >= p50); + assert!(max >= p95); +} diff --git a/src/proxy/tests/client_masking_hard_adversarial_tests.rs b/src/proxy/tests/client_masking_hard_adversarial_tests.rs new file mode 100644 index 0000000..cdaede5 --- /dev/null +++ b/src/proxy/tests/client_masking_hard_adversarial_tests.rs @@ -0,0 +1,701 @@ +use super::*; +use crate::config::{UpstreamConfig, UpstreamType}; +use crate::crypto::sha256_hmac; +use crate::protocol::constants::{HANDSHAKE_LEN, TLS_RECORD_APPLICATION, TLS_VERSION}; +use crate::protocol::tls; +use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpListener; +use tokio::time::{Duration, Instant}; + +struct Harness { + config: Arc, + stats: Arc, + upstream_manager: Arc, + replay_checker: Arc, + buffer_pool: Arc, + rng: Arc, + route_runtime: Arc, + ip_tracker: Arc, + beobachten: Arc, +} + +fn new_upstream_manager(stats: Arc) -> Arc { + Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats, + )) +} + +fn build_harness(secret_hex: &str, mask_port: u16) -> Harness { + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = mask_port; + cfg.censorship.mask_proxy_protocol = 0; + cfg.access.ignore_time_skew = true; + cfg.access + .users + .insert("user".to_string(), secret_hex.to_string()); + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + + Harness { + config, + stats: stats.clone(), + upstream_manager: new_upstream_manager(stats), + replay_checker: Arc::new(ReplayChecker::new(512, Duration::from_secs(60))), + buffer_pool: Arc::new(BufferPool::new()), + rng: Arc::new(SecureRandom::new()), + route_runtime: Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + ip_tracker: Arc::new(UserIpTracker::new()), + beobachten: Arc::new(BeobachtenStore::new()), + } +} + +fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec { + assert!(tls_len <= u16::MAX as usize, "TLS length must fit into record header"); + + let total_len = 5 + tls_len; + let mut handshake = vec![fill; total_len]; + + handshake[0] = 0x16; + handshake[1] = 0x03; + handshake[2] = 0x01; + handshake[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes()); + + let session_id_len: usize = 32; + handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8; + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + let computed = sha256_hmac(secret, &handshake); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] + .copy_from_slice(&digest); + handshake +} + +fn wrap_tls_application_data(payload: &[u8]) -> Vec { + let mut record = Vec::with_capacity(5 + payload.len()); + record.push(TLS_RECORD_APPLICATION); + record.extend_from_slice(&TLS_VERSION); + record.extend_from_slice(&(payload.len() as u16).to_be_bytes()); + record.extend_from_slice(payload); + record +} + +async fn read_tls_record_body(stream: &mut T, header: [u8; 5]) +where + T: tokio::io::AsyncRead + Unpin, +{ + let len = u16::from_be_bytes([header[3], header[4]]) as usize; + let mut body = vec![0u8; len]; + stream.read_exact(&mut body).await.unwrap(); +} + +async fn run_tls_success_mtproto_fail_capture( + secret_hex: &str, + secret: [u8; 16], + timestamp: u32, + trailing_records: Vec>, +) -> Vec { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let expected_len = trailing_records.iter().map(Vec::len).sum::(); + let expected_concat = trailing_records.concat(); + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = vec![0u8; expected_len]; + stream.read_exact(&mut got).await.unwrap(); + got + }); + + let harness = build_harness(secret_hex, backend_addr.port()); + let client_hello = make_valid_tls_client_hello(&secret, timestamp, 600, 0x42); + let invalid_mtproto_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); + + let (server_side, mut client_side) = duplex(262144); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.210:56010".parse().unwrap(), + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + + let mut tls_response_head = [0u8; 5]; + client_side.read_exact(&mut tls_response_head).await.unwrap(); + assert_eq!(tls_response_head[0], 0x16); + read_tls_record_body(&mut client_side, tls_response_head).await; + + client_side.write_all(&invalid_mtproto_record).await.unwrap(); + for record in trailing_records { + client_side.write_all(&record).await.unwrap(); + } + + let got = tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + assert_eq!(got, expected_concat); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + + got +} + +#[tokio::test] +async fn masking_budget_survives_zero_handshake_timeout_with_delay() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.timeouts.client_handshake = 0; + cfg.censorship.server_hello_delay_min_ms = 720; + cfg.censorship.server_hello_delay_max_ms = 720; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = vec![0u8; 605]; + stream.read_exact(&mut got).await.unwrap(); + got + }); + + let (server_side, mut client_side) = duplex(65536); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.211:56011".parse().unwrap(), + config, + stats.clone(), + new_upstream_manager(stats.clone()), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + false, + )); + + let mut invalid_hello = vec![0u8; 605]; + invalid_hello[0] = 0x16; + invalid_hello[1] = 0x03; + invalid_hello[2] = 0x01; + invalid_hello[3..5].copy_from_slice(&600u16.to_be_bytes()); + invalid_hello[5..].fill(0xA1); + + let started = Instant::now(); + client_side.write_all(&invalid_hello).await.unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + client_side.shutdown().await.unwrap(); + let result = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + + assert!(result.is_ok()); + assert_eq!(stats.get_handshake_timeouts(), 0); + assert!(started.elapsed() >= Duration::from_millis(680)); +} + +#[tokio::test] +async fn tls_mtproto_fail_forwards_only_trailing_record() { + let tail = wrap_tls_application_data(b"tail-only"); + let got = run_tls_success_mtproto_fail_capture( + "c1c1c1c1c1c1c1c1c1c1c1c1c1c1c1c1", + [0xC1; 16], + 1, + vec![tail.clone()], + ) + .await; + assert_eq!(got, tail); +} + +#[tokio::test] +async fn replayed_tls_hello_gets_no_serverhello_and_is_masked() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let harness = build_harness("c2c2c2c2c2c2c2c2c2c2c2c2c2c2c2c2", backend_addr.port()); + let secret = [0xC2u8; 16]; + let hello = make_valid_tls_client_hello(&secret, 2, 600, 0x41); + let invalid_mtproto_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); + let first_tail = wrap_tls_application_data(b"seed"); + + let expected_hello = hello.clone(); + let expected_tail = first_tail.clone(); + + let accept_task = tokio::spawn(async move { + let (mut s1, _) = listener.accept().await.unwrap(); + let mut got_tail = vec![0u8; expected_tail.len()]; + s1.read_exact(&mut got_tail).await.unwrap(); + assert_eq!(got_tail, expected_tail); + drop(s1); + + let (mut s2, _) = listener.accept().await.unwrap(); + let mut got_hello = vec![0u8; expected_hello.len()]; + s2.read_exact(&mut got_hello).await.unwrap(); + assert_eq!(got_hello, expected_hello); + }); + + let run_session = |send_mtproto: bool| { + let (server_side, mut client_side) = duplex(131072); + let config = harness.config.clone(); + let stats = harness.stats.clone(); + let upstream = harness.upstream_manager.clone(); + let replay = harness.replay_checker.clone(); + let pool = harness.buffer_pool.clone(); + let rng = harness.rng.clone(); + let route = harness.route_runtime.clone(); + let ipt = harness.ip_tracker.clone(); + let beob = harness.beobachten.clone(); + let hello = hello.clone(); + let invalid_mtproto_record = invalid_mtproto_record.clone(); + let first_tail = first_tail.clone(); + + async move { + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.212:56012".parse().unwrap(), + config, + stats, + upstream, + replay, + pool, + rng, + None, + route, + None, + ipt, + beob, + false, + )); + + client_side.write_all(&hello).await.unwrap(); + if send_mtproto { + let mut head = [0u8; 5]; + client_side.read_exact(&mut head).await.unwrap(); + assert_eq!(head[0], 0x16); + read_tls_record_body(&mut client_side, head).await; + client_side.write_all(&invalid_mtproto_record).await.unwrap(); + client_side.write_all(&first_tail).await.unwrap(); + } else { + let mut one = [0u8; 1]; + let no_server_hello = tokio::time::timeout( + Duration::from_millis(300), + client_side.read_exact(&mut one), + ) + .await; + assert!(no_server_hello.is_err() || no_server_hello.unwrap().is_err()); + } + + client_side.shutdown().await.unwrap(); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + } + }; + + run_session(true).await; + run_session(false).await; + + tokio::time::timeout(Duration::from_secs(5), accept_task) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn connects_bad_increments_once_per_invalid_mtproto() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let harness = build_harness("c3c3c3c3c3c3c3c3c3c3c3c3c3c3c3c3", backend_addr.port()); + let stats = harness.stats.clone(); + let bad_before = stats.get_connects_bad(); + + let tail = wrap_tls_application_data(b"accounting"); + let expected_tail = tail.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = vec![0u8; expected_tail.len()]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, expected_tail); + }); + + let hello = make_valid_tls_client_hello(&[0xC3; 16], 3, 600, 0x42); + let invalid_mtproto_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); + + let (server_side, mut client_side) = duplex(131072); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.213:56013".parse().unwrap(), + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&hello).await.unwrap(); + let mut head = [0u8; 5]; + client_side.read_exact(&mut head).await.unwrap(); + read_tls_record_body(&mut client_side, head).await; + client_side.write_all(&invalid_mtproto_record).await.unwrap(); + client_side.write_all(&tail).await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + client_side.shutdown().await.unwrap(); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + + assert_eq!(stats.get_connects_bad(), bad_before + 1); +} + +#[tokio::test] +async fn truncated_clienthello_forwards_only_seen_prefix() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_unix_sock = None; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + + let expected_prefix_len = 5 + 17; + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = vec![0u8; expected_prefix_len]; + stream.read_exact(&mut got).await.unwrap(); + got + }); + + let (server_side, mut client_side) = duplex(65536); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.214:56014".parse().unwrap(), + config, + stats, + new_upstream_manager(Arc::new(Stats::new())), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + false, + )); + + let mut hello = vec![0u8; 5 + 17]; + hello[0] = 0x16; + hello[1] = 0x03; + hello[2] = 0x01; + hello[3..5].copy_from_slice(&600u16.to_be_bytes()); + hello[5..].fill(0x55); + + client_side.write_all(&hello).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let got = tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + assert_eq!(got, hello); + + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn out_of_bounds_tls_len_forwards_header_only() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_unix_sock = None; + + let config = Arc::new(cfg); + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = [0u8; 5]; + stream.read_exact(&mut got).await.unwrap(); + got + }); + + let (server_side, mut client_side) = duplex(8192); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.215:56015".parse().unwrap(), + config, + Arc::new(Stats::new()), + new_upstream_manager(Arc::new(Stats::new())), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + false, + )); + + let hdr = [0x16, 0x03, 0x01, 0x42, 0x69]; + client_side.write_all(&hdr).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let got = tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + assert_eq!(got, hdr); + + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn non_tls_with_modes_disabled_is_masked() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_unix_sock = None; + cfg.general.modes.classic = false; + cfg.general.modes.secure = false; + + let config = Arc::new(cfg); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = [0u8; 5]; + stream.read_exact(&mut got).await.unwrap(); + got + }); + + let (server_side, mut client_side) = duplex(8192); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.216:56016".parse().unwrap(), + config, + Arc::new(Stats::new()), + new_upstream_manager(Arc::new(Stats::new())), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + false, + )); + + let probe = *b"HELLO"; + client_side.write_all(&probe).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let got = tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + assert_eq!(got, probe); + + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn concurrent_tls_mtproto_fail_sessions_are_isolated() { + let sessions = 12usize; + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let mut expected = std::collections::HashSet::new(); + for idx in 0..sessions { + let payload = vec![idx as u8; 32 + idx]; + expected.insert(wrap_tls_application_data(&payload)); + } + + let accept_task = tokio::spawn(async move { + let mut remaining = expected; + for _ in 0..sessions { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut header = [0u8; 5]; + stream.read_exact(&mut header).await.unwrap(); + assert_eq!(header[0], TLS_RECORD_APPLICATION); + let len = u16::from_be_bytes([header[3], header[4]]) as usize; + let mut record = vec![0u8; 5 + len]; + record[..5].copy_from_slice(&header); + stream.read_exact(&mut record[5..]).await.unwrap(); + assert!(remaining.remove(&record)); + } + assert!(remaining.is_empty()); + }); + + let mut tasks = Vec::with_capacity(sessions); + for idx in 0..sessions { + let secret_hex = "c4c4c4c4c4c4c4c4c4c4c4c4c4c4c4c4"; + let harness = build_harness(secret_hex, backend_addr.port()); + let hello = make_valid_tls_client_hello(&[0xC4; 16], 20 + idx as u32, 600, 0x40 + idx as u8); + let invalid_mtproto = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); + let trailing = wrap_tls_application_data(&vec![idx as u8; 32 + idx]); + let peer: SocketAddr = format!("198.51.100.217:{}", 56100 + idx as u16) + .parse() + .unwrap(); + + tasks.push(tokio::spawn(async move { + let (server_side, mut client_side) = duplex(131072); + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&hello).await.unwrap(); + let mut head = [0u8; 5]; + client_side.read_exact(&mut head).await.unwrap(); + read_tls_record_body(&mut client_side, head).await; + client_side.write_all(&invalid_mtproto).await.unwrap(); + client_side.write_all(&trailing).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + })); + } + + for task in tasks { + task.await.unwrap(); + } + + tokio::time::timeout(Duration::from_secs(6), accept_task) + .await + .unwrap() + .unwrap(); +} + +macro_rules! tail_length_case { + ($name:ident, $hex:expr, $secret:expr, $ts:expr, $len:expr) => { + #[tokio::test] + async fn $name() { + let mut payload = vec![0u8; $len]; + for (i, b) in payload.iter_mut().enumerate() { + *b = (i as u8).wrapping_mul(17).wrapping_add(5); + } + let record = wrap_tls_application_data(&payload); + let got = run_tls_success_mtproto_fail_capture($hex, $secret, $ts, vec![record.clone()]).await; + assert_eq!(got, record); + } + }; +} + +tail_length_case!(tail_len_1_preserved, "d1d1d1d1d1d1d1d1d1d1d1d1d1d1d1d1", [0xD1; 16], 30, 1); +tail_length_case!(tail_len_2_preserved, "d2d2d2d2d2d2d2d2d2d2d2d2d2d2d2d2", [0xD2; 16], 31, 2); +tail_length_case!(tail_len_3_preserved, "d3d3d3d3d3d3d3d3d3d3d3d3d3d3d3d3", [0xD3; 16], 32, 3); +tail_length_case!(tail_len_7_preserved, "d4d4d4d4d4d4d4d4d4d4d4d4d4d4d4d4", [0xD4; 16], 33, 7); +tail_length_case!(tail_len_31_preserved, "d5d5d5d5d5d5d5d5d5d5d5d5d5d5d5d5", [0xD5; 16], 34, 31); +tail_length_case!(tail_len_127_preserved, "d6d6d6d6d6d6d6d6d6d6d6d6d6d6d6d6", [0xD6; 16], 35, 127); +tail_length_case!(tail_len_511_preserved, "d7d7d7d7d7d7d7d7d7d7d7d7d7d7d7d7", [0xD7; 16], 36, 511); +tail_length_case!(tail_len_1023_preserved, "d8d8d8d8d8d8d8d8d8d8d8d8d8d8d8d8", [0xD8; 16], 37, 1023); diff --git a/src/proxy/tests/client_masking_probe_evasion_blackhat_tests.rs b/src/proxy/tests/client_masking_probe_evasion_blackhat_tests.rs new file mode 100644 index 0000000..1208071 --- /dev/null +++ b/src/proxy/tests/client_masking_probe_evasion_blackhat_tests.rs @@ -0,0 +1,344 @@ +use super::*; +use crate::config::{UpstreamConfig, UpstreamType}; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; +use std::net::SocketAddr; +use std::sync::Arc; +use std::time::Duration; +use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; +use tokio::net::{TcpListener, TcpStream}; + +const REPLY_404: &[u8] = b"HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n"; + +fn make_test_upstream_manager(stats: Arc) -> Arc { + Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats, + )) +} + +fn masking_config(mask_port: u16) -> Arc { + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.timeouts.client_handshake = 1; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = mask_port; + cfg.censorship.mask_proxy_protocol = 0; + Arc::new(cfg) +} + +async fn run_generic_probe_and_capture_prefix(payload: Vec, expected_prefix: Vec) { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let reply = REPLY_404.to_vec(); + let prefix_len = expected_prefix.len(); + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = vec![0u8; prefix_len]; + stream.read_exact(&mut got).await.unwrap(); + stream.write_all(&reply).await.unwrap(); + got + }); + + let config = masking_config(backend_addr.port()); + let stats = Arc::new(Stats::new()); + let upstream_manager = make_test_upstream_manager(stats.clone()); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(4096); + let peer: SocketAddr = "203.0.113.210:55110".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + )); + + client_side.write_all(&payload).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let mut observed = vec![0u8; REPLY_404.len()]; + tokio::time::timeout(Duration::from_secs(2), client_side.read_exact(&mut observed)) + .await + .unwrap() + .unwrap(); + assert_eq!(observed, REPLY_404); + + let got = tokio::time::timeout(Duration::from_secs(2), accept_task) + .await + .unwrap() + .unwrap(); + assert_eq!(got, expected_prefix); + + let result = tokio::time::timeout(Duration::from_secs(2), handler) + .await + .unwrap() + .unwrap(); + assert!(result.is_ok()); +} + +async fn read_http_probe_header(stream: &mut TcpStream) -> Vec { + let mut out = Vec::with_capacity(96); + let mut one = [0u8; 1]; + + loop { + stream.read_exact(&mut one).await.unwrap(); + out.push(one[0]); + if out.ends_with(b"\r\n\r\n") { + break; + } + assert!( + out.len() <= 512, + "probe header exceeded sane limit while waiting for terminator" + ); + } + + out +} + +#[tokio::test] +async fn blackhat_fragmented_plain_http_probe_masks_and_preserves_prefix() { + let payload = b"GET /probe-evasion HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec(); + run_generic_probe_and_capture_prefix(payload.clone(), payload).await; +} + +#[tokio::test] +async fn blackhat_invalid_tls_like_probe_masks_and_preserves_header_prefix() { + let payload = vec![0x16, 0x03, 0x03, 0x00, 0x64, 0x01, 0x00]; + run_generic_probe_and_capture_prefix(payload.clone(), payload).await; +} + +#[tokio::test] +async fn integration_client_handler_plain_probe_masks_and_preserves_prefix() { + let mask_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = mask_listener.local_addr().unwrap(); + + let front_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let front_addr = front_listener.local_addr().unwrap(); + + let payload = b"GET /integration-probe HTTP/1.1\r\nHost: a.example\r\n\r\n".to_vec(); + let expected_prefix = payload.clone(); + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = mask_listener.accept().await.unwrap(); + let mut got = vec![0u8; expected_prefix.len()]; + stream.read_exact(&mut got).await.unwrap(); + stream.write_all(REPLY_404).await.unwrap(); + got + }); + + let config = masking_config(backend_addr.port()); + let stats = Arc::new(Stats::new()); + let upstream_manager = make_test_upstream_manager(stats.clone()); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let server_task = { + let config = config.clone(); + let stats = stats.clone(); + let upstream_manager = upstream_manager.clone(); + let replay_checker = replay_checker.clone(); + let buffer_pool = buffer_pool.clone(); + let rng = rng.clone(); + let route_runtime = route_runtime.clone(); + let ip_tracker = ip_tracker.clone(); + let beobachten = beobachten.clone(); + + tokio::spawn(async move { + let (stream, peer) = front_listener.accept().await.unwrap(); + let real_peer_report = Arc::new(std::sync::Mutex::new(None)); + ClientHandler::new( + stream, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + real_peer_report, + ) + .run() + .await + }) + }; + + let mut client = TcpStream::connect(front_addr).await.unwrap(); + client.write_all(&payload).await.unwrap(); + client.shutdown().await.unwrap(); + + let mut observed = vec![0u8; REPLY_404.len()]; + tokio::time::timeout(Duration::from_secs(2), client.read_exact(&mut observed)) + .await + .unwrap() + .unwrap(); + assert_eq!(observed, REPLY_404); + + let got = tokio::time::timeout(Duration::from_secs(2), accept_task) + .await + .unwrap() + .unwrap(); + assert_eq!(got, payload); + + let result = tokio::time::timeout(Duration::from_secs(2), server_task) + .await + .unwrap() + .unwrap(); + assert!(result.is_ok()); +} + +#[tokio::test] +async fn light_fuzz_small_probe_variants_always_mask_and_preserve_declared_prefix() { + let mut rng = StdRng::seed_from_u64(0xA11E_5EED_F0F0_CAFE); + + for i in 0..24usize { + let mut payload = if rng.random::() { + b"GET /fuzz HTTP/1.1\r\nHost: fuzz.example\r\n\r\n".to_vec() + } else { + vec![0x16, 0x03, 0x03, 0x00, 0x64] + }; + + let tail_len = rng.random_range(0..=8usize); + for _ in 0..tail_len { + payload.push(rng.random::()); + } + + let expected_prefix = payload.clone(); + run_generic_probe_and_capture_prefix(payload, expected_prefix).await; + + if i % 6 == 0 { + tokio::task::yield_now().await; + } + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_parallel_probe_mix_masks_all_sessions_without_cross_leakage() { + let session_count = 12usize; + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let mut expected = std::collections::HashSet::new(); + for idx in 0..session_count { + let probe = format!("GET /stress-{idx} HTTP/1.1\r\nHost: s{idx}.example\r\n\r\n").into_bytes(); + expected.insert(probe); + } + + let accept_task = tokio::spawn(async move { + let mut remaining = expected; + for _ in 0..session_count { + let (mut stream, _) = listener.accept().await.unwrap(); + let head = read_http_probe_header(&mut stream).await; + stream.write_all(REPLY_404).await.unwrap(); + assert!(remaining.remove(&head), "backend received unexpected or duplicated probe prefix"); + } + assert!(remaining.is_empty(), "all session prefixes must be observed exactly once"); + }); + + let mut tasks = Vec::with_capacity(session_count); + for idx in 0..session_count { + let config = masking_config(backend_addr.port()); + let stats = Arc::new(Stats::new()); + let upstream_manager = make_test_upstream_manager(stats.clone()); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let probe = format!("GET /stress-{idx} HTTP/1.1\r\nHost: s{idx}.example\r\n\r\n").into_bytes(); + let peer: SocketAddr = format!("203.0.113.{}:{}", 30 + idx, 56000 + idx) + .parse() + .unwrap(); + + tasks.push(tokio::spawn(async move { + let (server_side, mut client_side) = duplex(4096); + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + )); + + client_side.write_all(&probe).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let mut observed = vec![0u8; REPLY_404.len()]; + tokio::time::timeout(Duration::from_secs(2), client_side.read_exact(&mut observed)) + .await + .unwrap() + .unwrap(); + assert_eq!(observed, REPLY_404); + + let result = tokio::time::timeout(Duration::from_secs(2), handler) + .await + .unwrap() + .unwrap(); + assert!(result.is_ok()); + })); + } + + for task in tasks { + task.await.unwrap(); + } + + tokio::time::timeout(Duration::from_secs(4), accept_task) + .await + .unwrap() + .unwrap(); +} diff --git a/src/proxy/tests/client_masking_redteam_expected_fail_tests.rs b/src/proxy/tests/client_masking_redteam_expected_fail_tests.rs new file mode 100644 index 0000000..08d276d --- /dev/null +++ b/src/proxy/tests/client_masking_redteam_expected_fail_tests.rs @@ -0,0 +1,556 @@ +use super::*; +use crate::config::{UpstreamConfig, UpstreamType}; +use crate::crypto::sha256_hmac; +use crate::protocol::constants::{HANDSHAKE_LEN, TLS_VERSION}; +use crate::protocol::tls; +use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpListener; +use tokio::time::{Duration, Instant}; + +struct RedTeamHarness { + config: Arc, + stats: Arc, + upstream_manager: Arc, + replay_checker: Arc, + buffer_pool: Arc, + rng: Arc, + route_runtime: Arc, + ip_tracker: Arc, + beobachten: Arc, +} + +fn build_harness(secret_hex: &str, mask_port: u16) -> RedTeamHarness { + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = mask_port; + cfg.censorship.mask_proxy_protocol = 0; + cfg.access.ignore_time_skew = true; + cfg.access + .users + .insert("user".to_string(), secret_hex.to_string()); + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + + RedTeamHarness { + config, + stats, + upstream_manager, + replay_checker: Arc::new(ReplayChecker::new(256, Duration::from_secs(60))), + buffer_pool: Arc::new(BufferPool::new()), + rng: Arc::new(SecureRandom::new()), + route_runtime: Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + ip_tracker: Arc::new(UserIpTracker::new()), + beobachten: Arc::new(BeobachtenStore::new()), + } +} + +fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec { + assert!(tls_len <= u16::MAX as usize, "TLS length must fit into record header"); + + let total_len = 5 + tls_len; + let mut handshake = vec![fill; total_len]; + handshake[0] = 0x16; + handshake[1] = 0x03; + handshake[2] = 0x01; + handshake[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes()); + + let session_id_len: usize = 32; + handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8; + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + + let computed = sha256_hmac(secret, &handshake); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] + .copy_from_slice(&digest); + handshake +} + +fn wrap_tls_application_data(payload: &[u8]) -> Vec { + let mut record = Vec::with_capacity(5 + payload.len()); + record.push(0x17); + record.extend_from_slice(&TLS_VERSION); + record.extend_from_slice(&(payload.len() as u16).to_be_bytes()); + record.extend_from_slice(payload); + record +} + +async fn run_tls_success_mtproto_fail_session( + secret_hex: &str, + secret: [u8; 16], + timestamp: u32, + tail: Vec, +) -> Vec { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let harness = build_harness(secret_hex, backend_addr.port()); + let client_hello = make_valid_tls_client_hello(&secret, timestamp, 600, 0x42); + let invalid_mtproto_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); + let trailing_record = wrap_tls_application_data(&tail); + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = vec![0u8; trailing_record.len()]; + stream.read_exact(&mut got).await.unwrap(); + got + }); + + let (server_side, mut client_side) = duplex(262144); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.250:56900".parse().unwrap(), + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut head = [0u8; 5]; + client_side.read_exact(&mut head).await.unwrap(); + assert_eq!(head[0], 0x16); + let body_len = u16::from_be_bytes([head[3], head[4]]) as usize; + let mut body = vec![0u8; body_len]; + client_side.read_exact(&mut body).await.unwrap(); + + client_side.write_all(&invalid_mtproto_record).await.unwrap(); + client_side.write_all(&wrap_tls_application_data(&tail)).await.unwrap(); + + let forwarded = tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + + forwarded +} + +#[tokio::test] +#[ignore = "red-team expected-fail: demonstrates that post-TLS fallback still forwards data to backend"] +async fn redteam_01_backend_receives_no_data_after_mtproto_fail() { + let forwarded = run_tls_success_mtproto_fail_session( + "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", + [0xAA; 16], + 1, + b"probe-a".to_vec(), + ) + .await; + assert!(forwarded.is_empty(), "backend unexpectedly received fallback bytes"); +} + +#[tokio::test] +#[ignore = "red-team expected-fail: strict no-fallback policy hypothesis"] +async fn redteam_02_backend_must_never_receive_tls_records_after_mtproto_fail() { + let forwarded = run_tls_success_mtproto_fail_session( + "abababababababababababababababab", + [0xAB; 16], + 2, + b"probe-b".to_vec(), + ) + .await; + assert_ne!(forwarded[0], 0x17, "received TLS application record despite strict policy"); +} + +#[tokio::test] +#[ignore = "red-team expected-fail: impossible timing uniformity target"] +async fn redteam_03_masking_duration_must_be_less_than_1ms_when_backend_down() { + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = 1; + cfg.access.ignore_time_skew = true; + cfg.access + .users + .insert("user".to_string(), "acacacacacacacacacacacacacacacac".to_string()); + + let harness = RedTeamHarness { + config: Arc::new(cfg), + stats: Arc::new(Stats::new()), + upstream_manager: Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + Arc::new(Stats::new()), + )), + replay_checker: Arc::new(ReplayChecker::new(256, Duration::from_secs(60))), + buffer_pool: Arc::new(BufferPool::new()), + rng: Arc::new(SecureRandom::new()), + route_runtime: Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + ip_tracker: Arc::new(UserIpTracker::new()), + beobachten: Arc::new(BeobachtenStore::new()), + }; + + let hello = make_valid_tls_client_hello(&[0xAC; 16], 3, 600, 0x42); + let (server_side, mut client_side) = duplex(131072); + + let started = Instant::now(); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.251:56901".parse().unwrap(), + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&hello).await.unwrap(); + client_side.shutdown().await.unwrap(); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + + assert!(started.elapsed() < Duration::from_millis(1), "fallback path took longer than 1ms"); +} + +macro_rules! redteam_tail_must_not_forward_case { + ($name:ident, $hex:expr, $secret:expr, $ts:expr, $len:expr) => { + #[tokio::test] + #[ignore = "red-team expected-fail: strict no-forwarding hypothesis"] + async fn $name() { + let mut tail = vec![0u8; $len]; + for (i, b) in tail.iter_mut().enumerate() { + *b = (i as u8).wrapping_mul(31).wrapping_add(7); + } + let forwarded = run_tls_success_mtproto_fail_session($hex, $secret, $ts, tail).await; + assert!( + forwarded.is_empty(), + "strict model expects zero forwarded bytes, got {}", + forwarded.len() + ); + } + }; +} + +redteam_tail_must_not_forward_case!(redteam_04_tail_len_1_not_forwarded, "adadadadadadadadadadadadadadadad", [0xAD; 16], 4, 1); +redteam_tail_must_not_forward_case!(redteam_05_tail_len_2_not_forwarded, "aeaeaeaeaeaeaeaeaeaeaeaeaeaeaeae", [0xAE; 16], 5, 2); +redteam_tail_must_not_forward_case!(redteam_06_tail_len_3_not_forwarded, "afafafafafafafafafafafafafafafaf", [0xAF; 16], 6, 3); +redteam_tail_must_not_forward_case!(redteam_07_tail_len_7_not_forwarded, "b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0", [0xB0; 16], 7, 7); +redteam_tail_must_not_forward_case!(redteam_08_tail_len_15_not_forwarded, "b1b1b1b1b1b1b1b1b1b1b1b1b1b1b1b1", [0xB1; 16], 8, 15); +redteam_tail_must_not_forward_case!(redteam_09_tail_len_63_not_forwarded, "b2b2b2b2b2b2b2b2b2b2b2b2b2b2b2b2", [0xB2; 16], 9, 63); +redteam_tail_must_not_forward_case!(redteam_10_tail_len_127_not_forwarded, "b3b3b3b3b3b3b3b3b3b3b3b3b3b3b3b3", [0xB3; 16], 10, 127); +redteam_tail_must_not_forward_case!(redteam_11_tail_len_255_not_forwarded, "b4b4b4b4b4b4b4b4b4b4b4b4b4b4b4b4", [0xB4; 16], 11, 255); +redteam_tail_must_not_forward_case!(redteam_12_tail_len_511_not_forwarded, "b5b5b5b5b5b5b5b5b5b5b5b5b5b5b5b5", [0xB5; 16], 12, 511); +redteam_tail_must_not_forward_case!(redteam_13_tail_len_1023_not_forwarded, "b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6", [0xB6; 16], 13, 1023); +redteam_tail_must_not_forward_case!(redteam_14_tail_len_2047_not_forwarded, "b7b7b7b7b7b7b7b7b7b7b7b7b7b7b7b7", [0xB7; 16], 14, 2047); +redteam_tail_must_not_forward_case!(redteam_15_tail_len_4095_not_forwarded, "b8b8b8b8b8b8b8b8b8b8b8b8b8b8b8b8", [0xB8; 16], 15, 4095); + +#[tokio::test] +#[ignore = "red-team expected-fail: impossible indistinguishability envelope"] +async fn redteam_16_timing_delta_between_paths_must_be_sub_1ms_under_concurrency() { + let runs = 20usize; + let mut durations = Vec::with_capacity(runs); + + for i in 0..runs { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let secret = [0xB9u8; 16]; + let harness = build_harness("b9b9b9b9b9b9b9b9b9b9b9b9b9b9b9b9", backend_addr.port()); + let hello = make_valid_tls_client_hello(&secret, 100 + i as u32, 600, 0x42); + + let accept_task = tokio::spawn(async move { + let (_stream, _) = listener.accept().await.unwrap(); + }); + + let (server_side, mut client_side) = duplex(65536); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.252:56902".parse().unwrap(), + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + let started = Instant::now(); + client_side.write_all(&hello).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + let _ = tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + durations.push(started.elapsed()); + } + + let min = durations.iter().copied().min().unwrap(); + let max = durations.iter().copied().max().unwrap(); + assert!(max - min <= Duration::from_millis(1), "timing spread too wide for strict anti-probing envelope"); +} + +async fn measure_invalid_probe_duration_ms( + delay_ms: u64, + tls_len: u16, + body_sent: usize, +) -> u128 { + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = 1; + cfg.timeouts.client_handshake = 1; + cfg.censorship.server_hello_delay_min_ms = delay_ms; + cfg.censorship.server_hello_delay_max_ms = delay_ms; + + let (server_side, mut client_side) = duplex(65536); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.253:56903".parse().unwrap(), + Arc::new(cfg), + Arc::new(Stats::new()), + Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + Arc::new(Stats::new()), + )), + Arc::new(ReplayChecker::new(256, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + false, + )); + + let mut probe = vec![0u8; 5 + body_sent]; + probe[0] = 0x16; + probe[1] = 0x03; + probe[2] = 0x01; + probe[3..5].copy_from_slice(&tls_len.to_be_bytes()); + probe[5..].fill(0xD7); + + let started = Instant::now(); + client_side.write_all(&probe).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(4), handler) + .await + .unwrap() + .unwrap(); + + started.elapsed().as_millis() +} + +async fn capture_forwarded_probe_len(tls_len: u16, body_sent: usize) -> usize { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.timeouts.client_handshake = 1; + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = Vec::new(); + let _ = tokio::time::timeout(Duration::from_secs(2), stream.read_to_end(&mut got)).await; + got.len() + }); + + let (server_side, mut client_side) = duplex(65536); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.254:56904".parse().unwrap(), + Arc::new(cfg), + Arc::new(Stats::new()), + Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + Arc::new(Stats::new()), + )), + Arc::new(ReplayChecker::new(256, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + false, + )); + + let mut probe = vec![0u8; 5 + body_sent]; + probe[0] = 0x16; + probe[1] = 0x03; + probe[2] = 0x01; + probe[3..5].copy_from_slice(&tls_len.to_be_bytes()); + probe[5..].fill(0xBC); + + client_side.write_all(&probe).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(4), handler) + .await + .unwrap() + .unwrap(); + + tokio::time::timeout(Duration::from_secs(4), accept_task) + .await + .unwrap() + .unwrap() +} + +macro_rules! redteam_timing_envelope_case { + ($name:ident, $delay_ms:expr, $tls_len:expr, $body_sent:expr, $max_ms:expr) => { + #[tokio::test] + #[ignore = "red-team expected-fail: unrealistically tight reject timing envelope"] + async fn $name() { + let elapsed_ms = measure_invalid_probe_duration_ms($delay_ms, $tls_len, $body_sent).await; + assert!( + elapsed_ms <= $max_ms, + "timing envelope violated: elapsed={}ms, max={}ms", + elapsed_ms, + $max_ms + ); + } + }; +} + +macro_rules! redteam_constant_shape_case { + ($name:ident, $tls_len:expr, $body_sent:expr, $expected_len:expr) => { + #[tokio::test] + #[ignore = "red-team expected-fail: strict constant-shape backend fingerprint hypothesis"] + async fn $name() { + let got = capture_forwarded_probe_len($tls_len, $body_sent).await; + assert_eq!( + got, + $expected_len, + "fingerprint shape mismatch: got={} expected={} (strict constant-shape model)", + got, + $expected_len + ); + } + }; +} + +redteam_timing_envelope_case!(redteam_17_timing_env_very_tight_00, 700, 600, 0, 3); +redteam_timing_envelope_case!(redteam_18_timing_env_very_tight_01, 700, 600, 1, 3); +redteam_timing_envelope_case!(redteam_19_timing_env_very_tight_02, 700, 600, 7, 3); +redteam_timing_envelope_case!(redteam_20_timing_env_very_tight_03, 700, 600, 17, 3); +redteam_timing_envelope_case!(redteam_21_timing_env_very_tight_04, 700, 600, 31, 3); +redteam_timing_envelope_case!(redteam_22_timing_env_very_tight_05, 700, 600, 63, 3); +redteam_timing_envelope_case!(redteam_23_timing_env_very_tight_06, 700, 600, 127, 3); +redteam_timing_envelope_case!(redteam_24_timing_env_very_tight_07, 700, 600, 255, 3); +redteam_timing_envelope_case!(redteam_25_timing_env_very_tight_08, 700, 600, 511, 3); +redteam_timing_envelope_case!(redteam_26_timing_env_very_tight_09, 700, 600, 1023, 3); +redteam_timing_envelope_case!(redteam_27_timing_env_very_tight_10, 700, 600, 2047, 3); +redteam_timing_envelope_case!(redteam_28_timing_env_very_tight_11, 700, 600, 4095, 3); + +redteam_constant_shape_case!(redteam_29_constant_shape_00, 600, 0, 517); +redteam_constant_shape_case!(redteam_30_constant_shape_01, 600, 1, 517); +redteam_constant_shape_case!(redteam_31_constant_shape_02, 600, 7, 517); +redteam_constant_shape_case!(redteam_32_constant_shape_03, 600, 17, 517); +redteam_constant_shape_case!(redteam_33_constant_shape_04, 600, 31, 517); +redteam_constant_shape_case!(redteam_34_constant_shape_05, 600, 63, 517); +redteam_constant_shape_case!(redteam_35_constant_shape_06, 600, 127, 517); +redteam_constant_shape_case!(redteam_36_constant_shape_07, 600, 255, 517); +redteam_constant_shape_case!(redteam_37_constant_shape_08, 600, 511, 517); +redteam_constant_shape_case!(redteam_38_constant_shape_09, 600, 1023, 517); +redteam_constant_shape_case!(redteam_39_constant_shape_10, 600, 2047, 517); +redteam_constant_shape_case!(redteam_40_constant_shape_11, 600, 4095, 517); diff --git a/src/proxy/tests/client_masking_shape_classifier_fuzz_redteam_expected_fail_tests.rs b/src/proxy/tests/client_masking_shape_classifier_fuzz_redteam_expected_fail_tests.rs new file mode 100644 index 0000000..5b5344d --- /dev/null +++ b/src/proxy/tests/client_masking_shape_classifier_fuzz_redteam_expected_fail_tests.rs @@ -0,0 +1,245 @@ +use super::*; +use crate::config::{UpstreamConfig, UpstreamType}; +use std::sync::Arc; +use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpListener; +use tokio::time::Duration; + +fn new_upstream_manager(stats: Arc) -> Arc { + Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats, + )) +} + +async fn run_probe_capture( + body_sent: usize, + tls_len: u16, + enable_shape_hardening: bool, + floor: usize, + cap: usize, +) -> usize { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_shape_hardening = enable_shape_hardening; + cfg.censorship.mask_shape_bucket_floor_bytes = floor; + cfg.censorship.mask_shape_bucket_cap_bytes = cap; + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = Vec::new(); + let _ = tokio::time::timeout(Duration::from_secs(2), stream.read_to_end(&mut got)).await; + got.len() + }); + + let (server_side, mut client_side) = duplex(65536); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.214:57014".parse().unwrap(), + Arc::new(cfg), + Arc::new(Stats::new()), + new_upstream_manager(Arc::new(Stats::new())), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + false, + )); + + let mut probe = vec![0u8; 5 + body_sent]; + probe[0] = 0x16; + probe[1] = 0x03; + probe[2] = 0x01; + probe[3..5].copy_from_slice(&tls_len.to_be_bytes()); + probe[5..].fill(0x66); + + client_side.write_all(&probe).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(4), handler) + .await + .unwrap() + .unwrap(); + + tokio::time::timeout(Duration::from_secs(4), accept_task) + .await + .unwrap() + .unwrap() +} + +fn pearson_corr(xs: &[f64], ys: &[f64]) -> f64 { + if xs.len() != ys.len() || xs.is_empty() { + return 0.0; + } + + let n = xs.len() as f64; + let mean_x = xs.iter().sum::() / n; + let mean_y = ys.iter().sum::() / n; + + let mut cov = 0.0; + let mut var_x = 0.0; + let mut var_y = 0.0; + + for (&x, &y) in xs.iter().zip(ys.iter()) { + let dx = x - mean_x; + let dy = y - mean_y; + cov += dx * dy; + var_x += dx * dx; + var_y += dy * dy; + } + + if var_x == 0.0 || var_y == 0.0 { + return 0.0; + } + + cov / (var_x.sqrt() * var_y.sqrt()) +} + +fn lcg_sizes(count: usize, floor: usize, cap: usize) -> Vec { + let mut x = 0x9E3779B97F4A7C15u64; + let span = cap.saturating_mul(3); + let mut out = Vec::with_capacity(count + 8); + + for _ in 0..count { + x = x + .wrapping_mul(6364136223846793005) + .wrapping_add(1442695040888963407); + let v = (x as usize) % span.max(1); + out.push(v); + } + + // Inject edge and boundary-heavy probes. + out.extend_from_slice(&[ + 0, + floor.saturating_sub(1), + floor, + floor.saturating_add(1), + cap.saturating_sub(1), + cap, + cap.saturating_add(1), + cap.saturating_mul(2), + ]); + out +} + +async fn collect_distribution( + sizes: &[usize], + hardening: bool, + floor: usize, + cap: usize, +) -> Vec { + let mut out = Vec::with_capacity(sizes.len()); + for &body in sizes { + out.push(run_probe_capture(body, 1200, hardening, floor, cap).await); + } + out +} + +#[tokio::test] +#[ignore = "red-team expected-fail: strict decorrelation target for hardened output lengths"] +async fn redteam_fuzz_01_hardened_output_length_correlation_should_be_below_0_2() { + let floor = 512usize; + let cap = 4096usize; + let sizes = lcg_sizes(24, floor, cap); + + let hardened = collect_distribution(&sizes, true, floor, cap).await; + let x: Vec = sizes.iter().map(|v| *v as f64).collect(); + let y_hard: Vec = hardened.iter().map(|v| *v as f64).collect(); + + let corr_hard = pearson_corr(&x, &y_hard).abs(); + println!("redteam_fuzz corr_hardened={corr_hard:.4} samples={}", sizes.len()); + + assert!( + corr_hard < 0.2, + "strict model expects near-zero size correlation; observed corr={corr_hard:.4}" + ); +} + +#[tokio::test] +#[ignore = "red-team expected-fail: strict class-collapse ratio target"] +async fn redteam_fuzz_02_hardened_unique_output_ratio_should_be_below_5pct() { + let floor = 512usize; + let cap = 4096usize; + let sizes = lcg_sizes(24, floor, cap); + + let hardened = collect_distribution(&sizes, true, floor, cap).await; + + let in_unique = { + let mut s = std::collections::BTreeSet::new(); + for v in &sizes { + s.insert(*v); + } + s.len() + }; + + let out_unique = { + let mut s = std::collections::BTreeSet::new(); + for v in &hardened { + s.insert(*v); + } + s.len() + }; + + let ratio = out_unique as f64 / in_unique as f64; + println!( + "redteam_fuzz unique_ratio_hardened={ratio:.4} out_unique={} in_unique={}", + out_unique, in_unique + ); + + assert!( + ratio <= 0.05, + "strict model expects near-total collapse; observed ratio={ratio:.4}" + ); +} + +#[tokio::test] +#[ignore = "red-team expected-fail: strict separability improvement target"] +async fn redteam_fuzz_03_hardened_signal_must_be_10x_lower_than_plain() { + let floor = 512usize; + let cap = 4096usize; + let sizes = lcg_sizes(24, floor, cap); + + let plain = collect_distribution(&sizes, false, floor, cap).await; + let hardened = collect_distribution(&sizes, true, floor, cap).await; + + let x: Vec = sizes.iter().map(|v| *v as f64).collect(); + let y_plain: Vec = plain.iter().map(|v| *v as f64).collect(); + let y_hard: Vec = hardened.iter().map(|v| *v as f64).collect(); + + let corr_plain = pearson_corr(&x, &y_plain).abs(); + let corr_hard = pearson_corr(&x, &y_hard).abs(); + + println!( + "redteam_fuzz corr_plain={corr_plain:.4} corr_hardened={corr_hard:.4}" + ); + + assert!( + corr_hard <= corr_plain * 0.1, + "strict model expects 10x suppression; plain={corr_plain:.4} hardened={corr_hard:.4}" + ); +} diff --git a/src/proxy/tests/client_masking_shape_hardening_adversarial_tests.rs b/src/proxy/tests/client_masking_shape_hardening_adversarial_tests.rs new file mode 100644 index 0000000..6ce57b3 --- /dev/null +++ b/src/proxy/tests/client_masking_shape_hardening_adversarial_tests.rs @@ -0,0 +1,179 @@ +use super::*; +use crate::config::{UpstreamConfig, UpstreamType}; +use std::sync::Arc; +use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpListener; +use tokio::time::Duration; + +fn new_upstream_manager(stats: Arc) -> Arc { + Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats, + )) +} + +fn expected_bucket(total: usize, floor: usize, cap: usize) -> usize { + if total == 0 || floor == 0 || cap < floor { + return total; + } + + if total >= cap { + return total; + } + + let mut bucket = floor; + while bucket < total { + match bucket.checked_mul(2) { + Some(next) => bucket = next, + None => return total, + } + if bucket > cap { + return cap; + } + } + bucket +} + +async fn run_probe_capture( + body_sent: usize, + tls_len: u16, + enable_shape_hardening: bool, + floor: usize, + cap: usize, +) -> Vec { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_shape_hardening = enable_shape_hardening; + cfg.censorship.mask_shape_bucket_floor_bytes = floor; + cfg.censorship.mask_shape_bucket_cap_bytes = cap; + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = Vec::new(); + let _ = tokio::time::timeout(Duration::from_secs(2), stream.read_to_end(&mut got)).await; + got + }); + + let (server_side, mut client_side) = duplex(65536); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.199:56999".parse().unwrap(), + Arc::new(cfg), + Arc::new(Stats::new()), + new_upstream_manager(Arc::new(Stats::new())), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + false, + )); + + let mut probe = vec![0u8; 5 + body_sent]; + probe[0] = 0x16; + probe[1] = 0x03; + probe[2] = 0x01; + probe[3..5].copy_from_slice(&tls_len.to_be_bytes()); + probe[5..].fill(0x66); + + client_side.write_all(&probe).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let result = tokio::time::timeout(Duration::from_secs(4), handler) + .await + .unwrap() + .unwrap(); + assert!(result.is_ok()); + + tokio::time::timeout(Duration::from_secs(4), accept_task) + .await + .unwrap() + .unwrap() +} + +#[tokio::test] +async fn shape_hardening_non_power_of_two_cap_collapses_probe_classes() { + let floor = 1000usize; + let cap = 1500usize; + + let low = run_probe_capture(1195, 700, true, floor, cap).await; + let high = run_probe_capture(1494, 700, true, floor, cap).await; + + assert_eq!(low.len(), 1500); + assert_eq!(high.len(), 1500); +} + +#[tokio::test] +async fn shape_hardening_disabled_keeps_non_power_of_two_cap_lengths_distinct() { + let floor = 1000usize; + let cap = 1500usize; + + let low = run_probe_capture(1195, 700, false, floor, cap).await; + let high = run_probe_capture(1494, 700, false, floor, cap).await; + + assert_eq!(low.len(), 1200); + assert_eq!(high.len(), 1499); +} + +#[tokio::test] +async fn shape_hardening_parallel_stress_collapses_sub_cap_probes() { + let floor = 1000usize; + let cap = 1500usize; + let mut tasks = Vec::new(); + + for idx in 0..24usize { + let body = 1001 + (idx * 19 % 480); + tasks.push(tokio::spawn(async move { + run_probe_capture(body, 1200, true, floor, cap).await.len() + })); + } + + for task in tasks { + let observed = task.await.unwrap(); + assert_eq!(observed, 1500); + } +} + +#[tokio::test] +async fn shape_hardening_light_fuzz_matches_bucket_oracle() { + let floor = 512usize; + let cap = 4096usize; + + for step in 1usize..=36usize { + let total = 1 + (((step * 313) ^ (step << 7)) % (cap + 300)); + let body = total.saturating_sub(5); + + let got = run_probe_capture(body, 650, true, floor, cap).await; + let expected = expected_bucket(total, floor, cap); + assert_eq!( + got.len(), + expected, + "step={step} total={total} expected={expected} got={} ", + got.len() + ); + } +} diff --git a/src/proxy/tests/client_masking_shape_hardening_redteam_expected_fail_tests.rs b/src/proxy/tests/client_masking_shape_hardening_redteam_expected_fail_tests.rs new file mode 100644 index 0000000..a835d00 --- /dev/null +++ b/src/proxy/tests/client_masking_shape_hardening_redteam_expected_fail_tests.rs @@ -0,0 +1,236 @@ +use super::*; +use crate::config::{UpstreamConfig, UpstreamType}; +use std::sync::Arc; +use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpListener; +use tokio::time::{Duration, Instant}; + +fn new_upstream_manager(stats: Arc) -> Arc { + Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats, + )) +} + +async fn run_probe_capture( + body_sent: usize, + tls_len: u16, + enable_shape_hardening: bool, + floor: usize, + cap: usize, +) -> Vec { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_shape_hardening = enable_shape_hardening; + cfg.censorship.mask_shape_bucket_floor_bytes = floor; + cfg.censorship.mask_shape_bucket_cap_bytes = cap; + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = Vec::new(); + let _ = tokio::time::timeout(Duration::from_secs(2), stream.read_to_end(&mut got)).await; + got + }); + + let (server_side, mut client_side) = duplex(65536); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.211:57011".parse().unwrap(), + Arc::new(cfg), + Arc::new(Stats::new()), + new_upstream_manager(Arc::new(Stats::new())), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + false, + )); + + let mut probe = vec![0u8; 5 + body_sent]; + probe[0] = 0x16; + probe[1] = 0x03; + probe[2] = 0x01; + probe[3..5].copy_from_slice(&tls_len.to_be_bytes()); + probe[5..].fill(0x66); + + client_side.write_all(&probe).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(4), handler) + .await + .unwrap() + .unwrap(); + + tokio::time::timeout(Duration::from_secs(4), accept_task) + .await + .unwrap() + .unwrap() +} + +async fn measure_reject_ms(body_sent: usize) -> u128 { + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.timeouts.client_handshake = 1; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = 1; + cfg.censorship.server_hello_delay_min_ms = 700; + cfg.censorship.server_hello_delay_max_ms = 700; + + let (server_side, mut client_side) = duplex(65536); + let started = Instant::now(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.212:57012".parse().unwrap(), + Arc::new(cfg), + Arc::new(Stats::new()), + new_upstream_manager(Arc::new(Stats::new())), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + false, + )); + + let mut probe = vec![0u8; 5 + body_sent]; + probe[0] = 0x16; + probe[1] = 0x03; + probe[2] = 0x01; + probe[3..5].copy_from_slice(&600u16.to_be_bytes()); + probe[5..].fill(0x44); + + client_side.write_all(&probe).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(4), handler) + .await + .unwrap() + .unwrap(); + + started.elapsed().as_millis() +} + +#[tokio::test] +#[ignore = "red-team expected-fail: above-cap exact length still leaks classifier signal"] +async fn redteam_shape_01_above_cap_flows_should_collapse_to_single_class() { + let floor = 512usize; + let cap = 4096usize; + + let a = run_probe_capture(5000, 7000, true, floor, cap).await; + let b = run_probe_capture(6000, 7000, true, floor, cap).await; + + assert_eq!( + a.len(), + b.len(), + "strict anti-classifier model expects same backend length class above cap" + ); +} + +#[tokio::test] +#[ignore = "red-team expected-fail: current padding bytes are deterministic zeros"] +async fn redteam_shape_02_padding_tail_must_be_non_deterministic() { + let floor = 512usize; + let cap = 4096usize; + let got = run_probe_capture(17, 600, true, floor, cap).await; + + assert!( + got.len() > 22, + "test requires padding tail to exist" + ); + + let tail = &got[22..]; + assert!( + tail.iter().any(|b| *b != 0), + "padding tail is fully zeroed and thus deterministic" + ); +} + +#[tokio::test] +#[ignore = "red-team expected-fail: exact-floor probes still expose boundary class"] +async fn redteam_shape_03_exact_floor_input_should_not_be_fixed_point() { + let floor = 512usize; + let cap = 4096usize; + let got = run_probe_capture(507, 600, true, floor, cap).await; + + assert!( + got.len() > floor, + "strict model expects extra blur even when input lands exactly on floor" + ); +} + +#[tokio::test] +#[ignore = "red-team expected-fail: strict one-bucket collapse hypothesis"] +async fn redteam_shape_04_all_sub_cap_sizes_should_collapse_to_single_size() { + let floor = 512usize; + let cap = 4096usize; + let classes = [17usize, 63usize, 255usize, 511usize, 1023usize, 2047usize, 3071usize]; + + let mut observed = Vec::new(); + for body in classes { + observed.push(run_probe_capture(body, 1200, true, floor, cap).await.len()); + } + + let first = observed[0]; + for v in observed { + assert_eq!(v, first, "strict model expects one collapsed class across all sub-cap probes"); + } +} + +#[tokio::test] +#[ignore = "red-team expected-fail: over-strict micro-timing invariance"] +async fn redteam_shape_05_reject_timing_spread_should_be_under_2ms() { + let classes = [17usize, 511usize, 1023usize, 2047usize, 4095usize]; + let mut values = Vec::new(); + + for class in classes { + values.push(measure_reject_ms(class).await); + } + + let min = *values.iter().min().unwrap(); + let max = *values.iter().max().unwrap(); + assert!( + min == 700 && max == 700, + "strict model requires exact 700ms for every malformed class: min={min}ms max={max}ms" + ); +} + +#[test] +#[ignore = "red-team expected-fail: secure-by-default hypothesis"] +fn redteam_shape_06_shape_hardening_should_be_secure_by_default() { + let cfg = ProxyConfig::default(); + assert!( + cfg.censorship.mask_shape_hardening, + "strict model expects shape hardening enabled by default" + ); +} diff --git a/src/proxy/tests/client_masking_shape_hardening_security_tests.rs b/src/proxy/tests/client_masking_shape_hardening_security_tests.rs new file mode 100644 index 0000000..f9c0f17 --- /dev/null +++ b/src/proxy/tests/client_masking_shape_hardening_security_tests.rs @@ -0,0 +1,122 @@ +use super::*; +use crate::config::{UpstreamConfig, UpstreamType}; +use std::sync::Arc; +use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpListener; +use tokio::time::Duration; + +fn new_upstream_manager(stats: Arc) -> Arc { + Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats, + )) +} + +async fn run_probe_capture( + body_sent: usize, + tls_len: u16, + enable_shape_hardening: bool, + floor: usize, + cap: usize, +) -> Vec { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_shape_hardening = enable_shape_hardening; + cfg.censorship.mask_shape_bucket_floor_bytes = floor; + cfg.censorship.mask_shape_bucket_cap_bytes = cap; + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = Vec::new(); + let _ = tokio::time::timeout(Duration::from_secs(2), stream.read_to_end(&mut got)).await; + got + }); + + let (server_side, mut client_side) = duplex(65536); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.188:56888".parse().unwrap(), + Arc::new(cfg), + Arc::new(Stats::new()), + new_upstream_manager(Arc::new(Stats::new())), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + false, + )); + + let mut probe = vec![0u8; 5 + body_sent]; + probe[0] = 0x16; + probe[1] = 0x03; + probe[2] = 0x01; + probe[3..5].copy_from_slice(&tls_len.to_be_bytes()); + probe[5..].fill(0x66); + + client_side.write_all(&probe).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let result = tokio::time::timeout(Duration::from_secs(4), handler) + .await + .unwrap() + .unwrap(); + assert!(result.is_ok()); + + tokio::time::timeout(Duration::from_secs(4), accept_task) + .await + .unwrap() + .unwrap() +} + +#[tokio::test] +async fn shape_hardening_disabled_keeps_original_probe_length() { + let got = run_probe_capture(17, 600, false, 512, 4096).await; + assert_eq!(got.len(), 22); + assert_eq!(&got[..5], &[0x16, 0x03, 0x01, 0x02, 0x58]); +} + +#[tokio::test] +async fn shape_hardening_enabled_pads_small_probe_to_floor_bucket() { + let got = run_probe_capture(17, 600, true, 512, 4096).await; + assert_eq!(got.len(), 512); + assert_eq!(&got[..5], &[0x16, 0x03, 0x01, 0x02, 0x58]); +} + +#[tokio::test] +async fn shape_hardening_enabled_pads_mid_probe_to_next_bucket() { + let got = run_probe_capture(511, 600, true, 512, 4096).await; + assert_eq!(got.len(), 1024); + assert_eq!(&got[..5], &[0x16, 0x03, 0x01, 0x02, 0x58]); +} + +#[tokio::test] +async fn shape_hardening_respects_cap_and_avoids_padding_above_cap() { + let got = run_probe_capture(5000, 7000, true, 512, 4096).await; + assert_eq!(got.len(), 5005); + assert_eq!(&got[..5], &[0x16, 0x03, 0x01, 0x1b, 0x58]); +} diff --git a/src/proxy/tests/client_masking_stress_adversarial_tests.rs b/src/proxy/tests/client_masking_stress_adversarial_tests.rs new file mode 100644 index 0000000..52e7da1 --- /dev/null +++ b/src/proxy/tests/client_masking_stress_adversarial_tests.rs @@ -0,0 +1,254 @@ +use super::*; +use crate::config::{UpstreamConfig, UpstreamType}; +use crate::crypto::sha256_hmac; +use crate::protocol::constants::{HANDSHAKE_LEN, TLS_RECORD_APPLICATION, TLS_VERSION}; +use crate::protocol::tls; +use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpListener; +use tokio::time::Duration; + +struct StressHarness { + config: Arc, + stats: Arc, + upstream_manager: Arc, + replay_checker: Arc, + buffer_pool: Arc, + rng: Arc, + route_runtime: Arc, + ip_tracker: Arc, + beobachten: Arc, +} + +fn new_upstream_manager(stats: Arc) -> Arc { + Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats, + )) +} + +fn build_harness(mask_port: u16, secret_hex: &str) -> StressHarness { + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = mask_port; + cfg.censorship.mask_proxy_protocol = 0; + cfg.access.ignore_time_skew = true; + cfg.access + .users + .insert("user".to_string(), secret_hex.to_string()); + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + + StressHarness { + config, + stats: stats.clone(), + upstream_manager: new_upstream_manager(stats), + replay_checker: Arc::new(ReplayChecker::new(1024, Duration::from_secs(60))), + buffer_pool: Arc::new(BufferPool::new()), + rng: Arc::new(SecureRandom::new()), + route_runtime: Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + ip_tracker: Arc::new(UserIpTracker::new()), + beobachten: Arc::new(BeobachtenStore::new()), + } +} + +fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec { + assert!(tls_len <= u16::MAX as usize, "TLS length must fit into record header"); + + let total_len = 5 + tls_len; + let mut handshake = vec![fill; total_len]; + + handshake[0] = 0x16; + handshake[1] = 0x03; + handshake[2] = 0x01; + handshake[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes()); + + let session_id_len: usize = 32; + handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8; + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + let computed = sha256_hmac(secret, &handshake); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] + .copy_from_slice(&digest); + handshake +} + +fn wrap_tls_application_data(payload: &[u8]) -> Vec { + let mut record = Vec::with_capacity(5 + payload.len()); + record.push(TLS_RECORD_APPLICATION); + record.extend_from_slice(&TLS_VERSION); + record.extend_from_slice(&(payload.len() as u16).to_be_bytes()); + record.extend_from_slice(payload); + record +} + +async fn read_tls_record_body(stream: &mut T, header: [u8; 5]) +where + T: tokio::io::AsyncRead + Unpin, +{ + let len = u16::from_be_bytes([header[3], header[4]]) as usize; + let mut body = vec![0u8; len]; + stream.read_exact(&mut body).await.unwrap(); +} + +async fn run_parallel_tail_fallback_case( + sessions: usize, + payload_len: usize, + write_chunk: usize, + ts_base: u32, + peer_port_base: u16, +) { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let mut expected = std::collections::HashSet::new(); + for idx in 0..sessions { + let payload = vec![((idx * 37) & 0xff) as u8; payload_len + idx % 3]; + expected.insert(wrap_tls_application_data(&payload)); + } + + let accept_task = tokio::spawn(async move { + let mut remaining = expected; + for _ in 0..sessions { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut header = [0u8; 5]; + stream.read_exact(&mut header).await.unwrap(); + assert_eq!(header[0], TLS_RECORD_APPLICATION); + let len = u16::from_be_bytes([header[3], header[4]]) as usize; + let mut record = vec![0u8; 5 + len]; + record[..5].copy_from_slice(&header); + stream.read_exact(&mut record[5..]).await.unwrap(); + assert!(remaining.remove(&record)); + } + assert!(remaining.is_empty()); + }); + + let mut tasks = Vec::with_capacity(sessions); + + for idx in 0..sessions { + let harness = build_harness(backend_addr.port(), "e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0e0"); + let hello = make_valid_tls_client_hello( + &[0xE0; 16], + ts_base + idx as u32, + 600, + 0x40 + (idx as u8), + ); + + let invalid_mtproto = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); + let payload = vec![((idx * 37) & 0xff) as u8; payload_len + idx % 3]; + let trailing = wrap_tls_application_data(&payload); + // Keep source IPs unique across stress cases so global pre-auth probe state + // cannot contaminate unrelated sessions and make this test nondeterministic. + let peer_ip_third = 100 + ((ts_base as u8) / 10); + let peer_ip_fourth = (idx as u8).saturating_add(1); + let peer: SocketAddr = format!( + "198.51.{}.{}:{}", + peer_ip_third, + peer_ip_fourth, + peer_port_base + idx as u16 + ) + .parse() + .unwrap(); + + tasks.push(tokio::spawn(async move { + let (server_side, mut client_side) = duplex(262144); + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&hello).await.unwrap(); + let mut server_hello_head = [0u8; 5]; + client_side.read_exact(&mut server_hello_head).await.unwrap(); + assert_eq!(server_hello_head[0], 0x16); + read_tls_record_body(&mut client_side, server_hello_head).await; + + client_side.write_all(&invalid_mtproto).await.unwrap(); + for chunk in trailing.chunks(write_chunk.max(1)) { + client_side.write_all(chunk).await.unwrap(); + } + client_side.shutdown().await.unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(4), handler) + .await + .unwrap() + .unwrap(); + })); + } + + for task in tasks { + task.await.unwrap(); + } + + tokio::time::timeout(Duration::from_secs(8), accept_task) + .await + .unwrap() + .unwrap(); +} + +macro_rules! stress_case { + ($name:ident, $sessions:expr, $payload_len:expr, $chunk:expr, $ts:expr, $port:expr) => { + #[tokio::test] + async fn $name() { + run_parallel_tail_fallback_case($sessions, $payload_len, $chunk, $ts, $port).await; + } + }; +} + +stress_case!(stress_masking_parallel_s01, 4, 16, 1, 1000, 57000); +stress_case!(stress_masking_parallel_s02, 5, 24, 2, 1010, 57010); +stress_case!(stress_masking_parallel_s03, 6, 32, 3, 1020, 57020); +stress_case!(stress_masking_parallel_s04, 7, 40, 4, 1030, 57030); +stress_case!(stress_masking_parallel_s05, 8, 48, 5, 1040, 57040); +stress_case!(stress_masking_parallel_s06, 9, 56, 6, 1050, 57050); +stress_case!(stress_masking_parallel_s07, 10, 64, 7, 1060, 57060); +stress_case!(stress_masking_parallel_s08, 11, 72, 8, 1070, 57070); +stress_case!(stress_masking_parallel_s09, 12, 80, 9, 1080, 57080); +stress_case!(stress_masking_parallel_s10, 13, 88, 10, 1090, 57090); +stress_case!(stress_masking_parallel_s11, 6, 128, 11, 1100, 57100); +stress_case!(stress_masking_parallel_s12, 7, 160, 12, 1110, 57110); +stress_case!(stress_masking_parallel_s13, 8, 192, 13, 1120, 57120); +stress_case!(stress_masking_parallel_s14, 9, 224, 14, 1130, 57130); +stress_case!(stress_masking_parallel_s15, 10, 256, 15, 1140, 57140); +stress_case!(stress_masking_parallel_s16, 11, 288, 16, 1150, 57150); +stress_case!(stress_masking_parallel_s17, 12, 320, 17, 1160, 57160); +stress_case!(stress_masking_parallel_s18, 13, 352, 18, 1170, 57170); +stress_case!(stress_masking_parallel_s19, 14, 384, 19, 1180, 57180); +stress_case!(stress_masking_parallel_s20, 15, 416, 20, 1190, 57190); +stress_case!(stress_masking_parallel_s21, 16, 448, 21, 1200, 57200); +stress_case!(stress_masking_parallel_s22, 17, 480, 22, 1210, 57210); diff --git a/src/proxy/tests/client_security_tests.rs b/src/proxy/tests/client_security_tests.rs new file mode 100644 index 0000000..98e3cd1 --- /dev/null +++ b/src/proxy/tests/client_security_tests.rs @@ -0,0 +1,4436 @@ +use super::*; +use crate::config::{UpstreamConfig, UpstreamType}; +use crate::crypto::AesCtr; +use crate::crypto::sha256_hmac; +use crate::protocol::constants::ProtoTag; +use crate::protocol::tls; +use crate::proxy::handshake::HandshakeSuccess; +use crate::stream::{CryptoReader, CryptoWriter}; +use crate::transport::proxy_protocol::ProxyProtocolV1Builder; +use std::net::Ipv4Addr; +use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; +use tokio::net::{TcpListener, TcpStream}; + +#[test] +fn synthetic_local_addr_uses_configured_port_for_zero() { + let addr = synthetic_local_addr(0); + assert_eq!(addr.ip(), IpAddr::from([0, 0, 0, 0])); + assert_eq!(addr.port(), 0); +} + +#[test] +fn synthetic_local_addr_uses_configured_port_for_max() { + let addr = synthetic_local_addr(u16::MAX); + assert_eq!(addr.ip(), IpAddr::from([0, 0, 0, 0])); + assert_eq!(addr.port(), u16::MAX); +} + +fn make_crypto_reader(reader: R) -> CryptoReader +where + R: tokio::io::AsyncRead + Unpin, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoReader::new(reader, AesCtr::new(&key, iv)) +} + +fn make_crypto_writer(writer: W) -> CryptoWriter +where + W: tokio::io::AsyncWrite + Unpin, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) +} + +#[tokio::test] +async fn user_connection_reservation_drop_enqueues_cleanup_synchronously() { + let ip_tracker = Arc::new(crate::ip_tracker::UserIpTracker::new()); + let stats = Arc::new(crate::stats::Stats::new()); + let user = "sync-drop-user".to_string(); + let ip: std::net::IpAddr = "192.168.1.1".parse().unwrap(); + + ip_tracker.set_user_limit(&user, 1).await; + ip_tracker.check_and_add(&user, ip).await.unwrap(); + stats.increment_user_curr_connects(&user); + + assert_eq!(ip_tracker.get_active_ip_count(&user).await, 1); + assert_eq!(stats.get_user_curr_connects(&user), 1); + + let reservation = UserConnectionReservation::new(stats.clone(), ip_tracker.clone(), user.clone(), ip); + + // Drop the reservation synchronously without any tokio::spawn/await yielding! + drop(reservation); + + // The IP is now inside the cleanup_queue, check that the queue has length 1 + let queue_len = ip_tracker.cleanup_queue.lock().unwrap().len(); + assert_eq!(queue_len, 1, "Reservation drop must push directly to synchronized IP queue"); + + assert_eq!(stats.get_user_curr_connects(&user), 0, "Stats must decrement immediately"); + + ip_tracker.drain_cleanup_queue().await; + assert_eq!(ip_tracker.get_active_ip_count(&user).await, 0); +} + +#[tokio::test] +async fn relay_task_abort_releases_user_gate_and_ip_reservation() { + let tg_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let tg_addr = tg_listener.local_addr().unwrap(); + + let tg_accept_task = tokio::spawn(async move { + let (stream, _) = tg_listener.accept().await.unwrap(); + let _hold_stream = stream; + tokio::time::sleep(Duration::from_secs(60)).await; + }); + + let user = "abort-user"; + let peer_addr: SocketAddr = "198.51.100.230:50000".parse().unwrap(); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let mut cfg = ProxyConfig::default(); + cfg.access.user_max_tcp_conns.insert(user.to_string(), 8); + cfg.dc_overrides + .insert("2".to_string(), vec![tg_addr.to_string()]); + let config = Arc::new(cfg); + + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + + let (server_side, client_side) = duplex(64 * 1024); + let (server_reader, server_writer) = tokio::io::split(server_side); + let client_reader = make_crypto_reader(server_reader); + let client_writer = make_crypto_writer(server_writer); + + let success = HandshakeSuccess { + user: user.to_string(), + dc_idx: 2, + proto_tag: ProtoTag::Intermediate, + dec_key: [0u8; 32], + dec_iv: 0, + enc_key: [0u8; 32], + enc_iv: 0, + peer: peer_addr, + is_tls: false, + }; + + let relay_task = tokio::spawn(RunningClientHandler::handle_authenticated_static( + client_reader, + client_writer, + success, + upstream_manager, + stats.clone(), + config, + buffer_pool, + rng, + None, + route_runtime, + "127.0.0.1:443".parse().unwrap(), + peer_addr, + ip_tracker.clone(), + )); + + tokio::time::timeout(Duration::from_secs(2), async { + loop { + if stats.get_user_curr_connects(user) == 1 + && ip_tracker.get_active_ip_count(user).await == 1 + { + break; + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + }) + .await + .expect("relay must reserve user slot and IP before abort"); + + relay_task.abort(); + let joined = relay_task.await; + assert!(joined.is_err(), "aborted relay task must return join error"); + + tokio::time::sleep(Duration::from_millis(50)).await; + assert_eq!( + stats.get_user_curr_connects(user), + 0, + "task abort must release user current-connection slot" + ); + assert_eq!( + ip_tracker.get_active_ip_count(user).await, + 0, + "task abort must release reserved user IP footprint" + ); + + drop(client_side); + tg_accept_task.abort(); + let _ = tg_accept_task.await; +} + +#[tokio::test] +async fn relay_cutover_releases_user_gate_and_ip_reservation() { + let tg_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let tg_addr = tg_listener.local_addr().unwrap(); + + let tg_accept_task = tokio::spawn(async move { + let (stream, _) = tg_listener.accept().await.unwrap(); + let _hold_stream = stream; + tokio::time::sleep(Duration::from_secs(60)).await; + }); + + let user = "cutover-user"; + let peer_addr: SocketAddr = "198.51.100.231:50001".parse().unwrap(); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let mut cfg = ProxyConfig::default(); + cfg.access.user_max_tcp_conns.insert(user.to_string(), 8); + cfg.dc_overrides + .insert("2".to_string(), vec![tg_addr.to_string()]); + let config = Arc::new(cfg); + + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + + let (server_side, client_side) = duplex(64 * 1024); + let (server_reader, server_writer) = tokio::io::split(server_side); + let client_reader = make_crypto_reader(server_reader); + let client_writer = make_crypto_writer(server_writer); + + let success = HandshakeSuccess { + user: user.to_string(), + dc_idx: 2, + proto_tag: ProtoTag::Intermediate, + dec_key: [0u8; 32], + dec_iv: 0, + enc_key: [0u8; 32], + enc_iv: 0, + peer: peer_addr, + is_tls: false, + }; + + let relay_task = tokio::spawn(RunningClientHandler::handle_authenticated_static( + client_reader, + client_writer, + success, + upstream_manager, + stats.clone(), + config, + buffer_pool, + rng, + None, + route_runtime.clone(), + "127.0.0.1:443".parse().unwrap(), + peer_addr, + ip_tracker.clone(), + )); + + tokio::time::timeout(Duration::from_secs(2), async { + loop { + if stats.get_user_curr_connects(user) == 1 + && ip_tracker.get_active_ip_count(user).await == 1 + { + break; + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + }) + .await + .expect("relay must reserve user slot and IP before cutover"); + + assert!( + route_runtime.set_mode(RelayRouteMode::Middle).is_some(), + "cutover must advance route generation" + ); + + let relay_result = tokio::time::timeout(Duration::from_secs(6), relay_task) + .await + .expect("relay must terminate after cutover") + .expect("relay task must not panic"); + assert!(relay_result.is_err(), "cutover must terminate direct relay session"); + + assert_eq!( + stats.get_user_curr_connects(user), + 0, + "cutover exit must release user current-connection slot" + ); + assert_eq!( + ip_tracker.get_active_ip_count(user).await, + 0, + "cutover exit must release reserved user IP footprint" + ); + + drop(client_side); + tg_accept_task.abort(); + let _ = tg_accept_task.await; +} + +#[tokio::test] +async fn integration_route_cutover_and_quota_overlap_fails_closed_and_releases_state() { + let tg_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let tg_addr = tg_listener.local_addr().unwrap(); + + let tg_accept_task = tokio::spawn(async move { + let (mut stream, _) = tg_listener.accept().await.unwrap(); + stream.write_all(&[0x41, 0x42]).await.unwrap(); + tokio::time::sleep(Duration::from_secs(1)).await; + }); + + let user = "cutover-quota-overlap-user"; + let peer_addr: SocketAddr = "198.51.100.240:50010".parse().unwrap(); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let mut cfg = ProxyConfig::default(); + cfg.access.user_max_tcp_conns.insert(user.to_string(), 8); + cfg.access.user_data_quota.insert(user.to_string(), 1); + cfg.dc_overrides + .insert("2".to_string(), vec![tg_addr.to_string()]); + let config = Arc::new(cfg); + + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + + let (server_side, client_side) = duplex(64 * 1024); + let (server_reader, server_writer) = tokio::io::split(server_side); + let client_reader = make_crypto_reader(server_reader); + let client_writer = make_crypto_writer(server_writer); + + let success = HandshakeSuccess { + user: user.to_string(), + dc_idx: 2, + proto_tag: ProtoTag::Intermediate, + dec_key: [0u8; 32], + dec_iv: 0, + enc_key: [0u8; 32], + enc_iv: 0, + peer: peer_addr, + is_tls: false, + }; + + let relay_task = tokio::spawn(RunningClientHandler::handle_authenticated_static( + client_reader, + client_writer, + success, + upstream_manager, + stats.clone(), + config, + buffer_pool, + rng, + None, + route_runtime.clone(), + "127.0.0.1:443".parse().unwrap(), + peer_addr, + ip_tracker.clone(), + )); + + let observed_progress = tokio::time::timeout(Duration::from_secs(2), async { + loop { + if stats.get_user_curr_connects(user) >= 1 + || ip_tracker.get_active_ip_count(user).await >= 1 + || relay_task.is_finished() + { + return true; + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + }) + .await + .unwrap_or(false); + assert!( + observed_progress, + "overlap race test precondition must observe activation or bounded early termination" + ); + + tokio::time::sleep(Duration::from_millis(5)).await; + let _ = route_runtime.set_mode(RelayRouteMode::Middle); + + let relay_result = tokio::time::timeout(Duration::from_secs(3), relay_task) + .await + .expect("overlap race relay must terminate") + .expect("overlap race relay task must not panic"); + + assert!( + matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. })) + || matches!(relay_result, Err(ProxyError::Proxy(ref msg)) if msg == crate::proxy::route_mode::ROUTE_SWITCH_ERROR_MSG), + "overlap race must fail closed via quota enforcement or generic cutover termination" + ); + + assert_eq!( + stats.get_user_curr_connects(user), + 0, + "overlap race exit must release user current-connection slot" + ); + assert_eq!( + ip_tracker.get_active_ip_count(user).await, + 0, + "overlap race exit must release reserved user IP footprint" + ); + + drop(client_side); + tg_accept_task.abort(); + let _ = tg_accept_task.await; +} + +#[tokio::test] +async fn stress_drop_without_release_converges_to_zero_user_and_ip_state() { + let user = "gap-t05-drop-stress-user"; + let mut config = crate::config::ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 4096); + + let stats = std::sync::Arc::new(crate::stats::Stats::new()); + let ip_tracker = std::sync::Arc::new(crate::ip_tracker::UserIpTracker::new()); + + let mut reservations = Vec::new(); + for idx in 0..512u16 { + let peer = std::net::SocketAddr::new( + std::net::IpAddr::V4(std::net::Ipv4Addr::new(198, 51, (idx >> 8) as u8, (idx & 0xff) as u8)), + 30_000 + idx, + ); + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await + .expect("reservation acquisition must succeed in stress precondition"); + reservations.push(reservation); + } + + assert_eq!(stats.get_user_curr_connects(user), 512); + + for reservation in reservations { + std::thread::spawn(move || drop(reservation)) + .join() + .expect("drop thread must not panic"); + } + + tokio::time::timeout(std::time::Duration::from_secs(2), async { + loop { + if stats.get_user_curr_connects(user) == 0 + && ip_tracker.get_active_ip_count(user).await == 0 + { + break; + } + tokio::time::sleep(std::time::Duration::from_millis(5)).await; + } + }) + .await + .expect("drop-only path must eventually release all user/IP reservations"); +} + +#[tokio::test] +async fn proxy_protocol_header_is_rejected_when_trust_list_is_empty() { + let mut cfg = crate::config::ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.server.proxy_protocol_trusted_cidrs.clear(); + + let config = std::sync::Arc::new(cfg); + let stats = std::sync::Arc::new(crate::stats::Stats::new()); + let upstream_manager = std::sync::Arc::new(crate::transport::UpstreamManager::new( + vec![crate::config::UpstreamConfig { + upstream_type: crate::config::UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = std::sync::Arc::new(crate::stats::ReplayChecker::new(128, std::time::Duration::from_secs(60))); + let buffer_pool = std::sync::Arc::new(crate::stream::BufferPool::new()); + let rng = std::sync::Arc::new(crate::crypto::SecureRandom::new()); + let route_runtime = std::sync::Arc::new(crate::proxy::route_mode::RouteRuntimeController::new(crate::proxy::route_mode::RelayRouteMode::Direct)); + let ip_tracker = std::sync::Arc::new(crate::ip_tracker::UserIpTracker::new()); + let beobachten = std::sync::Arc::new(crate::stats::beobachten::BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(2048); + let peer: std::net::SocketAddr = "198.51.100.80:55000".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + true, + )); + + let proxy_header = ProxyProtocolV1Builder::new() + .tcp4( + "203.0.113.9:32000".parse().unwrap(), + "192.0.2.8:443".parse().unwrap(), + ) + .build(); + client_side.write_all(&proxy_header).await.unwrap(); + drop(client_side); + + let result = tokio::time::timeout(std::time::Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + assert!(matches!(result, Err(ProxyError::InvalidProxyProtocol))); +} + +#[tokio::test] +async fn proxy_protocol_header_from_untrusted_peer_range_is_rejected_under_load() { + let mut cfg = crate::config::ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.server.proxy_protocol_trusted_cidrs = vec!["10.0.0.0/8".parse().unwrap()]; + + let config = std::sync::Arc::new(cfg); + + for idx in 0..32u16 { + let stats = std::sync::Arc::new(crate::stats::Stats::new()); + let upstream_manager = std::sync::Arc::new(crate::transport::UpstreamManager::new( + vec![crate::config::UpstreamConfig { + upstream_type: crate::config::UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = std::sync::Arc::new(crate::stats::ReplayChecker::new(64, std::time::Duration::from_secs(60))); + let buffer_pool = std::sync::Arc::new(crate::stream::BufferPool::new()); + let rng = std::sync::Arc::new(crate::crypto::SecureRandom::new()); + let route_runtime = std::sync::Arc::new(crate::proxy::route_mode::RouteRuntimeController::new(crate::proxy::route_mode::RelayRouteMode::Direct)); + let ip_tracker = std::sync::Arc::new(crate::ip_tracker::UserIpTracker::new()); + let beobachten = std::sync::Arc::new(crate::stats::beobachten::BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(1024); + let peer = std::net::SocketAddr::new( + std::net::IpAddr::V4(std::net::Ipv4Addr::new(203, 0, 113, (idx + 1) as u8)), + 55_000 + idx, + ); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config.clone(), + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + true, + )); + + let proxy_header = ProxyProtocolV1Builder::new() + .tcp4( + "203.0.113.10:32000".parse().unwrap(), + "192.0.2.8:443".parse().unwrap(), + ) + .build(); + client_side.write_all(&proxy_header).await.unwrap(); + drop(client_side); + + let result = tokio::time::timeout(std::time::Duration::from_secs(2), handler) + .await + .unwrap() + .unwrap(); + assert!( + matches!(result, Err(ProxyError::InvalidProxyProtocol)), + "burst idx {idx}: untrusted source must be rejected" + ); + } +} + +#[tokio::test] +async fn reservation_limit_failure_does_not_leak_curr_connects_counter() { + let user = "leak-check-user"; + let mut config = crate::config::ProxyConfig::default(); + config.access.user_max_tcp_conns.insert(user.to_string(), 1); + + let stats = Arc::new(crate::stats::Stats::new()); + let ip_tracker = Arc::new(crate::ip_tracker::UserIpTracker::new()); + ip_tracker.set_user_limit(user, 8).await; + + let first_peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(198, 51, 200, 1)), 50001); + let first = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + first_peer, + ip_tracker.clone(), + ) + .await + .expect("first reservation must succeed"); + + assert_eq!(stats.get_user_curr_connects(user), 1); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 1); + + let second_peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(198, 51, 200, 2)), 50002); + let second = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + second_peer, + ip_tracker.clone(), + ) + .await; + + assert!( + matches!(second, Err(crate::error::ProxyError::ConnectionLimitExceeded { user: denied }) if denied == user), + "second reservation must be rejected at the configured tcp-conns limit" + ); + assert_eq!(stats.get_user_curr_connects(user), 1, "failed acquisition must not leak a counter increment"); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 1, "failed acquisition must not mutate IP tracker state"); + + first.release().await; + ip_tracker.drain_cleanup_queue().await; + + assert_eq!(stats.get_user_curr_connects(user), 0); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 0); +} + +#[tokio::test] +async fn short_tls_probe_is_masked_through_client_pipeline() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = vec![0x16, 0x03, 0x01, 0x00, 0x10]; + let backend_reply = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK".to_vec(); + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = vec![0u8; probe.len()]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, probe); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(4096); + let peer: SocketAddr = "203.0.113.77:55001".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + )); + + client_side.write_all(&probe).await.unwrap(); + let mut observed = vec![0u8; backend_reply.len()]; + client_side.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, backend_reply); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + accept_task.await.unwrap(); +} + +#[tokio::test] +async fn tls12_record_probe_is_masked_through_client_pipeline() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = vec![0x16, 0x03, 0x03, 0x00, 0x10]; + let backend_reply = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK".to_vec(); + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = vec![0u8; probe.len()]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, probe); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(4096); + let peer: SocketAddr = "203.0.113.78:55001".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + )); + + client_side.write_all(&probe).await.unwrap(); + let mut observed = vec![0u8; backend_reply.len()]; + client_side.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, backend_reply); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + accept_task.await.unwrap(); +} + +#[tokio::test] +async fn handle_client_stream_increments_connects_all_exactly_once() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = vec![0x16, 0x03, 0x01, 0x00, 0x10]; + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = vec![0u8; probe.len()]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, probe); + } + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let before = stats.get_connects_all(); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(4096); + let peer: SocketAddr = "203.0.113.177:55001".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats.clone(), + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + )); + + client_side.write_all(&probe).await.unwrap(); + drop(client_side); + + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + assert_eq!( + stats.get_connects_all(), + before + 1, + "handle_client_stream must increment connects_all exactly once" + ); +} + +#[tokio::test] +async fn running_client_handler_increments_connects_all_exactly_once() { + let mask_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = mask_listener.local_addr().unwrap(); + + let front_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let front_addr = front_listener.local_addr().unwrap(); + + let probe = [0x16, 0x03, 0x01, 0x00, 0x10]; + + let mask_accept_task = tokio::spawn(async move { + let (mut stream, _) = mask_listener.accept().await.unwrap(); + let mut got = [0u8; 5]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, probe); + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let before = stats.get_connects_all(); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let server_task = { + let config = config.clone(); + let stats = stats.clone(); + let upstream_manager = upstream_manager.clone(); + let replay_checker = replay_checker.clone(); + let buffer_pool = buffer_pool.clone(); + let rng = rng.clone(); + let route_runtime = route_runtime.clone(); + let ip_tracker = ip_tracker.clone(); + let beobachten = beobachten.clone(); + + tokio::spawn(async move { + let (stream, peer) = front_listener.accept().await.unwrap(); + let real_peer_report = Arc::new(std::sync::Mutex::new(None)); + ClientHandler::new( + stream, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + real_peer_report, + ) + .run() + .await + }) + }; + + let mut client = TcpStream::connect(front_addr).await.unwrap(); + client.write_all(&probe).await.unwrap(); + drop(client); + + let _ = tokio::time::timeout(Duration::from_secs(3), server_task) + .await + .unwrap() + .unwrap(); + tokio::time::timeout(Duration::from_secs(3), mask_accept_task) + .await + .unwrap() + .unwrap(); + + assert_eq!( + stats.get_connects_all(), + before + 1, + "ClientHandler::run must increment connects_all exactly once" + ); +} + +#[tokio::test] +async fn partial_tls_header_stall_triggers_handshake_timeout() { + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.timeouts.client_handshake = 1; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(4096); + let peer: SocketAddr = "198.51.100.170:55201".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + )); + + client_side + .write_all(&[0x16, 0x03, 0x01, 0x02, 0x00]) + .await + .unwrap(); + + let result = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + assert!(matches!(result, Err(ProxyError::TgHandshakeTimeout))); +} + +fn make_valid_tls_client_hello_with_len(secret: &[u8], timestamp: u32, tls_len: usize) -> Vec { + assert!(tls_len <= u16::MAX as usize, "TLS length must fit into record header"); + + let total_len = 5 + tls_len; + let mut handshake = vec![0x42u8; total_len]; + + handshake[0] = 0x16; + handshake[1] = 0x03; + handshake[2] = 0x01; + handshake[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes()); + + let session_id_len: usize = 32; + handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8; + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + let computed = sha256_hmac(secret, &handshake); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].copy_from_slice(&digest); + handshake +} + +fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32) -> Vec { + make_valid_tls_client_hello_with_len(secret, timestamp, 600) +} + +fn make_valid_tls_client_hello_with_alpn( + secret: &[u8], + timestamp: u32, + alpn_protocols: &[&[u8]], +) -> Vec { + let mut body = Vec::new(); + body.extend_from_slice(&TLS_VERSION); + body.extend_from_slice(&[0u8; 32]); + body.push(32); + body.extend_from_slice(&[0x42u8; 32]); + body.extend_from_slice(&2u16.to_be_bytes()); + body.extend_from_slice(&[0x13, 0x01]); + body.push(1); + body.push(0); + + let mut ext_blob = Vec::new(); + if !alpn_protocols.is_empty() { + let mut alpn_list = Vec::new(); + for proto in alpn_protocols { + alpn_list.push(proto.len() as u8); + alpn_list.extend_from_slice(proto); + } + + let mut alpn_data = Vec::new(); + alpn_data.extend_from_slice(&(alpn_list.len() as u16).to_be_bytes()); + alpn_data.extend_from_slice(&alpn_list); + + ext_blob.extend_from_slice(&0x0010u16.to_be_bytes()); + ext_blob.extend_from_slice(&(alpn_data.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&alpn_data); + } + + body.extend_from_slice(&(ext_blob.len() as u16).to_be_bytes()); + body.extend_from_slice(&ext_blob); + + let mut handshake = Vec::new(); + handshake.push(0x01); + let body_len = (body.len() as u32).to_be_bytes(); + handshake.extend_from_slice(&body_len[1..4]); + handshake.extend_from_slice(&body); + + let mut record = Vec::new(); + record.push(0x16); + record.extend_from_slice(&[0x03, 0x01]); + record.extend_from_slice(&(handshake.len() as u16).to_be_bytes()); + record.extend_from_slice(&handshake); + + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + let computed = sha256_hmac(secret, &record); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] + .copy_from_slice(&digest); + record +} + +fn wrap_tls_application_data(payload: &[u8]) -> Vec { + let mut record = Vec::with_capacity(5 + payload.len()); + record.push(0x17); + record.extend_from_slice(&[0x03, 0x03]); + record.extend_from_slice(&(payload.len() as u16).to_be_bytes()); + record.extend_from_slice(payload); + record +} + +#[tokio::test] +async fn valid_tls_path_does_not_fall_back_to_mask_backend() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x11u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 0); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + cfg.access.ignore_time_skew = true; + cfg.access + .users + .insert("user".to_string(), "11111111111111111111111111111111".to_string()); + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(131072); + let peer: SocketAddr = "198.51.100.80:55002".parse().unwrap(); + let stats_for_assert = stats.clone(); + let bad_before = stats_for_assert.get_connects_bad(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + + let mut record_header = [0u8; 5]; + client_side.read_exact(&mut record_header).await.unwrap(); + assert_eq!(record_header[0], 0x16); + + drop(client_side); + let handler_result = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + assert!(handler_result.is_err()); + + let no_mask_connect = tokio::time::timeout(Duration::from_millis(250), listener.accept()).await; + assert!( + no_mask_connect.is_err(), + "Mask backend must not be contacted on authenticated TLS path" + ); + + let bad_after = stats_for_assert.get_connects_bad(); + assert_eq!( + bad_before, + bad_after, + "Authenticated TLS path must not increment connects_bad" + ); +} + +#[tokio::test] +async fn valid_tls_with_invalid_mtproto_falls_back_to_mask_backend() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x33u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 0); + let invalid_mtproto = vec![0u8; crate::protocol::constants::HANDSHAKE_LEN]; + let tls_app_record = wrap_tls_application_data(&invalid_mtproto); + let trailing_tls_payload = b"still-tls-after-fallback".to_vec(); + let trailing_tls_record = wrap_tls_application_data(&trailing_tls_payload); + + let expected_trailing_tls_record = trailing_tls_record.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut trailing = vec![0u8; expected_trailing_tls_record.len()]; + stream.read_exact(&mut trailing).await.unwrap(); + assert_eq!(trailing, expected_trailing_tls_record); + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + cfg.access.ignore_time_skew = true; + cfg.access + .users + .insert("user".to_string(), "33333333333333333333333333333333".to_string()); + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(32768); + let peer: SocketAddr = "198.51.100.90:55111".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut tls_response_head = [0u8; 5]; + client_side.read_exact(&mut tls_response_head).await.unwrap(); + assert_eq!(tls_response_head[0], 0x16); + + client_side.write_all(&tls_app_record).await.unwrap(); + client_side.write_all(&trailing_tls_record).await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn client_handler_tls_bad_mtproto_is_forwarded_to_mask_backend() { + let mask_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = mask_listener.local_addr().unwrap(); + + let front_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let front_addr = front_listener.local_addr().unwrap(); + + let secret = [0x44u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 0); + let invalid_mtproto = vec![0u8; crate::protocol::constants::HANDSHAKE_LEN]; + let tls_app_record = wrap_tls_application_data(&invalid_mtproto); + let trailing_tls_payload = b"second-tls-record".to_vec(); + let trailing_tls_record = wrap_tls_application_data(&trailing_tls_payload); + + let expected_trailing_tls_record = trailing_tls_record.clone(); + let mask_accept_task = tokio::spawn(async move { + let (mut stream, _) = mask_listener.accept().await.unwrap(); + let mut trailing = vec![0u8; expected_trailing_tls_record.len()]; + stream.read_exact(&mut trailing).await.unwrap(); + assert_eq!(trailing, expected_trailing_tls_record); + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + cfg.access.ignore_time_skew = true; + cfg.access + .users + .insert("user".to_string(), "44444444444444444444444444444444".to_string()); + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let server_task = { + let config = config.clone(); + let stats = stats.clone(); + let upstream_manager = upstream_manager.clone(); + let replay_checker = replay_checker.clone(); + let buffer_pool = buffer_pool.clone(); + let rng = rng.clone(); + let route_runtime = route_runtime.clone(); + let ip_tracker = ip_tracker.clone(); + let beobachten = beobachten.clone(); + + tokio::spawn(async move { + let (stream, peer) = front_listener.accept().await.unwrap(); + let real_peer_report = Arc::new(std::sync::Mutex::new(None)); + ClientHandler::new( + stream, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + real_peer_report, + ) + .run() + .await + }) + }; + + let mut client = TcpStream::connect(front_addr).await.unwrap(); + client.write_all(&client_hello).await.unwrap(); + + let mut tls_response_head = [0u8; 5]; + client.read_exact(&mut tls_response_head).await.unwrap(); + assert_eq!(tls_response_head[0], 0x16); + + client.write_all(&tls_app_record).await.unwrap(); + client.write_all(&trailing_tls_record).await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), mask_accept_task) + .await + .unwrap() + .unwrap(); + + drop(client); + + let _ = tokio::time::timeout(Duration::from_secs(3), server_task) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn alpn_mismatch_tls_probe_is_masked_through_client_pipeline() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x66u8; 16]; + let probe = make_valid_tls_client_hello_with_alpn(&secret, 0, &[b"h3"]); + let backend_reply = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK".to_vec(); + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = vec![0u8; probe.len()]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, probe); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + cfg.censorship.alpn_enforce = true; + cfg.access.ignore_time_skew = true; + cfg.access + .users + .insert("user".to_string(), "66666666666666666666666666666666".to_string()); + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(8192); + let peer: SocketAddr = "198.51.100.66:55211".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + )); + + client_side.write_all(&probe).await.unwrap(); + let mut observed = vec![0u8; backend_reply.len()]; + client_side.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, backend_reply); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + accept_task.await.unwrap(); +} + +#[tokio::test] +async fn invalid_hmac_tls_probe_is_masked_through_client_pipeline() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x77u8; 16]; + let mut probe = make_valid_tls_client_hello(&secret, 0); + probe[tls::TLS_DIGEST_POS] ^= 0x01; + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = vec![0u8; probe.len()]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, probe); + } + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + cfg.access.ignore_time_skew = true; + cfg.access + .users + .insert("user".to_string(), "77777777777777777777777777777777".to_string()); + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(8192); + let peer: SocketAddr = "198.51.100.77:55212".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + )); + + client_side.write_all(&probe).await.unwrap(); + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn burst_invalid_tls_probes_are_masked_verbatim() { + const N: usize = 12; + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x88u8; 16]; + let mut probe = make_valid_tls_client_hello(&secret, 0); + probe[tls::TLS_DIGEST_POS + 1] ^= 0x01; + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + async move { + for _ in 0..N { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = vec![0u8; probe.len()]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, probe); + } + } + }); + + let mut handlers = Vec::with_capacity(N); + for i in 0..N { + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + cfg.access.ignore_time_skew = true; + cfg.access + .users + .insert("user".to_string(), "88888888888888888888888888888888".to_string()); + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(8192); + let peer: SocketAddr = format!("198.51.100.{}:{}", 100 + i, 56000 + i) + .parse() + .unwrap(); + let probe_bytes = probe.clone(); + + let h = tokio::spawn(async move { + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + )); + + client_side.write_all(&probe_bytes).await.unwrap(); + drop(client_side); + + tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap() + .unwrap(); + }); + handlers.push(h); + } + + for h in handlers { + tokio::time::timeout(Duration::from_secs(5), h) + .await + .unwrap() + .unwrap(); + } + + tokio::time::timeout(Duration::from_secs(5), accept_task) + .await + .unwrap() + .unwrap(); +} + +#[test] +fn unexpected_eof_is_classified_without_string_matching() { + let beobachten = BeobachtenStore::new(); + let mut config = ProxyConfig::default(); + config.general.beobachten = true; + config.general.beobachten_minutes = 1; + + let eof = ProxyError::Io(std::io::Error::from(std::io::ErrorKind::UnexpectedEof)); + let peer_ip: IpAddr = "198.51.100.200".parse().unwrap(); + + record_handshake_failure_class(&beobachten, &config, peer_ip, &eof); + + let snapshot = beobachten.snapshot_text(Duration::from_secs(60)); + assert!( + snapshot.contains("[expected_64_got_0]"), + "UnexpectedEof must be classified as expected_64_got_0" + ); + assert!( + snapshot.contains("198.51.100.200-1"), + "Classified record must include source IP" + ); +} + +#[test] +fn non_eof_error_is_classified_as_other() { + let beobachten = BeobachtenStore::new(); + let mut config = ProxyConfig::default(); + config.general.beobachten = true; + config.general.beobachten_minutes = 1; + + let non_eof = ProxyError::Io(std::io::Error::other("different error")); + let peer_ip: IpAddr = "203.0.113.201".parse().unwrap(); + + record_handshake_failure_class(&beobachten, &config, peer_ip, &non_eof); + + let snapshot = beobachten.snapshot_text(Duration::from_secs(60)); + assert!( + snapshot.contains("[other]"), + "Non-EOF errors must map to other" + ); + assert!( + snapshot.contains("203.0.113.201-1"), + "Classified record must include source IP" + ); + assert!( + !snapshot.contains("[expected_64_got_0]"), + "Non-EOF errors must not be misclassified as expected_64_got_0" + ); +} + +#[test] +fn beobachten_ttl_zero_minutes_is_floored_to_one_minute() { + let mut config = ProxyConfig::default(); + config.general.beobachten = true; + config.general.beobachten_minutes = 0; + + let ttl = beobachten_ttl(&config); + assert_eq!( + ttl, + Duration::from_secs(60), + "beobachten_minutes=0 must be fail-closed to a one-minute minimum TTL" + ); +} + +#[test] +fn beobachten_ttl_positive_minutes_remain_unchanged() { + let mut config = ProxyConfig::default(); + config.general.beobachten = true; + config.general.beobachten_minutes = 7; + + let ttl = beobachten_ttl(&config); + assert_eq!( + ttl, + Duration::from_secs(7 * 60), + "configured positive beobacten TTL must be preserved" + ); +} + +#[tokio::test] +async fn tcp_limit_rejection_does_not_reserve_ip_or_trigger_rollback() { + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert("user".to_string(), 1); + + let stats = Stats::new(); + stats.increment_user_curr_connects("user"); + + let ip_tracker = UserIpTracker::new(); + let peer_addr: SocketAddr = "198.51.100.210:50000".parse().unwrap(); + + let result = RunningClientHandler::check_user_limits_static( + "user", + &config, + &stats, + peer_addr, + &ip_tracker, + ) + .await; + + assert!(matches!( + result, + Err(ProxyError::ConnectionLimitExceeded { user }) if user == "user" + )); + assert_eq!( + ip_tracker.get_active_ip_count("user").await, + 0, + "Rejected client must not reserve IP slot" + ); + assert_eq!( + stats.get_ip_reservation_rollback_tcp_limit_total(), + 0, + "No rollback should occur when reservation is not taken" + ); +} + +#[tokio::test] +async fn zero_tcp_limit_rejects_without_ip_or_counter_side_effects() { + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert("user".to_string(), 0); + + let stats = Stats::new(); + let ip_tracker = UserIpTracker::new(); + let peer_addr: SocketAddr = "198.51.100.211:50001".parse().unwrap(); + + let result = RunningClientHandler::check_user_limits_static( + "user", + &config, + &stats, + peer_addr, + &ip_tracker, + ) + .await; + + assert!(matches!( + result, + Err(ProxyError::ConnectionLimitExceeded { user }) if user == "user" + )); + assert_eq!(stats.get_user_curr_connects("user"), 0); + assert_eq!(ip_tracker.get_active_ip_count("user").await, 0); +} + +#[tokio::test] +async fn check_user_limits_static_success_does_not_leak_counter_or_ip_reservation() { + let user = "check-helper-user"; + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 1); + + let stats = Stats::new(); + let ip_tracker = UserIpTracker::new(); + let peer_addr: SocketAddr = "198.51.100.212:50002".parse().unwrap(); + + let first = RunningClientHandler::check_user_limits_static( + user, + &config, + &stats, + peer_addr, + &ip_tracker, + ) + .await; + assert!(first.is_ok(), "first check-only limit validation must succeed"); + + let second = RunningClientHandler::check_user_limits_static( + user, + &config, + &stats, + peer_addr, + &ip_tracker, + ) + .await; + assert!(second.is_ok(), "second check-only validation must not fail from leaked state"); + assert_eq!(stats.get_user_curr_connects(user), 0); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 0); +} + +#[tokio::test] +async fn stress_check_user_limits_static_success_never_leaks_state() { + let user = "check-helper-stress-user"; + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 1); + + let stats = Stats::new(); + let ip_tracker = UserIpTracker::new(); + + for i in 0..4096u16 { + let peer_addr = SocketAddr::new( + IpAddr::V4(std::net::Ipv4Addr::new(198, 51, 110, (i % 250) as u8 + 1)), + 40000 + (i % 1024), + ); + + let result = RunningClientHandler::check_user_limits_static( + user, + &config, + &stats, + peer_addr, + &ip_tracker, + ) + .await; + assert!(result.is_ok(), "check-only helper must remain leak-free under stress"); + } + + assert_eq!( + stats.get_user_curr_connects(user), + 0, + "stress success loop must not leak user connection counters" + ); + assert_eq!( + ip_tracker.get_active_ip_count(user).await, + 0, + "stress success loop must not leak active IP reservations" + ); +} + +#[tokio::test] +async fn concurrent_distinct_ip_rejections_rollback_user_counter_without_leak() { + let user = "rollback-storm-user"; + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 128); + + let config = Arc::new(config); + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 1).await; + + let keeper_peer: SocketAddr = "198.51.100.212:50002".parse().unwrap(); + let keeper = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + keeper_peer, + ip_tracker.clone(), + ) + .await + .expect("keeper reservation must succeed"); + + let mut tasks = tokio::task::JoinSet::new(); + for i in 0..64u8 { + let config = config.clone(); + let stats = stats.clone(); + let ip_tracker = ip_tracker.clone(); + tasks.spawn(async move { + let peer = SocketAddr::new( + IpAddr::V4(std::net::Ipv4Addr::new(198, 51, 101, i.saturating_add(1))), + 41000 + i as u16, + ); + let result = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats, + peer, + ip_tracker, + ) + .await; + assert!(matches!( + result, + Err(ProxyError::ConnectionLimitExceeded { user }) if user == "rollback-storm-user" + )); + }); + } + + while let Some(joined) = tasks.join_next().await { + joined.unwrap(); + } + + assert_eq!( + stats.get_user_curr_connects(user), + 1, + "failed distinct-IP attempts must rollback acquired user slots" + ); + assert_eq!( + ip_tracker.get_active_ip_count(user).await, + 1, + "failed distinct-IP attempts must not leave extra active IPs" + ); + + keeper.release().await; + assert_eq!(stats.get_user_curr_connects(user), 0); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 0); +} + +#[tokio::test] +async fn explicit_reservation_release_cleans_user_and_ip_immediately() { + let user = "release-user"; + let peer_addr: SocketAddr = "198.51.100.240:50002".parse().unwrap(); + + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 4); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer_addr, + ip_tracker.clone(), + ) + .await + .expect("reservation acquisition must succeed"); + + assert_eq!(stats.get_user_curr_connects(user), 1); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 1); + + reservation.release().await; + + assert_eq!( + stats.get_user_curr_connects(user), + 0, + "explicit release must synchronously free user connection slot" + ); + assert_eq!( + ip_tracker.get_active_ip_count(user).await, + 0, + "explicit release must synchronously remove reserved user IP" + ); +} + +#[tokio::test] +async fn explicit_reservation_release_does_not_double_decrement_on_drop() { + let user = "release-once-user"; + let peer_addr: SocketAddr = "198.51.100.241:50003".parse().unwrap(); + + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 4); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer_addr, + ip_tracker, + ) + .await + .expect("reservation acquisition must succeed"); + + reservation.release().await; + + assert_eq!( + stats.get_user_curr_connects(user), + 0, + "release must disarm drop and prevent double decrement" + ); +} + +#[tokio::test] +async fn drop_fallback_eventually_cleans_user_and_ip_reservation() { + let user = "drop-fallback-user"; + let peer_addr: SocketAddr = "198.51.100.242:50004".parse().unwrap(); + + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 4); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer_addr, + ip_tracker.clone(), + ) + .await + .expect("reservation acquisition must succeed"); + + assert_eq!(stats.get_user_curr_connects(user), 1); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 1); + + drop(reservation); + + tokio::time::timeout(Duration::from_secs(1), async { + loop { + if stats.get_user_curr_connects(user) == 0 + && ip_tracker.get_active_ip_count(user).await == 0 + { + break; + } + tokio::time::sleep(Duration::from_millis(5)).await; + } + }) + .await + .expect("drop fallback must eventually clean both user slot and active IP"); +} + +#[tokio::test] +async fn explicit_release_allows_immediate_cross_ip_reacquire_under_limit() { + let user = "cross-ip-user"; + let peer1: SocketAddr = "198.51.100.243:50005".parse().unwrap(); + let peer2: SocketAddr = "198.51.100.244:50006".parse().unwrap(); + + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 4); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 1).await; + + let first = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer1, + ip_tracker.clone(), + ) + .await + .expect("first reservation must succeed"); + first.release().await; + + let second = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer2, + ip_tracker.clone(), + ) + .await + .expect("second reservation must succeed immediately after explicit release"); + second.release().await; + + assert_eq!(stats.get_user_curr_connects(user), 0); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 0); +} + +#[tokio::test] +async fn release_abort_storm_does_not_leak_user_or_ip_reservations() { + const ATTEMPTS: usize = 256; + + let user = "release-abort-storm-user"; + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), ATTEMPTS + 16); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + for idx in 0..ATTEMPTS { + let peer = SocketAddr::new( + IpAddr::V4(std::net::Ipv4Addr::new(203, 0, 114, (idx % 250 + 1) as u8)), + 52000 + idx as u16, + ); + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await + .expect("reservation acquisition must succeed in abort storm"); + + let release_task = tokio::spawn(async move { + reservation.release().await; + }); + release_task.abort(); + let _ = release_task.await; + } + + tokio::time::timeout(Duration::from_secs(1), async { + loop { + if stats.get_user_curr_connects(user) == 0 + && ip_tracker.get_active_ip_count(user).await == 0 + { + break; + } + tokio::task::yield_now().await; + tokio::time::sleep(Duration::from_millis(2)).await; + } + }) + .await + .expect("release abort storm must not leak user slots or active IP entries"); +} + +#[tokio::test] +async fn release_abort_loop_preserves_immediate_same_ip_reacquire() { + const ITERATIONS: usize = 128; + + let user = "release-abort-reacquire-user"; + let peer: SocketAddr = "198.51.100.246:53001".parse().unwrap(); + + let mut config = ProxyConfig::default(); + config.access.user_max_tcp_conns.insert(user.to_string(), 1); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + for _ in 0..ITERATIONS { + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await + .expect("baseline acquisition must succeed"); + + let release_task = tokio::spawn(async move { + reservation.release().await; + }); + release_task.abort(); + let _ = release_task.await; + + tokio::time::timeout(Duration::from_secs(1), async { + loop { + if stats.get_user_curr_connects(user) == 0 + && ip_tracker.get_active_ip_count(user).await == 0 + { + break; + } + tokio::task::yield_now().await; + tokio::time::sleep(Duration::from_millis(2)).await; + } + }) + .await + .expect("aborted release must still converge to zero footprint"); + } + + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await + .expect("same-ip reacquire must succeed after repeated abort-release churn"); + reservation.release().await; +} + +#[tokio::test] +async fn adversarial_mixed_release_drop_abort_wave_converges_to_zero() { + const RESERVATIONS: usize = 192; + + let user = "mixed-wave-user"; + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), RESERVATIONS + 8); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let mut reservations = Vec::with_capacity(RESERVATIONS); + for idx in 0..RESERVATIONS { + let peer = SocketAddr::new( + IpAddr::V4(std::net::Ipv4Addr::new(203, 0, 115, (idx % 250 + 1) as u8)), + 54000 + idx as u16, + ); + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await + .expect("mixed-wave acquisition must succeed"); + reservations.push(reservation); + } + + let mut seed: u64 = 0xDEAD_BEEF_CAFE_BA5E; + let mut join_set = tokio::task::JoinSet::new(); + for reservation in reservations { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + match seed % 3 { + 0 => { + join_set.spawn(async move { + reservation.release().await; + }); + } + 1 => { + drop(reservation); + } + _ => { + let task = tokio::spawn(async move { + reservation.release().await; + }); + task.abort(); + let _ = task.await; + } + } + } + + while let Some(result) = join_set.join_next().await { + result.expect("release subtask must not panic"); + } + + tokio::time::timeout(Duration::from_secs(2), async { + loop { + if stats.get_user_curr_connects(user) == 0 + && ip_tracker.get_active_ip_count(user).await == 0 + { + break; + } + tokio::task::yield_now().await; + tokio::time::sleep(Duration::from_millis(2)).await; + } + }) + .await + .expect("mixed release/drop/abort wave must converge to zero footprint"); +} + +#[tokio::test] +async fn parallel_users_abort_release_isolation_preserves_independent_cleanup() { + let user_a = "abort-isolation-a"; + let user_b = "abort-isolation-b"; + + let mut config = ProxyConfig::default(); + config.access.user_max_tcp_conns.insert(user_a.to_string(), 64); + config.access.user_max_tcp_conns.insert(user_b.to_string(), 64); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let mut tasks = tokio::task::JoinSet::new(); + for idx in 0..64usize { + let user = if idx % 2 == 0 { user_a } else { user_b }; + let peer = SocketAddr::new( + IpAddr::V4(std::net::Ipv4Addr::new(198, 18, 0, (idx % 250 + 1) as u8)), + 55000 + idx as u16, + ); + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await + .expect("parallel-user acquisition must succeed"); + + tasks.spawn(async move { + let t = tokio::spawn(async move { + reservation.release().await; + }); + t.abort(); + let _ = t.await; + }); + } + + while let Some(result) = tasks.join_next().await { + result.expect("parallel-user abort task must not panic"); + } + + tokio::time::timeout(Duration::from_secs(2), async { + loop { + if stats.get_user_curr_connects(user_a) == 0 + && stats.get_user_curr_connects(user_b) == 0 + && ip_tracker.get_active_ip_count(user_a).await == 0 + && ip_tracker.get_active_ip_count(user_b).await == 0 + { + break; + } + tokio::task::yield_now().await; + tokio::time::sleep(Duration::from_millis(2)).await; + } + }) + .await + .expect("parallel users must cleanup independently under abort churn"); +} + +#[tokio::test] +async fn concurrent_release_storm_leaves_zero_user_and_ip_footprint() { + const RESERVATIONS: usize = 64; + + let user = "release-storm-user"; + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), RESERVATIONS + 8); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let mut reservations = Vec::with_capacity(RESERVATIONS); + for idx in 0..RESERVATIONS { + let ip = std::net::Ipv4Addr::new(203, 0, 113, (idx + 1) as u8); + let peer = SocketAddr::new(IpAddr::V4(ip), 51000 + idx as u16); + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await + .expect("reservation acquisition in storm must succeed"); + reservations.push(reservation); + } + + assert_eq!(stats.get_user_curr_connects(user), RESERVATIONS as u64); + assert_eq!(ip_tracker.get_active_ip_count(user).await, RESERVATIONS); + + let mut tasks = tokio::task::JoinSet::new(); + for reservation in reservations { + tasks.spawn(async move { + reservation.release().await; + }); + } + + while let Some(result) = tasks.join_next().await { + result.expect("release task must not panic"); + } + + assert_eq!( + stats.get_user_curr_connects(user), + 0, + "release storm must drain user current-connection counter to zero" + ); + assert_eq!( + ip_tracker.get_active_ip_count(user).await, + 0, + "release storm must clear all active IP entries" + ); +} + +#[tokio::test] +async fn relay_connect_error_releases_user_and_ip_before_return() { + let user = "relay-error-user"; + let peer_addr: SocketAddr = "198.51.100.245:50007".parse().unwrap(); + + let dead_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let dead_port = dead_listener.local_addr().unwrap().port(); + drop(dead_listener); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 1); + config + .dc_overrides + .insert("2".to_string(), vec![format!("127.0.0.1:{dead_port}")]); + let config = Arc::new(config); + + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + + let (server_side, _client_side) = duplex(64 * 1024); + let (server_reader, server_writer) = tokio::io::split(server_side); + let client_reader = make_crypto_reader(server_reader); + let client_writer = make_crypto_writer(server_writer); + + let success = HandshakeSuccess { + user: user.to_string(), + dc_idx: 2, + proto_tag: ProtoTag::Intermediate, + dec_key: [0u8; 32], + dec_iv: 0, + enc_key: [0u8; 32], + enc_iv: 0, + peer: peer_addr, + is_tls: false, + }; + + let result = RunningClientHandler::handle_authenticated_static( + client_reader, + client_writer, + success, + upstream_manager, + stats.clone(), + config, + buffer_pool, + rng, + None, + route_runtime, + "127.0.0.1:443".parse().unwrap(), + peer_addr, + ip_tracker.clone(), + ) + .await; + + assert!(result.is_err(), "relay must fail when upstream DC is unreachable"); + assert_eq!( + stats.get_user_curr_connects(user), + 0, + "error return must release user slot before returning" + ); + assert_eq!( + ip_tracker.get_active_ip_count(user).await, + 0, + "error return must release user IP reservation before returning" + ); +} + +#[tokio::test] +async fn mixed_release_and_drop_same_ip_preserves_counter_correctness() { + let user = "same-ip-mixed-user"; + let peer_addr: SocketAddr = "198.51.100.246:50008".parse().unwrap(); + + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 8); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let reservation_a = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer_addr, + ip_tracker.clone(), + ) + .await + .expect("first reservation must succeed"); + let reservation_b = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer_addr, + ip_tracker.clone(), + ) + .await + .expect("second reservation must succeed"); + + assert_eq!(stats.get_user_curr_connects(user), 2); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 1); + + reservation_a.release().await; + assert_eq!( + stats.get_user_curr_connects(user), + 1, + "explicit release must decrement only one active reservation" + ); + assert_eq!( + ip_tracker.get_active_ip_count(user).await, + 1, + "same IP must remain active while second reservation exists" + ); + + drop(reservation_b); + tokio::time::timeout(Duration::from_secs(1), async { + loop { + if stats.get_user_curr_connects(user) == 0 + && ip_tracker.get_active_ip_count(user).await == 0 + { + break; + } + tokio::time::sleep(Duration::from_millis(5)).await; + } + }) + .await + .expect("drop fallback must clear final same-IP reservation"); +} + +#[tokio::test] +async fn drop_one_of_two_same_ip_reservations_keeps_ip_active() { + let user = "same-ip-drop-one-user"; + let peer_addr: SocketAddr = "198.51.100.247:50009".parse().unwrap(); + + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 8); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let reservation_a = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer_addr, + ip_tracker.clone(), + ) + .await + .expect("first reservation must succeed"); + let reservation_b = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer_addr, + ip_tracker.clone(), + ) + .await + .expect("second reservation must succeed"); + + drop(reservation_a); + tokio::time::timeout(Duration::from_secs(1), async { + loop { + if stats.get_user_curr_connects(user) == 1 + && ip_tracker.get_active_ip_count(user).await == 1 + { + break; + } + tokio::time::sleep(Duration::from_millis(5)).await; + } + }) + .await + .expect("dropping one reservation must keep same-IP activity for remaining reservation"); + + reservation_b.release().await; + tokio::time::timeout(Duration::from_secs(1), async { + loop { + if stats.get_user_curr_connects(user) == 0 + && ip_tracker.get_active_ip_count(user).await == 0 + { + break; + } + tokio::time::sleep(Duration::from_millis(5)).await; + } + }) + .await + .expect("final release must converge to zero footprint after async fallback cleanup"); +} + +#[tokio::test] +async fn quota_rejection_does_not_reserve_ip_or_trigger_rollback() { + let mut config = ProxyConfig::default(); + config.access.user_data_quota.insert("user".to_string(), 1024); + + let stats = Stats::new(); + stats.add_user_octets_from("user", 1024); + + let ip_tracker = UserIpTracker::new(); + let peer_addr: SocketAddr = "203.0.113.211:50001".parse().unwrap(); + + let result = RunningClientHandler::check_user_limits_static( + "user", + &config, + &stats, + peer_addr, + &ip_tracker, + ) + .await; + + assert!(matches!( + result, + Err(ProxyError::DataQuotaExceeded { user }) if user == "user" + )); + assert_eq!( + ip_tracker.get_active_ip_count("user").await, + 0, + "Quota-rejected client must not reserve IP slot" + ); + assert_eq!( + stats.get_ip_reservation_rollback_quota_limit_total(), + 0, + "No rollback should occur when reservation is not taken" + ); +} + +#[tokio::test] +async fn expired_user_rejection_does_not_reserve_ip_or_increment_curr_connects() { + let mut config = ProxyConfig::default(); + config + .access + .user_expirations + .insert("user".to_string(), chrono::Utc::now() - chrono::Duration::seconds(1)); + + let stats = Stats::new(); + let ip_tracker = UserIpTracker::new(); + let peer_addr: SocketAddr = "203.0.113.212:50002".parse().unwrap(); + + let result = RunningClientHandler::check_user_limits_static( + "user", + &config, + &stats, + peer_addr, + &ip_tracker, + ) + .await; + + assert!(matches!( + result, + Err(ProxyError::UserExpired { user }) if user == "user" + )); + assert_eq!(stats.get_user_curr_connects("user"), 0); + assert_eq!(ip_tracker.get_active_ip_count("user").await, 0); +} + +#[tokio::test] +async fn same_ip_second_reservation_succeeds_under_unique_ip_limit_one() { + let user = "same-ip-unique-limit-user"; + let peer_addr: SocketAddr = "198.51.100.248:50010".parse().unwrap(); + + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 8); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 1).await; + + let first = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer_addr, + ip_tracker.clone(), + ) + .await + .expect("first reservation must succeed"); + let second = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer_addr, + ip_tracker.clone(), + ) + .await + .expect("second reservation from same IP must succeed under unique-ip limit=1"); + + assert_eq!(stats.get_user_curr_connects(user), 2); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 1); + + first.release().await; + second.release().await; + assert_eq!(stats.get_user_curr_connects(user), 0); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 0); +} + +#[tokio::test] +async fn second_distinct_ip_is_rejected_under_unique_ip_limit_one() { + let user = "distinct-ip-unique-limit-user"; + let peer1: SocketAddr = "198.51.100.249:50011".parse().unwrap(); + let peer2: SocketAddr = "198.51.100.250:50012".parse().unwrap(); + + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 8); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 1).await; + + let first = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer1, + ip_tracker.clone(), + ) + .await + .expect("first reservation must succeed"); + + let second = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer2, + ip_tracker.clone(), + ) + .await; + + assert!(matches!( + second, + Err(ProxyError::ConnectionLimitExceeded { user }) if user == "distinct-ip-unique-limit-user" + )); + assert_eq!(stats.get_user_curr_connects(user), 1); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 1); + + first.release().await; +} + +#[tokio::test] +async fn cross_thread_drop_uses_captured_runtime_for_ip_cleanup() { + let user = "cross-thread-drop-user"; + let peer_addr: SocketAddr = "198.51.100.251:50013".parse().unwrap(); + + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 8); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer_addr, + ip_tracker.clone(), + ) + .await + .expect("reservation acquisition must succeed"); + + assert_eq!(stats.get_user_curr_connects(user), 1); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 1); + + std::thread::spawn(move || { + drop(reservation); + }) + .join() + .expect("drop thread must not panic"); + + tokio::time::timeout(Duration::from_secs(1), async { + loop { + if stats.get_user_curr_connects(user) == 0 + && ip_tracker.get_active_ip_count(user).await == 0 + { + break; + } + tokio::time::sleep(Duration::from_millis(5)).await; + } + }) + .await + .expect("cross-thread drop must still converge to zero user and IP footprint"); +} + +#[tokio::test] +async fn immediate_reacquire_after_cross_thread_drop_succeeds() { + let user = "cross-thread-reacquire-user"; + let peer_addr: SocketAddr = "198.51.100.252:50014".parse().unwrap(); + + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 1); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer_addr, + ip_tracker.clone(), + ) + .await + .expect("initial reservation must succeed"); + + std::thread::spawn(move || { + drop(reservation); + }) + .join() + .expect("drop thread must not panic"); + + tokio::time::timeout(Duration::from_secs(1), async { + loop { + if stats.get_user_curr_connects(user) == 0 + && ip_tracker.get_active_ip_count(user).await == 0 + { + break; + } + tokio::time::sleep(Duration::from_millis(5)).await; + } + }) + .await + .expect("cross-thread cleanup must settle before reacquire check"); + + let reacquire = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats, + peer_addr, + ip_tracker, + ) + .await; + assert!( + reacquire.is_ok(), + "reacquire must succeed after cross-thread drop cleanup" + ); +} + +#[tokio::test] +async fn concurrent_limit_rejections_from_mixed_ips_leave_no_ip_footprint() { + const PARALLEL_IPS: usize = 64; + const ATTEMPTS_PER_IP: usize = 8; + + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert("user".to_string(), 1); + + let config = Arc::new(config); + let stats = Arc::new(Stats::new()); + stats.increment_user_curr_connects("user"); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let mut tasks = tokio::task::JoinSet::new(); + for i in 0..PARALLEL_IPS { + let config = config.clone(); + let stats = stats.clone(); + let ip_tracker = ip_tracker.clone(); + + tasks.spawn(async move { + let ip = IpAddr::V4(std::net::Ipv4Addr::new(198, 51, 100, (i + 1) as u8)); + for _ in 0..ATTEMPTS_PER_IP { + let peer_addr = SocketAddr::new(ip, 40000 + i as u16); + let result = RunningClientHandler::check_user_limits_static( + "user", + &config, + &stats, + peer_addr, + &ip_tracker, + ) + .await; + + assert!(matches!( + result, + Err(ProxyError::ConnectionLimitExceeded { user }) if user == "user" + )); + } + }); + } + + while let Some(joined) = tasks.join_next().await { + joined.unwrap(); + } + + assert_eq!( + ip_tracker.get_active_ip_count("user").await, + 0, + "Concurrent rejected attempts must not leave active IP reservations" + ); + + let recent = ip_tracker + .get_recent_ips_for_users(&["user".to_string()]) + .await; + assert!( + recent + .get("user") + .map(|ips| ips.is_empty()) + .unwrap_or(true), + "Concurrent rejected attempts must not leave recent IP footprint" + ); + + assert_eq!( + stats.get_ip_reservation_rollback_tcp_limit_total(), + 0, + "No rollback should occur under concurrent rejection storms" + ); +} + +#[tokio::test] +async fn atomic_limit_gate_allows_only_one_concurrent_acquire() { + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert("user".to_string(), 1); + + let config = Arc::new(config); + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let mut tasks = tokio::task::JoinSet::new(); + for i in 0..64u16 { + let config = config.clone(); + let stats = stats.clone(); + let ip_tracker = ip_tracker.clone(); + tasks.spawn(async move { + let peer = SocketAddr::new( + IpAddr::V4(std::net::Ipv4Addr::new(203, 0, 113, (i + 1) as u8)), + 30000 + i, + ); + RunningClientHandler::acquire_user_connection_reservation_static( + "user", + &config, + stats, + peer, + ip_tracker, + ) + .await + .ok() + }); + } + + let mut successes = 0u64; + let mut held_reservations = Vec::new(); + while let Some(joined) = tasks.join_next().await { + if let Some(reservation) = joined.unwrap() { + successes += 1; + held_reservations.push(reservation); + } + } + + assert_eq!( + successes, 1, + "exactly one concurrent acquire must pass for a limit=1 user" + ); + assert_eq!(stats.get_user_curr_connects("user"), 1); + + drop(held_reservations); +} + +#[tokio::test] +async fn untrusted_proxy_header_source_is_rejected() { + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.server.proxy_protocol_trusted_cidrs = vec!["10.10.0.0/16".parse().unwrap()]; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(2048); + let peer: SocketAddr = "198.51.100.44:55000".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + true, + )); + + let proxy_header = ProxyProtocolV1Builder::new() + .tcp4( + "203.0.113.9:32000".parse().unwrap(), + "192.0.2.8:443".parse().unwrap(), + ) + .build(); + client_side.write_all(&proxy_header).await.unwrap(); + drop(client_side); + + let result = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + assert!(matches!(result, Err(ProxyError::InvalidProxyProtocol))); +} + +#[tokio::test] +async fn empty_proxy_trusted_cidrs_rejects_proxy_header_by_default() { + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.server.proxy_protocol_trusted_cidrs.clear(); + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(2048); + let peer: SocketAddr = "198.51.100.45:55000".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + true, + )); + + let proxy_header = ProxyProtocolV1Builder::new() + .tcp4( + "203.0.113.9:32000".parse().unwrap(), + "192.0.2.8:443".parse().unwrap(), + ) + .build(); + client_side.write_all(&proxy_header).await.unwrap(); + drop(client_side); + + let result = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + assert!(matches!(result, Err(ProxyError::InvalidProxyProtocol))); +} + +#[tokio::test] +async fn oversized_tls_record_is_masked_in_generic_stream_pipeline() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = [ + 0x16, + 0x03, + 0x01, + (((MAX_TLS_PLAINTEXT_SIZE + 1) >> 8) & 0xff) as u8, + ((MAX_TLS_PLAINTEXT_SIZE + 1) & 0xff) as u8, + ]; + let backend_reply = b"HTTP/1.1 400 Bad Request\r\nContent-Length: 0\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn({ + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = [0u8; 5]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, probe); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let bad_before = stats.get_connects_bad(); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(4096); + let peer: SocketAddr = "203.0.113.123:55123".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats.clone(), + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + )); + + client_side.write_all(&probe).await.unwrap(); + let mut observed = vec![0u8; backend_reply.len()]; + client_side.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, backend_reply); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + accept_task.await.unwrap(); + + assert_eq!( + stats.get_connects_bad(), + bad_before + 1, + "Oversized TLS probe must be classified as bad" + ); +} + +#[tokio::test] +async fn oversized_tls_record_is_masked_in_client_handler_pipeline() { + let mask_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = mask_listener.local_addr().unwrap(); + + let front_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let front_addr = front_listener.local_addr().unwrap(); + + let probe = [ + 0x16, + 0x03, + 0x01, + (((MAX_TLS_PLAINTEXT_SIZE + 1) >> 8) & 0xff) as u8, + ((MAX_TLS_PLAINTEXT_SIZE + 1) & 0xff) as u8, + ]; + let backend_reply = b"HTTP/1.1 403 Forbidden\r\nContent-Length: 0\r\n\r\n".to_vec(); + + let mask_accept_task = tokio::spawn({ + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = mask_listener.accept().await.unwrap(); + let mut got = [0u8; 5]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, probe); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let server_task = { + let config = config.clone(); + let stats = stats.clone(); + let upstream_manager = upstream_manager.clone(); + let replay_checker = replay_checker.clone(); + let buffer_pool = buffer_pool.clone(); + let rng = rng.clone(); + let route_runtime = route_runtime.clone(); + let ip_tracker = ip_tracker.clone(); + let beobachten = beobachten.clone(); + + tokio::spawn(async move { + let (stream, peer) = front_listener.accept().await.unwrap(); + let real_peer_report = Arc::new(std::sync::Mutex::new(None)); + ClientHandler::new( + stream, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + real_peer_report, + ) + .run() + .await + }) + }; + + let mut client = TcpStream::connect(front_addr).await.unwrap(); + client.write_all(&probe).await.unwrap(); + + let mut observed = vec![0u8; backend_reply.len()]; + client.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, backend_reply); + + tokio::time::timeout(Duration::from_secs(3), mask_accept_task) + .await + .unwrap() + .unwrap(); + + drop(client); + + let _ = tokio::time::timeout(Duration::from_secs(3), server_task) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn tls_record_len_min_minus_1_is_rejected_in_generic_stream_pipeline() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = [ + 0x16, + 0x03, + 0x01, + (((MIN_TLS_CLIENT_HELLO_SIZE - 1) >> 8) & 0xff) as u8, + ((MIN_TLS_CLIENT_HELLO_SIZE - 1) & 0xff) as u8, + ]; + let backend_reply = b"HTTP/1.1 400 Bad Request\r\nContent-Length: 0\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn({ + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = [0u8; 5]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, probe); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let bad_before = stats.get_connects_bad(); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(4096); + let peer: SocketAddr = "203.0.113.130:55130".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats.clone(), + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + )); + + client_side.write_all(&probe).await.unwrap(); + let mut observed = vec![0u8; backend_reply.len()]; + client_side.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, backend_reply); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + accept_task.await.unwrap(); + + assert_eq!( + stats.get_connects_bad(), + bad_before + 1, + "TLS record length below minimum structural ClientHello size must be rejected" + ); +} + +#[tokio::test] +async fn tls_record_len_min_minus_1_is_rejected_in_client_handler_pipeline() { + let mask_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = mask_listener.local_addr().unwrap(); + + let front_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let front_addr = front_listener.local_addr().unwrap(); + + let probe = [ + 0x16, + 0x03, + 0x01, + (((MIN_TLS_CLIENT_HELLO_SIZE - 1) >> 8) & 0xff) as u8, + ((MIN_TLS_CLIENT_HELLO_SIZE - 1) & 0xff) as u8, + ]; + let backend_reply = b"HTTP/1.1 403 Forbidden\r\nContent-Length: 0\r\n\r\n".to_vec(); + + let mask_accept_task = tokio::spawn({ + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = mask_listener.accept().await.unwrap(); + let mut got = [0u8; 5]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, probe); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let server_task = { + let config = config.clone(); + let stats = stats.clone(); + let upstream_manager = upstream_manager.clone(); + let replay_checker = replay_checker.clone(); + let buffer_pool = buffer_pool.clone(); + let rng = rng.clone(); + let route_runtime = route_runtime.clone(); + let ip_tracker = ip_tracker.clone(); + let beobachten = beobachten.clone(); + + tokio::spawn(async move { + let (stream, peer) = front_listener.accept().await.unwrap(); + let real_peer_report = Arc::new(std::sync::Mutex::new(None)); + ClientHandler::new( + stream, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + real_peer_report, + ) + .run() + .await + }) + }; + + let mut client = TcpStream::connect(front_addr).await.unwrap(); + client.write_all(&probe).await.unwrap(); + + let mut observed = vec![0u8; backend_reply.len()]; + client.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, backend_reply); + + tokio::time::timeout(Duration::from_secs(3), mask_accept_task) + .await + .unwrap() + .unwrap(); + + drop(client); + + let _ = tokio::time::timeout(Duration::from_secs(3), server_task) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn tls_record_len_16384_is_accepted_in_generic_stream_pipeline() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x55u8; 16]; + let client_hello = make_valid_tls_client_hello_with_len(&secret, 0, MAX_TLS_PLAINTEXT_SIZE); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + cfg.access.ignore_time_skew = true; + cfg.access + .users + .insert("user".to_string(), "55555555555555555555555555555555".to_string()); + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let bad_before = stats.get_connects_bad(); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(131072); + let peer: SocketAddr = "198.51.100.55:56055".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats.clone(), + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut record_header = [0u8; 5]; + client_side.read_exact(&mut record_header).await.unwrap(); + assert_eq!(record_header[0], 0x16, "Valid max-length ClientHello must be accepted"); + + drop(client_side); + let handler_result = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + assert!(handler_result.is_err()); + + let no_mask_connect = tokio::time::timeout(Duration::from_millis(250), listener.accept()).await; + assert!( + no_mask_connect.is_err(), + "Valid max-length ClientHello must not trigger mask fallback" + ); + + assert_eq!( + bad_before, + stats.get_connects_bad(), + "Valid max-length ClientHello must not increment bad counter" + ); +} + +#[tokio::test] +async fn tls_record_len_16384_is_accepted_in_client_handler_pipeline() { + let mask_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = mask_listener.local_addr().unwrap(); + + let front_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let front_addr = front_listener.local_addr().unwrap(); + + let secret = [0x66u8; 16]; + let client_hello = make_valid_tls_client_hello_with_len(&secret, 0, MAX_TLS_PLAINTEXT_SIZE); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + cfg.access.ignore_time_skew = true; + cfg.access + .users + .insert("user".to_string(), "66666666666666666666666666666666".to_string()); + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let bad_before = stats.get_connects_bad(); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let server_task = { + let config = config.clone(); + let stats = stats.clone(); + let upstream_manager = upstream_manager.clone(); + let replay_checker = replay_checker.clone(); + let buffer_pool = buffer_pool.clone(); + let rng = rng.clone(); + let route_runtime = route_runtime.clone(); + let ip_tracker = ip_tracker.clone(); + let beobachten = beobachten.clone(); + + tokio::spawn(async move { + let (stream, peer) = front_listener.accept().await.unwrap(); + let real_peer_report = Arc::new(std::sync::Mutex::new(None)); + ClientHandler::new( + stream, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + real_peer_report, + ) + .run() + .await + }) + }; + + let mut client = TcpStream::connect(front_addr).await.unwrap(); + client.write_all(&client_hello).await.unwrap(); + + let mut record_header = [0u8; 5]; + client.read_exact(&mut record_header).await.unwrap(); + assert_eq!(record_header[0], 0x16, "Valid max-length ClientHello must be accepted"); + + drop(client); + + let _ = tokio::time::timeout(Duration::from_secs(3), server_task) + .await + .unwrap() + .unwrap(); + + let no_mask_connect = tokio::time::timeout(Duration::from_millis(250), mask_listener.accept()).await; + assert!( + no_mask_connect.is_err(), + "Valid max-length ClientHello must not trigger mask fallback in ClientHandler path" + ); + + assert_eq!( + bad_before, + stats.get_connects_bad(), + "Valid max-length ClientHello must not increment bad counter" + ); +} + +fn lcg_next(state: &mut u64) -> u64 { + *state = state.wrapping_mul(6364136223846793005).wrapping_add(1); + *state +} + +async fn wait_for_user_and_ip_zero( + stats: &Arc, + ip_tracker: &Arc, + user: &str, +) { + tokio::time::timeout(Duration::from_secs(2), async { + loop { + if stats.get_user_curr_connects(user) == 0 + && ip_tracker.get_active_ip_count(user).await == 0 + { + break; + } + tokio::time::sleep(Duration::from_millis(5)).await; + } + }) + .await + .expect("cleanup must converge to zero user and IP footprint"); +} + +async fn burst_acquire_distinct_ips( + user: &'static str, + config: Arc, + stats: Arc, + ip_tracker: Arc, + third_octet: u8, + attempts: u16, +) -> (Vec, usize) { + let mut tasks = tokio::task::JoinSet::new(); + for i in 0..attempts { + let config = config.clone(); + let stats = stats.clone(); + let ip_tracker = ip_tracker.clone(); + tasks.spawn(async move { + let host = (i as u8).saturating_add(1); + let peer = SocketAddr::new( + IpAddr::V4(std::net::Ipv4Addr::new(198, 51, third_octet, host)), + 55000 + i, + ); + RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats, + peer, + ip_tracker, + ) + .await + }); + } + + let mut successes = Vec::new(); + let mut failures = 0usize; + while let Some(joined) = tasks.join_next().await { + match joined.expect("burst acquire task must not panic") { + Ok(reservation) => successes.push(reservation), + Err(err) => { + assert!(matches!( + err, + ProxyError::ConnectionLimitExceeded { user: ref denied_user } + if denied_user == user + )); + failures = failures.saturating_add(1); + } + } + } + + (successes, failures) +} + +#[tokio::test] +async fn deterministic_mixed_reservation_churn_preserves_counter_and_eventual_cleanup() { + let user = "deterministic-churn-user"; + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 12); + + let config = Arc::new(config); + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 4).await; + + let mut seed = 0xD1F2_A4C8_991B_77E1u64; + let mut reservations: Vec> = Vec::new(); + + for step in 0..220u64 { + let op = (lcg_next(&mut seed) % 100) as u8; + let active = reservations.iter().filter(|entry| entry.is_some()).count(); + + if active == 0 || op < 55 { + let ip_octet = (lcg_next(&mut seed) % 16 + 1) as u8; + let peer = SocketAddr::new( + IpAddr::V4(std::net::Ipv4Addr::new(198, 51, 120, ip_octet)), + 52000 + (step % 4000) as u16, + ); + let result = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await; + + if let Ok(reservation) = result { + reservations.push(Some(reservation)); + } else { + assert!(matches!( + result, + Err(ProxyError::ConnectionLimitExceeded { user }) if user == "deterministic-churn-user" + )); + } + } else { + let selected = reservations + .iter() + .enumerate() + .filter(|(_, entry)| entry.is_some()) + .map(|(idx, _)| idx) + .nth((lcg_next(&mut seed) as usize) % active) + .unwrap(); + + let reservation = reservations[selected].take().unwrap(); + if op < 80 { + reservation.release().await; + } else { + std::thread::spawn(move || { + drop(reservation); + }) + .join() + .expect("cross-thread drop must not panic"); + } + } + + let live_slots = reservations.iter().filter(|entry| entry.is_some()).count() as u64; + assert_eq!( + stats.get_user_curr_connects(user), + live_slots, + "current-connects counter must match number of live reservations" + ); + assert!( + stats.get_user_curr_connects(user) <= 12, + "current-connects must stay within configured TCP limit" + ); + assert!( + ip_tracker.get_active_ip_count(user).await <= 4, + "active unique IPs must stay within configured per-user IP limit" + ); + } + + for reservation in reservations.into_iter().flatten() { + reservation.release().await; + } + wait_for_user_and_ip_zero(&stats, &ip_tracker, user).await; +} + +#[tokio::test] +async fn cross_thread_drop_storm_then_parallel_reacquire_wave_has_no_leak() { + let user = "drop-storm-reacquire-user"; + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 64); + + let config = Arc::new(config); + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 8).await; + + let mut initial = Vec::new(); + for i in 0..32u16 { + let ip_octet = (i % 8 + 1) as u8; + let peer = SocketAddr::new( + IpAddr::V4(std::net::Ipv4Addr::new(203, 0, 120, ip_octet)), + 53000 + i, + ); + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await + .expect("initial reservation must succeed"); + initial.push(reservation); + } + + let mut second_half = initial.split_off(16); + + let mut releases = Vec::new(); + for reservation in initial { + releases.push(tokio::spawn(async move { + reservation.release().await; + })); + } + for release_task in releases { + release_task.await.expect("release task must not panic"); + } + + let mut drop_threads = Vec::new(); + for reservation in second_half.drain(..) { + drop_threads.push(std::thread::spawn(move || { + drop(reservation); + })); + } + for drop_thread in drop_threads { + drop_thread + .join() + .expect("cross-thread drop worker must not panic"); + } + + wait_for_user_and_ip_zero(&stats, &ip_tracker, user).await; + + let mut reacquire_tasks = tokio::task::JoinSet::new(); + for i in 0..16u16 { + let config = config.clone(); + let stats = stats.clone(); + let ip_tracker = ip_tracker.clone(); + reacquire_tasks.spawn(async move { + let peer = SocketAddr::new( + IpAddr::V4(std::net::Ipv4Addr::new(198, 51, 121, (i + 1) as u8)), + 54000 + i, + ); + RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats, + peer, + ip_tracker, + ) + .await + }); + } + + let mut acquired = Vec::new(); + while let Some(joined) = reacquire_tasks.join_next().await { + match joined.expect("reacquire task must not panic") { + Ok(reservation) => acquired.push(reservation), + Err(err) => { + assert!(matches!( + err, + ProxyError::ConnectionLimitExceeded { user } + if user == "drop-storm-reacquire-user" + )); + } + } + } + + assert!( + acquired.len() <= 8, + "parallel distinct-IP reacquire wave must not exceed per-user unique IP limit" + ); + for reservation in acquired { + reservation.release().await; + } + wait_for_user_and_ip_zero(&stats, &ip_tracker, user).await; +} + +#[tokio::test] +async fn scheduled_near_limit_and_burst_windows_preserve_admission_invariants() { + let user: &'static str = "scheduled-attack-user"; + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 6); + + let config = Arc::new(config); + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 2).await; + + let mut base = Vec::new(); + for i in 0..5u16 { + let peer = SocketAddr::new(IpAddr::V4(std::net::Ipv4Addr::new(198, 51, 130, 1)), 56000 + i); + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await + .expect("near-limit warmup reservation must succeed"); + base.push(reservation); + } + assert_eq!(stats.get_user_curr_connects(user), 5); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 1); + + let (wave1_success, wave1_fail) = burst_acquire_distinct_ips( + user, + config.clone(), + stats.clone(), + ip_tracker.clone(), + 131, + 32, + ) + .await; + assert_eq!(wave1_success.len(), 1); + assert_eq!(wave1_fail, 31); + assert_eq!(stats.get_user_curr_connects(user), 6); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 2); + + let released = base.pop().expect("must have releasable reservation"); + released.release().await; + for reservation in wave1_success { + reservation.release().await; + } + + tokio::time::timeout(Duration::from_secs(1), async { + loop { + if stats.get_user_curr_connects(user) == 4 + && ip_tracker.get_active_ip_count(user).await == 1 + { + break; + } + tokio::time::sleep(Duration::from_millis(5)).await; + } + }) + .await + .expect("window cleanup must settle to expected occupancy"); + + let (wave2_success, wave2_fail) = burst_acquire_distinct_ips( + user, + config, + stats.clone(), + ip_tracker.clone(), + 132, + 32, + ) + .await; + assert_eq!(wave2_success.len(), 1); + assert_eq!(wave2_fail, 31); + assert_eq!(stats.get_user_curr_connects(user), 5); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 2); + + let tail = base.split_off(2); + + let mut drop_threads = Vec::new(); + for reservation in base { + drop_threads.push(std::thread::spawn(move || { + drop(reservation); + })); + } + for drop_thread in drop_threads { + drop_thread + .join() + .expect("cross-thread scheduled cleanup must not panic"); + } + + for reservation in tail { + reservation.release().await; + } + for reservation in wave2_success { + reservation.release().await; + } + + wait_for_user_and_ip_zero(&stats, &ip_tracker, user).await; +} + +#[tokio::test] +async fn scheduled_mode_switch_burst_churn_preserves_limits_and_cleanup() { + let user: &'static str = "scheduled-mode-switch-user"; + let mut config = ProxyConfig::default(); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 10); + + let config = Arc::new(config); + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 3).await; + + let base_peer = SocketAddr::new(IpAddr::V4(std::net::Ipv4Addr::new(198, 51, 140, 1)), 57000); + let mut base = Vec::new(); + for i in 0..7u16 { + let peer = SocketAddr::new(base_peer.ip(), base_peer.port().saturating_add(i)); + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await + .expect("base occupancy reservation must succeed"); + base.push(reservation); + } + + assert_eq!(stats.get_user_curr_connects(user), 7); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 1); + + for round in 0..8u8 { + let (wave_success, wave_fail) = burst_acquire_distinct_ips( + user, + config.clone(), + stats.clone(), + ip_tracker.clone(), + 141u8.saturating_add(round), + 24, + ) + .await; + + assert!( + wave_success.len() <= 2, + "burst must not exceed available unique-IP headroom under limit=3" + ); + assert_eq!(wave_success.len() + wave_fail, 24); + assert_eq!( + stats.get_user_curr_connects(user), + 7 + wave_success.len() as u64, + "slot counter must reflect base occupancy plus successful burst leases" + ); + assert!(ip_tracker.get_active_ip_count(user).await <= 3); + + if round % 2 == 0 { + for reservation in wave_success { + reservation.release().await; + } + let rotated = base.pop().expect("base rotation reservation must exist"); + rotated.release().await; + } else { + for reservation in wave_success { + std::thread::spawn(move || { + drop(reservation); + }) + .join() + .expect("drop-heavy burst cleanup thread must not panic"); + } + let rotated = base.pop().expect("base rotation reservation must exist"); + std::thread::spawn(move || { + drop(rotated); + }) + .join() + .expect("drop-heavy base cleanup thread must not panic"); + } + + let replacement = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + base_peer, + ip_tracker.clone(), + ) + .await + .expect("base replacement reservation must succeed after each round"); + base.push(replacement); + + tokio::time::timeout(Duration::from_secs(1), async { + loop { + if stats.get_user_curr_connects(user) == 7 + && ip_tracker.get_active_ip_count(user).await <= 1 + { + break; + } + tokio::time::sleep(Duration::from_millis(5)).await; + } + }) + .await + .expect("round cleanup must converge to steady base occupancy"); + } + + for reservation in base { + reservation.release().await; + } + wait_for_user_and_ip_zero(&stats, &ip_tracker, user).await; +} diff --git a/src/proxy/tests/client_timing_profile_adversarial_tests.rs b/src/proxy/tests/client_timing_profile_adversarial_tests.rs new file mode 100644 index 0000000..134990e --- /dev/null +++ b/src/proxy/tests/client_timing_profile_adversarial_tests.rs @@ -0,0 +1,367 @@ +//! Differential timing-profile adversarial tests. +//! Compare malformed in-range TLS truncation probes with plain web baselines, +//! ensuring masking behavior stays in similar latency buckets. + +use super::*; +use crate::config::{UpstreamConfig, UpstreamType}; +use crate::protocol::constants::MIN_TLS_CLIENT_HELLO_SIZE; +use std::net::SocketAddr; +use std::time::{Duration, Instant}; +use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; +use tokio::net::{TcpListener, TcpStream}; + +const REPLY_404: &[u8] = b"HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n"; + +#[derive(Clone, Copy, Debug)] +enum ProbeClass { + MalformedTlsTruncation, + PlainWebBaseline, +} + +fn make_test_upstream_manager(stats: Arc) -> Arc { + Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats, + )) +} + +fn malformed_tls_probe() -> Vec { + vec![ + 0x16, + 0x03, + 0x03, + ((MIN_TLS_CLIENT_HELLO_SIZE >> 8) & 0xff) as u8, + (MIN_TLS_CLIENT_HELLO_SIZE & 0xff) as u8, + 0x41, + ] +} + +fn plain_web_probe() -> Vec { + b"GET /timing-profile HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec() +} + +fn summarize(samples_ms: &[u128]) -> (f64, u128, u128, u128) { + let mut sorted = samples_ms.to_vec(); + sorted.sort_unstable(); + let sum: u128 = sorted.iter().copied().sum(); + let mean = sum as f64 / sorted.len() as f64; + let min = sorted[0]; + let p95_idx = ((sorted.len() as f64) * 0.95).floor() as usize; + let p95 = sorted[p95_idx.min(sorted.len() - 1)]; + let max = sorted[sorted.len() - 1]; + (mean, min, p95, max) +} + +async fn run_generic_once(class: ProbeClass) -> u128 { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let backend_reply = REPLY_404.to_vec(); + + let accept_task = tokio::spawn({ + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut buf = [0u8; 5]; + stream.read_exact(&mut buf).await.unwrap(); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.timeouts.client_handshake = 1; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + + if matches!(class, ProbeClass::PlainWebBaseline) { + cfg.general.modes.classic = false; + cfg.general.modes.secure = false; + } + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = make_test_upstream_manager(stats.clone()); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(4096); + let peer: SocketAddr = "203.0.113.210:55110".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + )); + + let probe = match class { + ProbeClass::MalformedTlsTruncation => malformed_tls_probe(), + ProbeClass::PlainWebBaseline => plain_web_probe(), + }; + + let started = Instant::now(); + client_side.write_all(&probe).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let mut observed = vec![0u8; REPLY_404.len()]; + tokio::time::timeout(Duration::from_secs(2), client_side.read_exact(&mut observed)) + .await + .unwrap() + .unwrap(); + assert_eq!(observed, REPLY_404); + + tokio::time::timeout(Duration::from_secs(2), accept_task) + .await + .unwrap() + .unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(2), handler) + .await + .unwrap() + .unwrap(); + + started.elapsed().as_millis() +} + +async fn run_client_handler_once(class: ProbeClass) -> u128 { + let mask_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = mask_listener.local_addr().unwrap(); + + let front_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let front_addr = front_listener.local_addr().unwrap(); + + let backend_reply = REPLY_404.to_vec(); + let mask_accept_task = tokio::spawn({ + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = mask_listener.accept().await.unwrap(); + let mut buf = [0u8; 5]; + stream.read_exact(&mut buf).await.unwrap(); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.timeouts.client_handshake = 1; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + + if matches!(class, ProbeClass::PlainWebBaseline) { + cfg.general.modes.classic = false; + cfg.general.modes.secure = false; + } + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = make_test_upstream_manager(stats.clone()); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let server_task = { + let config = config.clone(); + let stats = stats.clone(); + let upstream_manager = upstream_manager.clone(); + let replay_checker = replay_checker.clone(); + let buffer_pool = buffer_pool.clone(); + let rng = rng.clone(); + let route_runtime = route_runtime.clone(); + let ip_tracker = ip_tracker.clone(); + let beobachten = beobachten.clone(); + + tokio::spawn(async move { + let (stream, peer) = front_listener.accept().await.unwrap(); + let real_peer_report = Arc::new(std::sync::Mutex::new(None)); + ClientHandler::new( + stream, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + real_peer_report, + ) + .run() + .await + }) + }; + + let probe = match class { + ProbeClass::MalformedTlsTruncation => malformed_tls_probe(), + ProbeClass::PlainWebBaseline => plain_web_probe(), + }; + + let mut client = TcpStream::connect(front_addr).await.unwrap(); + let started = Instant::now(); + client.write_all(&probe).await.unwrap(); + client.shutdown().await.unwrap(); + + let mut observed = vec![0u8; REPLY_404.len()]; + tokio::time::timeout(Duration::from_secs(2), client.read_exact(&mut observed)) + .await + .unwrap() + .unwrap(); + assert_eq!(observed, REPLY_404); + + tokio::time::timeout(Duration::from_secs(2), mask_accept_task) + .await + .unwrap() + .unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(2), server_task) + .await + .unwrap() + .unwrap(); + + started.elapsed().as_millis() +} + +#[tokio::test] +async fn differential_timing_generic_malformed_tls_vs_plain_web_mask_profile_similar() { + const ITER: usize = 24; + const BUCKET_MS: u128 = 20; + + let mut malformed = Vec::with_capacity(ITER); + let mut plain = Vec::with_capacity(ITER); + + for _ in 0..ITER { + malformed.push(run_generic_once(ProbeClass::MalformedTlsTruncation).await); + plain.push(run_generic_once(ProbeClass::PlainWebBaseline).await); + } + + let (m_mean, m_min, m_p95, m_max) = summarize(&malformed); + let (p_mean, p_min, p_p95, p_max) = summarize(&plain); + + println!( + "TIMING_DIFF generic class=malformed mean_ms={:.2} min_ms={} p95_ms={} max_ms={} bucket_mean={} bucket_p95={}", + m_mean, + m_min, + m_p95, + m_max, + (m_mean as u128) / BUCKET_MS, + m_p95 / BUCKET_MS + ); + println!( + "TIMING_DIFF generic class=plain_web mean_ms={:.2} min_ms={} p95_ms={} max_ms={} bucket_mean={} bucket_p95={}", + p_mean, + p_min, + p_p95, + p_max, + (p_mean as u128) / BUCKET_MS, + p_p95 / BUCKET_MS + ); + + let mean_bucket_delta = ((m_mean as i128) - (p_mean as i128)).abs() / (BUCKET_MS as i128); + let p95_bucket_delta = ((m_p95 as i128) - (p_p95 as i128)).abs() / (BUCKET_MS as i128); + + assert!( + mean_bucket_delta <= 1, + "generic timing mean diverged: malformed_mean_ms={:.2}, plain_mean_ms={:.2}", + m_mean, + p_mean + ); + assert!( + p95_bucket_delta <= 2, + "generic timing p95 diverged: malformed_p95_ms={}, plain_p95_ms={}", + m_p95, + p_p95 + ); +} + +#[tokio::test] +async fn differential_timing_client_handler_malformed_tls_vs_plain_web_mask_profile_similar() { + const ITER: usize = 16; + const BUCKET_MS: u128 = 20; + + let mut malformed = Vec::with_capacity(ITER); + let mut plain = Vec::with_capacity(ITER); + + for _ in 0..ITER { + malformed.push(run_client_handler_once(ProbeClass::MalformedTlsTruncation).await); + plain.push(run_client_handler_once(ProbeClass::PlainWebBaseline).await); + } + + let (m_mean, m_min, m_p95, m_max) = summarize(&malformed); + let (p_mean, p_min, p_p95, p_max) = summarize(&plain); + + println!( + "TIMING_DIFF handler class=malformed mean_ms={:.2} min_ms={} p95_ms={} max_ms={} bucket_mean={} bucket_p95={}", + m_mean, + m_min, + m_p95, + m_max, + (m_mean as u128) / BUCKET_MS, + m_p95 / BUCKET_MS + ); + println!( + "TIMING_DIFF handler class=plain_web mean_ms={:.2} min_ms={} p95_ms={} max_ms={} bucket_mean={} bucket_p95={}", + p_mean, + p_min, + p_p95, + p_max, + (p_mean as u128) / BUCKET_MS, + p_p95 / BUCKET_MS + ); + + let mean_bucket_delta = ((m_mean as i128) - (p_mean as i128)).abs() / (BUCKET_MS as i128); + let p95_bucket_delta = ((m_p95 as i128) - (p_p95 as i128)).abs() / (BUCKET_MS as i128); + + assert!( + mean_bucket_delta <= 1, + "handler timing mean diverged: malformed_mean_ms={:.2}, plain_mean_ms={:.2}", + m_mean, + p_mean + ); + assert!( + p95_bucket_delta <= 2, + "handler timing p95 diverged: malformed_p95_ms={}, plain_p95_ms={}", + m_p95, + p_p95 + ); +} diff --git a/src/proxy/tests/client_tls_clienthello_size_security_tests.rs b/src/proxy/tests/client_tls_clienthello_size_security_tests.rs new file mode 100644 index 0000000..e54791f --- /dev/null +++ b/src/proxy/tests/client_tls_clienthello_size_security_tests.rs @@ -0,0 +1,200 @@ +//! TLS ClientHello size validation tests for proxy anti-censorship security +//! Covers positive, negative, edge, adversarial, and fuzz cases. +//! Ensures proxy does not reveal itself on probe failures. + +use super::*; +use crate::config::{UpstreamConfig, UpstreamType}; +use crate::protocol::constants::{MAX_TLS_PLAINTEXT_SIZE, MIN_TLS_CLIENT_HELLO_SIZE}; +use std::net::SocketAddr; +use std::time::Duration; +use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpListener; + +fn test_probe_for_len(len: usize) -> [u8; 5] { + [ + 0x16, + 0x03, + 0x03, + ((len >> 8) & 0xff) as u8, + (len & 0xff) as u8, + ] +} + +fn make_test_upstream_manager(stats: Arc) -> Arc { + Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats, + )) +} + +async fn run_probe_and_assert_masking(len: usize, expect_bad_increment: bool) { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = test_probe_for_len(len); + let backend_reply = b"HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn({ + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = [0u8; 5]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, probe, "mask backend must receive original probe bytes"); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let bad_before = stats.get_connects_bad(); + let upstream_manager = make_test_upstream_manager(stats.clone()); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(4096); + let peer: SocketAddr = "203.0.113.123:55123".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats.clone(), + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + )); + + client_side.write_all(&probe).await.unwrap(); + let mut observed = vec![0u8; backend_reply.len()]; + client_side.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, backend_reply, "invalid TLS path must be masked as a real site"); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + accept_task.await.unwrap(); + + let expected_bad = if expect_bad_increment { bad_before + 1 } else { bad_before }; + assert_eq!( + stats.get_connects_bad(), + expected_bad, + "unexpected connects_bad classification for tls_len={len}" + ); +} + +#[tokio::test] +async fn tls_client_hello_lower_bound_minus_one_is_masked_and_counted_bad() { + run_probe_and_assert_masking(MIN_TLS_CLIENT_HELLO_SIZE - 1, true).await; +} + +#[tokio::test] +async fn tls_client_hello_upper_bound_plus_one_is_masked_and_counted_bad() { + run_probe_and_assert_masking(MAX_TLS_PLAINTEXT_SIZE + 1, true).await; +} + +#[tokio::test] +async fn tls_client_hello_header_zero_len_is_masked_and_counted_bad() { + run_probe_and_assert_masking(0, true).await; +} + +#[test] +fn tls_client_hello_len_bounds_unit_adversarial_sweep() { + let cases = [ + (0usize, false), + (1usize, false), + (99usize, false), + (100usize, true), + (101usize, true), + (511usize, true), + (512usize, true), + (MAX_TLS_PLAINTEXT_SIZE - 1, true), + (MAX_TLS_PLAINTEXT_SIZE, true), + (MAX_TLS_PLAINTEXT_SIZE + 1, false), + (u16::MAX as usize, false), + (usize::MAX, false), + ]; + + for (len, expected) in cases { + assert_eq!( + tls_clienthello_len_in_bounds(len), + expected, + "unexpected bounds result for tls_len={len}" + ); + } +} + +#[test] +fn tls_client_hello_len_bounds_light_fuzz_deterministic_lcg() { + let mut x: u32 = 0xA5A5_5A5A; + for _ in 0..2_048 { + x = x.wrapping_mul(1_664_525).wrapping_add(1_013_904_223); + let base = (x as usize) & 0x3fff; + let len = match x & 0x7 { + 0 => MIN_TLS_CLIENT_HELLO_SIZE - 1, + 1 => MIN_TLS_CLIENT_HELLO_SIZE, + 2 => MIN_TLS_CLIENT_HELLO_SIZE + 1, + 3 => MAX_TLS_PLAINTEXT_SIZE - 1, + 4 => MAX_TLS_PLAINTEXT_SIZE, + 5 => MAX_TLS_PLAINTEXT_SIZE + 1, + _ => base, + }; + let expect_bad = !(MIN_TLS_CLIENT_HELLO_SIZE..=MAX_TLS_PLAINTEXT_SIZE).contains(&len); + assert_eq!( + tls_clienthello_len_in_bounds(len), + !expect_bad, + "deterministic fuzz mismatch for tls_len={len}" + ); + } +} + +#[test] +fn tls_client_hello_len_bounds_stress_many_evaluations() { + for _ in 0..100_000 { + assert!(tls_clienthello_len_in_bounds(MIN_TLS_CLIENT_HELLO_SIZE)); + assert!(tls_clienthello_len_in_bounds(MAX_TLS_PLAINTEXT_SIZE)); + assert!(!tls_clienthello_len_in_bounds(MIN_TLS_CLIENT_HELLO_SIZE - 1)); + assert!(!tls_clienthello_len_in_bounds(MAX_TLS_PLAINTEXT_SIZE + 1)); + } +} + +#[tokio::test] +async fn tls_client_hello_masking_integration_repeated_small_probes() { + for _ in 0..25 { + run_probe_and_assert_masking(MIN_TLS_CLIENT_HELLO_SIZE - 1, true).await; + } +} diff --git a/src/proxy/tests/client_tls_clienthello_truncation_adversarial_tests.rs b/src/proxy/tests/client_tls_clienthello_truncation_adversarial_tests.rs new file mode 100644 index 0000000..6ac02dd --- /dev/null +++ b/src/proxy/tests/client_tls_clienthello_truncation_adversarial_tests.rs @@ -0,0 +1,561 @@ +//! Black-hat adversarial tests for truncated in-range TLS ClientHello probes. +//! These tests encode a strict anti-probing expectation: malformed TLS traffic +//! should still be masked as a legitimate website response. + +use super::*; +use crate::config::{UpstreamConfig, UpstreamType}; +use crate::protocol::constants::MIN_TLS_CLIENT_HELLO_SIZE; +use std::net::SocketAddr; +use std::time::Duration; +use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; +use tokio::net::{TcpListener, TcpStream}; +use tokio::time::sleep; + +fn in_range_probe_header() -> [u8; 5] { + [ + 0x16, + 0x03, + 0x03, + ((MIN_TLS_CLIENT_HELLO_SIZE >> 8) & 0xff) as u8, + (MIN_TLS_CLIENT_HELLO_SIZE & 0xff) as u8, + ] +} + +fn make_test_upstream_manager(stats: Arc) -> Arc { + Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats, + )) +} + +fn truncated_in_range_record(actual_body_len: usize) -> Vec { + let mut out = in_range_probe_header().to_vec(); + out.extend(std::iter::repeat_n(0x41, actual_body_len)); + out +} + +async fn write_fragmented(writer: &mut W, bytes: &[u8], chunks: &[usize], delay_ms: u64) { + let mut offset = 0usize; + for &chunk in chunks { + if offset >= bytes.len() { + break; + } + let end = (offset + chunk).min(bytes.len()); + writer.write_all(&bytes[offset..end]).await.unwrap(); + offset = end; + if delay_ms > 0 { + sleep(Duration::from_millis(delay_ms)).await; + } + } + if offset < bytes.len() { + writer.write_all(&bytes[offset..]).await.unwrap(); + } +} + +async fn run_blackhat_generic_fragmented_probe_should_mask( + payload: Vec, + chunks: &[usize], + delay_ms: u64, + backend_reply: Vec, +) { + let mask_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let mask_addr = mask_listener.local_addr().unwrap(); + let probe_header = in_range_probe_header(); + + let mask_accept_task = tokio::spawn({ + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = mask_listener.accept().await.unwrap(); + let mut got = [0u8; 5]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, probe_header); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.timeouts.client_handshake = 1; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = mask_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = make_test_upstream_manager(stats.clone()); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(4096); + let peer: SocketAddr = "203.0.113.202:55002".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + )); + + write_fragmented(&mut client_side, &payload, chunks, delay_ms).await; + client_side.shutdown().await.unwrap(); + + let mut observed = vec![0u8; backend_reply.len()]; + tokio::time::timeout(Duration::from_secs(2), client_side.read_exact(&mut observed)) + .await + .unwrap() + .unwrap(); + assert_eq!(observed, backend_reply); + + tokio::time::timeout(Duration::from_secs(2), mask_accept_task) + .await + .unwrap() + .unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(2), handler) + .await + .unwrap() + .unwrap(); +} + +async fn run_blackhat_client_handler_fragmented_probe_should_mask( + payload: Vec, + chunks: &[usize], + delay_ms: u64, + backend_reply: Vec, +) { + let mask_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let mask_addr = mask_listener.local_addr().unwrap(); + + let front_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let front_addr = front_listener.local_addr().unwrap(); + + let probe_header = in_range_probe_header(); + let mask_accept_task = tokio::spawn({ + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = mask_listener.accept().await.unwrap(); + let mut got = [0u8; 5]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, probe_header); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.timeouts.client_handshake = 1; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = mask_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = make_test_upstream_manager(stats.clone()); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let server_task = { + let config = config.clone(); + let stats = stats.clone(); + let upstream_manager = upstream_manager.clone(); + let replay_checker = replay_checker.clone(); + let buffer_pool = buffer_pool.clone(); + let rng = rng.clone(); + let route_runtime = route_runtime.clone(); + let ip_tracker = ip_tracker.clone(); + let beobachten = beobachten.clone(); + + tokio::spawn(async move { + let (stream, peer) = front_listener.accept().await.unwrap(); + let real_peer_report = Arc::new(std::sync::Mutex::new(None)); + ClientHandler::new( + stream, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + real_peer_report, + ) + .run() + .await + }) + }; + + let mut client = TcpStream::connect(front_addr).await.unwrap(); + write_fragmented(&mut client, &payload, chunks, delay_ms).await; + client.shutdown().await.unwrap(); + + let mut observed = vec![0u8; backend_reply.len()]; + tokio::time::timeout(Duration::from_secs(2), client.read_exact(&mut observed)) + .await + .unwrap() + .unwrap(); + assert_eq!(observed, backend_reply); + + tokio::time::timeout(Duration::from_secs(2), mask_accept_task) + .await + .unwrap() + .unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(2), server_task) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn blackhat_truncated_in_range_clienthello_generic_stream_should_mask() { + let mask_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let mask_addr = mask_listener.local_addr().unwrap(); + let backend_reply = b"HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n".to_vec(); + let probe = in_range_probe_header(); + + let mask_accept_task = tokio::spawn({ + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = mask_listener.accept().await.unwrap(); + let mut got = [0u8; 5]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, probe); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.timeouts.client_handshake = 1; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = mask_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = make_test_upstream_manager(stats.clone()); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(4096); + let peer: SocketAddr = "203.0.113.201:55001".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + )); + + client_side.write_all(&probe).await.unwrap(); + client_side.shutdown().await.unwrap(); + + // Security expectation: even malformed in-range TLS should be masked. + // This invariant must hold to avoid probe-distinguishable EOF/timeout behavior. + let mut observed = vec![0u8; backend_reply.len()]; + tokio::time::timeout(Duration::from_secs(2), client_side.read_exact(&mut observed)) + .await + .unwrap() + .unwrap(); + assert_eq!(observed, backend_reply); + + tokio::time::timeout(Duration::from_secs(2), mask_accept_task) + .await + .unwrap() + .unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(2), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn blackhat_truncated_in_range_clienthello_client_handler_should_mask() { + let mask_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let mask_addr = mask_listener.local_addr().unwrap(); + + let front_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let front_addr = front_listener.local_addr().unwrap(); + + let backend_reply = b"HTTP/1.1 403 Forbidden\r\nContent-Length: 0\r\n\r\n".to_vec(); + let probe = in_range_probe_header(); + + let mask_accept_task = tokio::spawn({ + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = mask_listener.accept().await.unwrap(); + let mut got = [0u8; 5]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, probe); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.timeouts.client_handshake = 1; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = mask_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = make_test_upstream_manager(stats.clone()); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let server_task = { + let config = config.clone(); + let stats = stats.clone(); + let upstream_manager = upstream_manager.clone(); + let replay_checker = replay_checker.clone(); + let buffer_pool = buffer_pool.clone(); + let rng = rng.clone(); + let route_runtime = route_runtime.clone(); + let ip_tracker = ip_tracker.clone(); + let beobachten = beobachten.clone(); + + tokio::spawn(async move { + let (stream, peer) = front_listener.accept().await.unwrap(); + let real_peer_report = Arc::new(std::sync::Mutex::new(None)); + ClientHandler::new( + stream, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + real_peer_report, + ) + .run() + .await + }) + }; + + let mut client = TcpStream::connect(front_addr).await.unwrap(); + client.write_all(&probe).await.unwrap(); + client.shutdown().await.unwrap(); + + // Security expectation: malformed in-range TLS should still be masked. + let mut observed = vec![0u8; backend_reply.len()]; + tokio::time::timeout(Duration::from_secs(2), client.read_exact(&mut observed)) + .await + .unwrap() + .unwrap(); + assert_eq!(observed, backend_reply); + + tokio::time::timeout(Duration::from_secs(2), mask_accept_task) + .await + .unwrap() + .unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(2), server_task) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn blackhat_generic_truncated_min_body_1_should_mask() { + run_blackhat_generic_fragmented_probe_should_mask( + truncated_in_range_record(1), + &[6], + 0, + b"HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n".to_vec(), + ) + .await; +} + +#[tokio::test] +async fn blackhat_generic_truncated_min_body_8_should_mask() { + run_blackhat_generic_fragmented_probe_should_mask( + truncated_in_range_record(8), + &[13], + 0, + b"HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n".to_vec(), + ) + .await; +} + +#[tokio::test] +async fn blackhat_generic_truncated_min_body_99_should_mask() { + run_blackhat_generic_fragmented_probe_should_mask( + truncated_in_range_record(MIN_TLS_CLIENT_HELLO_SIZE - 1), + &[5, MIN_TLS_CLIENT_HELLO_SIZE - 1], + 0, + b"HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n".to_vec(), + ) + .await; +} + +#[tokio::test] +async fn blackhat_generic_fragmented_header_then_close_should_mask() { + run_blackhat_generic_fragmented_probe_should_mask( + truncated_in_range_record(0), + &[1, 1, 1, 1, 1], + 0, + b"HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n".to_vec(), + ) + .await; +} + +#[tokio::test] +async fn blackhat_generic_fragmented_header_plus_partial_body_should_mask() { + run_blackhat_generic_fragmented_probe_should_mask( + truncated_in_range_record(5), + &[1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + 0, + b"HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n".to_vec(), + ) + .await; +} + +#[tokio::test] +async fn blackhat_generic_slowloris_fragmented_min_probe_should_mask_but_times_out() { + run_blackhat_generic_fragmented_probe_should_mask( + truncated_in_range_record(1), + &[1, 1, 1, 1, 1, 1], + 250, + b"HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n".to_vec(), + ) + .await; +} + +#[tokio::test] +async fn blackhat_client_handler_truncated_min_body_1_should_mask() { + run_blackhat_client_handler_fragmented_probe_should_mask( + truncated_in_range_record(1), + &[6], + 0, + b"HTTP/1.1 403 Forbidden\r\nContent-Length: 0\r\n\r\n".to_vec(), + ) + .await; +} + +#[tokio::test] +async fn blackhat_client_handler_truncated_min_body_8_should_mask() { + run_blackhat_client_handler_fragmented_probe_should_mask( + truncated_in_range_record(8), + &[13], + 0, + b"HTTP/1.1 403 Forbidden\r\nContent-Length: 0\r\n\r\n".to_vec(), + ) + .await; +} + +#[tokio::test] +async fn blackhat_client_handler_truncated_min_body_99_should_mask() { + run_blackhat_client_handler_fragmented_probe_should_mask( + truncated_in_range_record(MIN_TLS_CLIENT_HELLO_SIZE - 1), + &[5, MIN_TLS_CLIENT_HELLO_SIZE - 1], + 0, + b"HTTP/1.1 403 Forbidden\r\nContent-Length: 0\r\n\r\n".to_vec(), + ) + .await; +} + +#[tokio::test] +async fn blackhat_client_handler_fragmented_header_then_close_should_mask() { + run_blackhat_client_handler_fragmented_probe_should_mask( + truncated_in_range_record(0), + &[1, 1, 1, 1, 1], + 0, + b"HTTP/1.1 403 Forbidden\r\nContent-Length: 0\r\n\r\n".to_vec(), + ) + .await; +} + +#[tokio::test] +async fn blackhat_client_handler_fragmented_header_plus_partial_body_should_mask() { + run_blackhat_client_handler_fragmented_probe_should_mask( + truncated_in_range_record(5), + &[1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + 0, + b"HTTP/1.1 403 Forbidden\r\nContent-Length: 0\r\n\r\n".to_vec(), + ) + .await; +} + +#[tokio::test] +async fn blackhat_client_handler_slowloris_fragmented_min_probe_should_mask_but_times_out() { + run_blackhat_client_handler_fragmented_probe_should_mask( + truncated_in_range_record(1), + &[1, 1, 1, 1, 1, 1], + 250, + b"HTTP/1.1 403 Forbidden\r\nContent-Length: 0\r\n\r\n".to_vec(), + ) + .await; +} diff --git a/src/proxy/tests/client_tls_mtproto_fallback_security_tests.rs b/src/proxy/tests/client_tls_mtproto_fallback_security_tests.rs new file mode 100644 index 0000000..920c013 --- /dev/null +++ b/src/proxy/tests/client_tls_mtproto_fallback_security_tests.rs @@ -0,0 +1,2861 @@ +use super::*; +use crate::config::{UpstreamConfig, UpstreamType}; +use crate::crypto::sha256_hmac; +use crate::protocol::constants::{ + HANDSHAKE_LEN, + MAX_TLS_CIPHERTEXT_SIZE, + TLS_RECORD_ALERT, + TLS_RECORD_APPLICATION, + TLS_RECORD_CHANGE_CIPHER, + TLS_RECORD_HANDSHAKE, + TLS_VERSION, +}; +use crate::protocol::tls; +use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpListener; + +struct PipelineHarness { + config: Arc, + stats: Arc, + upstream_manager: Arc, + replay_checker: Arc, + buffer_pool: Arc, + rng: Arc, + route_runtime: Arc, + ip_tracker: Arc, + beobachten: Arc, +} + +fn build_harness(secret_hex: &str, mask_port: u16) -> PipelineHarness { + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = mask_port; + cfg.censorship.mask_proxy_protocol = 0; + cfg.access.ignore_time_skew = true; + cfg.access + .users + .insert("user".to_string(), secret_hex.to_string()); + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + + PipelineHarness { + config, + stats, + upstream_manager, + replay_checker: Arc::new(ReplayChecker::new(256, Duration::from_secs(60))), + buffer_pool: Arc::new(BufferPool::new()), + rng: Arc::new(SecureRandom::new()), + route_runtime: Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + ip_tracker: Arc::new(UserIpTracker::new()), + beobachten: Arc::new(BeobachtenStore::new()), + } +} + +fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec { + assert!(tls_len <= u16::MAX as usize, "TLS length must fit into record header"); + + let total_len = 5 + tls_len; + let mut handshake = vec![fill; total_len]; + + handshake[0] = 0x16; + handshake[1] = 0x03; + handshake[2] = 0x01; + handshake[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes()); + + let session_id_len: usize = 32; + handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8; + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + let computed = sha256_hmac(secret, &handshake); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] + .copy_from_slice(&digest); + handshake +} + +fn wrap_tls_application_data(payload: &[u8]) -> Vec { + let mut record = Vec::with_capacity(5 + payload.len()); + record.push(0x17); + record.extend_from_slice(&TLS_VERSION); + record.extend_from_slice(&(payload.len() as u16).to_be_bytes()); + record.extend_from_slice(payload); + record +} + +fn wrap_tls_record(record_type: u8, payload: &[u8]) -> Vec { + let mut record = Vec::with_capacity(5 + payload.len()); + record.push(record_type); + record.extend_from_slice(&TLS_VERSION); + record.extend_from_slice(&(payload.len() as u16).to_be_bytes()); + record.extend_from_slice(payload); + record +} + +fn wrap_invalid_mtproto_with_coalesced_tail(tail: &[u8]) -> Vec { + let mut payload = vec![0u8; HANDSHAKE_LEN]; + payload.extend_from_slice(tail); + wrap_tls_application_data(&payload) +} + +async fn read_and_discard_tls_record_body(stream: &mut T, header: [u8; 5]) +where + T: tokio::io::AsyncRead + Unpin, +{ + let len = u16::from_be_bytes([header[3], header[4]]) as usize; + let mut body = vec![0u8; len]; + stream.read_exact(&mut body).await.unwrap(); +} + +#[tokio::test] +async fn tls_bad_mtproto_fallback_preserves_wire_and_backend_response() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x81u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 0, 600, 0x42); + let invalid_mtproto = vec![0u8; HANDSHAKE_LEN]; + let invalid_mtproto_record = wrap_tls_application_data(&invalid_mtproto); + let trailing_payload = b"masked-trailing-record".to_vec(); + let trailing_record = wrap_tls_application_data(&trailing_payload); + let backend_response = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK".to_vec(); + let expected_trailing_record = trailing_record.clone(); + let expected_response = backend_response.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got_trailing = vec![0u8; expected_trailing_record.len()]; + stream.read_exact(&mut got_trailing).await.unwrap(); + assert_eq!(got_trailing, expected_trailing_record); + + stream.write_all(&expected_response).await.unwrap(); + }); + + let harness = build_harness("81818181818181818181818181818181", backend_addr.port()); + let (server_side, mut client_side) = duplex(131072); + let peer: SocketAddr = "198.51.100.181:56001".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + + let mut tls_response_head = [0u8; 5]; + client_side.read_exact(&mut tls_response_head).await.unwrap(); + assert_eq!(tls_response_head[0], 0x16); + read_and_discard_tls_record_body(&mut client_side, tls_response_head).await; + + client_side.write_all(&invalid_mtproto_record).await.unwrap(); + client_side.write_all(&trailing_record).await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn tls_bad_mtproto_fallback_keeps_connects_bad_accounting() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x82u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 1, 600, 0x43); + let invalid_mtproto = vec![0u8; HANDSHAKE_LEN]; + let invalid_mtproto_record = wrap_tls_application_data(&invalid_mtproto); + let trailing_record = wrap_tls_application_data(b"x"); + let expected_trailing_record = trailing_record.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got_trailing = vec![0u8; expected_trailing_record.len()]; + stream.read_exact(&mut got_trailing).await.unwrap(); + assert_eq!(got_trailing, expected_trailing_record); + }); + + let harness = build_harness("82828282828282828282828282828282", backend_addr.port()); + let bad_before = harness.stats.get_connects_bad(); + + let (server_side, mut client_side) = duplex(65536); + let peer: SocketAddr = "198.51.100.182:56002".parse().unwrap(); + let stats_for_assert = harness.stats.clone(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + + let mut tls_response_head = [0u8; 5]; + client_side.read_exact(&mut tls_response_head).await.unwrap(); + assert_eq!(tls_response_head[0], 0x16); + + client_side.write_all(&invalid_mtproto_record).await.unwrap(); + client_side.write_all(&trailing_record).await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + + let bad_after = stats_for_assert.get_connects_bad(); + assert_eq!(bad_after, bad_before + 1, "connects_bad must increase exactly once for invalid MTProto after valid TLS"); +} + +#[tokio::test] +async fn tls_bad_mtproto_fallback_forwards_zero_length_tls_record_verbatim() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x83u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 2, 600, 0x44); + let invalid_mtproto = vec![0u8; HANDSHAKE_LEN]; + let invalid_mtproto_record = wrap_tls_application_data(&invalid_mtproto); + let trailing_record = wrap_tls_application_data(&[]); + let expected_trailing_record = trailing_record.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got_trailing = vec![0u8; expected_trailing_record.len()]; + stream.read_exact(&mut got_trailing).await.unwrap(); + assert_eq!(got_trailing, expected_trailing_record); + }); + + let harness = build_harness("83838383838383838383838383838383", backend_addr.port()); + let (server_side, mut client_side) = duplex(65536); + let peer: SocketAddr = "198.51.100.183:56003".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + + let mut tls_response_head = [0u8; 5]; + client_side.read_exact(&mut tls_response_head).await.unwrap(); + assert_eq!(tls_response_head[0], 0x16); + + client_side.write_all(&invalid_mtproto_record).await.unwrap(); + client_side.write_all(&trailing_record).await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn tls_bad_mtproto_fallback_forwards_max_tls_record_verbatim() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x84u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 3, 600, 0x45); + let invalid_mtproto = vec![0u8; HANDSHAKE_LEN]; + let invalid_mtproto_record = wrap_tls_application_data(&invalid_mtproto); + let trailing_payload = vec![0xAB; MAX_TLS_CIPHERTEXT_SIZE]; + let trailing_record = wrap_tls_application_data(&trailing_payload); + let expected_trailing_record = trailing_record.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got_trailing = vec![0u8; expected_trailing_record.len()]; + stream.read_exact(&mut got_trailing).await.unwrap(); + assert_eq!(got_trailing, expected_trailing_record); + }); + + let harness = build_harness("84848484848484848484848484848484", backend_addr.port()); + let (server_side, mut client_side) = duplex(2 * 1024 * 1024); + let peer: SocketAddr = "198.51.100.184:56004".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + + let mut tls_response_head = [0u8; 5]; + client_side.read_exact(&mut tls_response_head).await.unwrap(); + assert_eq!(tls_response_head[0], 0x16); + + client_side.write_all(&invalid_mtproto_record).await.unwrap(); + client_side.write_all(&trailing_record).await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn tls_bad_mtproto_fallback_light_fuzz_tls_record_lengths_verbatim() { + let lengths = [0usize, 1, 2, 3, 7, 15, 31, 63, 127, 255, 1024, 4096]; + + for (idx, payload_len) in lengths.iter().copied().enumerate() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x85u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, idx as u32 + 4, 600, 0x46 + idx as u8); + let invalid_mtproto = vec![0u8; HANDSHAKE_LEN]; + let invalid_mtproto_record = wrap_tls_application_data(&invalid_mtproto); + + let mut payload = vec![0u8; payload_len]; + for (i, b) in payload.iter_mut().enumerate() { + *b = ((idx as u8).wrapping_mul(29)).wrapping_add(i as u8); + } + let trailing_record = wrap_tls_application_data(&payload); + let expected_trailing_record = trailing_record.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got_trailing = vec![0u8; expected_trailing_record.len()]; + stream.read_exact(&mut got_trailing).await.unwrap(); + assert_eq!(got_trailing, expected_trailing_record); + }); + + let harness = build_harness("85858585858585858585858585858585", backend_addr.port()); + let (server_side, mut client_side) = duplex(262144); + let peer: SocketAddr = format!("198.51.100.185:{}", 56010 + idx as u16) + .parse() + .unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + + let mut tls_response_head = [0u8; 5]; + client_side.read_exact(&mut tls_response_head).await.unwrap(); + assert_eq!(tls_response_head[0], 0x16); + + client_side.write_all(&invalid_mtproto_record).await.unwrap(); + client_side.write_all(&trailing_record).await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + } +} + +#[tokio::test] +async fn tls_bad_mtproto_fallback_concurrent_sessions_are_isolated() { + let sessions = 24usize; + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let mut expected_records = std::collections::HashSet::new(); + let secret = [0x86u8; 16]; + for idx in 0..sessions { + let _hello = make_valid_tls_client_hello(&secret, idx as u32 + 100, 600, 0x60 + idx as u8); + let payload = vec![idx as u8; 64 + idx]; + let trailing = wrap_tls_application_data(&payload); + expected_records.insert(trailing); + } + + let accept_task = tokio::spawn(async move { + let mut remaining = expected_records; + for idx in 0..sessions { + let (mut stream, _) = listener.accept().await.unwrap(); + + let _ = idx; + let mut header = [0u8; 5]; + stream.read_exact(&mut header).await.unwrap(); + assert_eq!(header[0], TLS_RECORD_APPLICATION); + + let len = u16::from_be_bytes([header[3], header[4]]) as usize; + let mut record = vec![0u8; 5 + len]; + record[..5].copy_from_slice(&header); + stream.read_exact(&mut record[5..]).await.unwrap(); + + assert!( + remaining.remove(&record), + "unexpected trailing TLS record in concurrent isolation test" + ); + } + + assert!(remaining.is_empty(), "all expected client sessions must be matched exactly once"); + }); + + let mut client_tasks = Vec::with_capacity(sessions); + + for idx in 0..sessions { + let harness = build_harness("86868686868686868686868686868686", backend_addr.port()); + let secret = [0x86u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, idx as u32 + 100, 600, 0x60 + idx as u8); + let invalid_mtproto = vec![0u8; HANDSHAKE_LEN]; + let invalid_mtproto_record = wrap_tls_application_data(&invalid_mtproto); + let trailing_payload = vec![idx as u8; 64 + idx]; + let trailing_record = wrap_tls_application_data(&trailing_payload); + + let peer: SocketAddr = format!("198.51.100.186:{}", 57000 + idx as u16) + .parse() + .unwrap(); + + client_tasks.push(tokio::spawn(async move { + let (server_side, mut client_side) = duplex(262144); + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + + let mut tls_response_head = [0u8; 5]; + client_side.read_exact(&mut tls_response_head).await.unwrap(); + assert_eq!(tls_response_head[0], 0x16); + + client_side.write_all(&invalid_mtproto_record).await.unwrap(); + client_side.write_all(&trailing_record).await.unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + })); + } + + for task in client_tasks { + task.await.unwrap(); + } + + tokio::time::timeout(Duration::from_secs(5), accept_task) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn tls_bad_mtproto_fallback_forwards_fragmented_client_writes_verbatim() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x87u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 9, 600, 0x57); + let invalid_mtproto = vec![0u8; HANDSHAKE_LEN]; + let invalid_mtproto_record = wrap_tls_application_data(&invalid_mtproto); + let payload = b"fragmented-writes-to-test-stream-boundary-robustness".to_vec(); + let trailing_record = wrap_tls_application_data(&payload); + let expected_trailing_record = trailing_record.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got_trailing = vec![0u8; expected_trailing_record.len()]; + stream.read_exact(&mut got_trailing).await.unwrap(); + assert_eq!(got_trailing, expected_trailing_record); + }); + + let harness = build_harness("87878787878787878787878787878787", backend_addr.port()); + let (server_side, mut client_side) = duplex(262144); + let peer: SocketAddr = "198.51.100.187:56087".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + + let mut tls_response_head = [0u8; 5]; + client_side.read_exact(&mut tls_response_head).await.unwrap(); + assert_eq!(tls_response_head[0], 0x16); + + client_side.write_all(&invalid_mtproto_record).await.unwrap(); + + for chunk in trailing_record.chunks(3) { + client_side.write_all(chunk).await.unwrap(); + } + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn tls_bad_mtproto_fallback_header_fragmentation_bytewise_is_verbatim() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x88u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 10, 600, 0x58); + let invalid_mtproto_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); + let trailing_record = wrap_tls_application_data(b"bytewise-header"); + let expected_trailing = trailing_record.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got_trailing = vec![0u8; expected_trailing.len()]; + stream.read_exact(&mut got_trailing).await.unwrap(); + assert_eq!(got_trailing, expected_trailing); + }); + + let harness = build_harness("88888888888888888888888888888888", backend_addr.port()); + let (server_side, mut client_side) = duplex(131072); + let peer: SocketAddr = "198.51.100.188:56088".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + + let mut tls_response_head = [0u8; 5]; + client_side.read_exact(&mut tls_response_head).await.unwrap(); + assert_eq!(tls_response_head[0], 0x16); + + client_side.write_all(&invalid_mtproto_record).await.unwrap(); + for b in trailing_record.iter().copied() { + client_side.write_all(&[b]).await.unwrap(); + } + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn tls_bad_mtproto_fallback_record_splitting_chaos_is_verbatim() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x89u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 11, 600, 0x59); + let invalid_mtproto_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); + + let mut payload = vec![0u8; 2048]; + for (i, b) in payload.iter_mut().enumerate() { + *b = (i as u8).wrapping_mul(17).wrapping_add(3); + } + let trailing_record = wrap_tls_application_data(&payload); + let expected_trailing = trailing_record.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got_trailing = vec![0u8; expected_trailing.len()]; + stream.read_exact(&mut got_trailing).await.unwrap(); + assert_eq!(got_trailing, expected_trailing); + }); + + let harness = build_harness("89898989898989898989898989898989", backend_addr.port()); + let (server_side, mut client_side) = duplex(262144); + let peer: SocketAddr = "198.51.100.189:56089".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + + let mut tls_response_head = [0u8; 5]; + client_side.read_exact(&mut tls_response_head).await.unwrap(); + assert_eq!(tls_response_head[0], 0x16); + + client_side.write_all(&invalid_mtproto_record).await.unwrap(); + + let chaos = [7usize, 1, 19, 3, 5, 31, 2, 11, 13, 17]; + let mut pos = 0usize; + let mut idx = 0usize; + while pos < trailing_record.len() { + let step = chaos[idx % chaos.len()]; + let end = (pos + step).min(trailing_record.len()); + client_side.write_all(&trailing_record[pos..end]).await.unwrap(); + pos = end; + idx += 1; + } + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn tls_bad_mtproto_fallback_multiple_tls_records_are_forwarded_in_order() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x8Au8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 12, 600, 0x5A); + let invalid_mtproto_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); + + let r1 = wrap_tls_application_data(b"alpha"); + let r2 = wrap_tls_application_data(b"beta-beta"); + let r3 = wrap_tls_application_data(b"gamma-gamma-gamma"); + let expected = [r1.clone(), r2.clone(), r3.clone()].concat(); + let expected_concat = expected.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got = vec![0u8; expected_concat.len()]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, expected_concat); + }); + + let harness = build_harness("8a8a8a8a8a8a8a8a8a8a8a8a8a8a8a8a", backend_addr.port()); + let (server_side, mut client_side) = duplex(131072); + let peer: SocketAddr = "198.51.100.190:56090".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut tls_response_head = [0u8; 5]; + client_side.read_exact(&mut tls_response_head).await.unwrap(); + assert_eq!(tls_response_head[0], 0x16); + + client_side.write_all(&invalid_mtproto_record).await.unwrap(); + client_side.write_all(&r1).await.unwrap(); + client_side.write_all(&r2).await.unwrap(); + client_side.write_all(&r3).await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn tls_bad_mtproto_fallback_client_half_close_propagates_eof_to_backend() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x8Bu8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 13, 600, 0x5B); + let invalid_mtproto_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); + let trailing_record = wrap_tls_application_data(b"half-close-probe"); + let expected_trailing = trailing_record.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got_trailing = vec![0u8; expected_trailing.len()]; + stream.read_exact(&mut got_trailing).await.unwrap(); + assert_eq!(got_trailing, expected_trailing); + + let mut tail = [0u8; 1]; + let n = stream.read(&mut tail).await.unwrap(); + assert_eq!(n, 0, "backend must observe EOF after client write half-close"); + }); + + let harness = build_harness("8b8b8b8b8b8b8b8b8b8b8b8b8b8b8b8b", backend_addr.port()); + let (server_side, mut client_side) = duplex(131072); + let peer: SocketAddr = "198.51.100.191:56091".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut tls_response_head = [0u8; 5]; + client_side.read_exact(&mut tls_response_head).await.unwrap(); + assert_eq!(tls_response_head[0], 0x16); + + client_side.write_all(&invalid_mtproto_record).await.unwrap(); + client_side.write_all(&trailing_record).await.unwrap(); + client_side.shutdown().await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn tls_bad_mtproto_fallback_backend_half_close_after_response_is_tolerated() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x8Cu8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 14, 600, 0x5C); + let invalid_mtproto_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); + let trailing_record = wrap_tls_application_data(b"backend-half-close"); + let backend_response = b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\n\r\n".to_vec(); + let expected_trailing = trailing_record.clone(); + let response = backend_response.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got_trailing = vec![0u8; expected_trailing.len()]; + stream.read_exact(&mut got_trailing).await.unwrap(); + assert_eq!(got_trailing, expected_trailing); + + stream.write_all(&response).await.unwrap(); + stream.shutdown().await.unwrap(); + }); + + let harness = build_harness("8c8c8c8c8c8c8c8c8c8c8c8c8c8c8c8c", backend_addr.port()); + let (server_side, mut client_side) = duplex(131072); + let peer: SocketAddr = "198.51.100.192:56092".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut tls_response_head = [0u8; 5]; + client_side.read_exact(&mut tls_response_head).await.unwrap(); + assert_eq!(tls_response_head[0], 0x16); + read_and_discard_tls_record_body(&mut client_side, tls_response_head).await; + + client_side.write_all(&invalid_mtproto_record).await.unwrap(); + client_side.write_all(&trailing_record).await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn tls_bad_mtproto_fallback_backend_reset_after_clienthello_is_handled() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x8Du8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 15, 600, 0x5D); + let invalid_mtproto_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); + let trailing_record = wrap_tls_application_data(b"backend-reset"); + let accept_task = tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + drop(stream); + }); + + let harness = build_harness("8d8d8d8d8d8d8d8d8d8d8d8d8d8d8d8d", backend_addr.port()); + let (server_side, mut client_side) = duplex(131072); + let peer: SocketAddr = "198.51.100.193:56093".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut tls_response_head = [0u8; 5]; + client_side.read_exact(&mut tls_response_head).await.unwrap(); + assert_eq!(tls_response_head[0], 0x16); + + client_side.write_all(&invalid_mtproto_record).await.unwrap(); + let write_res = client_side.write_all(&trailing_record).await; + assert!( + write_res.is_ok() || write_res.is_err(), + "write completion is environment dependent under backend reset" + ); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn tls_bad_mtproto_fallback_backend_slow_reader_preserves_byte_identity() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x8Eu8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 16, 600, 0x5E); + let invalid_mtproto_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); + + let payload = vec![0xEC; 8192]; + let trailing_record = wrap_tls_application_data(&payload); + let expected_trailing = trailing_record.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got_trailing = vec![0u8; expected_trailing.len()]; + let mut offset = 0usize; + while offset < got_trailing.len() { + let step = (offset % 97).max(1).min(got_trailing.len() - offset); + stream + .read_exact(&mut got_trailing[offset..offset + step]) + .await + .unwrap(); + offset += step; + tokio::time::sleep(Duration::from_millis(1)).await; + } + assert_eq!(got_trailing, expected_trailing); + }); + + let harness = build_harness("8e8e8e8e8e8e8e8e8e8e8e8e8e8e8e8e", backend_addr.port()); + let (server_side, mut client_side) = duplex(262144); + let peer: SocketAddr = "198.51.100.194:56094".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut tls_response_head = [0u8; 5]; + client_side.read_exact(&mut tls_response_head).await.unwrap(); + assert_eq!(tls_response_head[0], 0x16); + + client_side.write_all(&invalid_mtproto_record).await.unwrap(); + client_side.write_all(&trailing_record).await.unwrap(); + + tokio::time::timeout(Duration::from_secs(5), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn tls_bad_mtproto_fallback_replay_pressure_masks_replay_without_serverhello() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x8Fu8; 16]; + let replayed_hello = make_valid_tls_client_hello(&secret, 17, 600, 0x5F); + let invalid_mtproto_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); + let trailing_record = wrap_tls_application_data(b"first-session"); + + let expected_second = replayed_hello.clone(); + let expected_trailing = trailing_record.clone(); + + let accept_task = tokio::spawn(async move { + let (mut s1, _) = listener.accept().await.unwrap(); + let mut got1_tail = vec![0u8; expected_trailing.len()]; + s1.read_exact(&mut got1_tail).await.unwrap(); + assert_eq!(got1_tail, expected_trailing); + drop(s1); + + let (mut s2, _) = listener.accept().await.unwrap(); + let mut got2 = vec![0u8; expected_second.len()]; + s2.read_exact(&mut got2).await.unwrap(); + assert_eq!(got2, expected_second); + }); + + let harness = build_harness("8f8f8f8f8f8f8f8f8f8f8f8f8f8f8f8f", backend_addr.port()); + let stats_for_assert = harness.stats.clone(); + let bad_before = stats_for_assert.get_connects_bad(); + + let run_session = |hello: Vec, send_mtproto: bool| { + let (server_side, mut client_side) = duplex(131072); + let config = harness.config.clone(); + let stats = harness.stats.clone(); + let upstream = harness.upstream_manager.clone(); + let replay = harness.replay_checker.clone(); + let pool = harness.buffer_pool.clone(); + let rng = harness.rng.clone(); + let route = harness.route_runtime.clone(); + let ipt = harness.ip_tracker.clone(); + let beob = harness.beobachten.clone(); + let invalid_mtproto_record = invalid_mtproto_record.clone(); + let trailing_record = trailing_record.clone(); + async move { + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.195:56095".parse().unwrap(), + config, + stats, + upstream, + replay, + pool, + rng, + None, + route, + None, + ipt, + beob, + false, + )); + + client_side.write_all(&hello).await.unwrap(); + if send_mtproto { + let mut head = [0u8; 5]; + client_side.read_exact(&mut head).await.unwrap(); + assert_eq!(head[0], 0x16); + client_side.write_all(&invalid_mtproto_record).await.unwrap(); + client_side.write_all(&trailing_record).await.unwrap(); + } else { + let mut one = [0u8; 1]; + let no_server_hello = tokio::time::timeout( + Duration::from_millis(300), + client_side.read_exact(&mut one), + ) + .await; + assert!( + no_server_hello.is_err() || no_server_hello.unwrap().is_err(), + "replayed TLS hello must not receive authenticated TLS ServerHello" + ); + } + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + } + }; + + run_session(replayed_hello.clone(), true).await; + run_session(replayed_hello.clone(), false).await; + + tokio::time::timeout(Duration::from_secs(5), accept_task) + .await + .unwrap() + .unwrap(); + + let bad_after = stats_for_assert.get_connects_bad(); + assert!( + bad_after >= bad_before + 2, + "both invalid-mtproto and replayed-tls paths must increment bad connection accounting" + ); +} + +#[tokio::test] +async fn tls_bad_mtproto_fallback_large_multi_record_chaos_under_backpressure() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x90u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 18, 600, 0x60); + let invalid_mtproto_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); + + let a = wrap_tls_application_data(&vec![0xA1; 2048]); + let b = wrap_tls_application_data(&vec![0xB2; 3072]); + let c = wrap_tls_application_data(&vec![0xC3; 1536]); + let expected = [a.clone(), b.clone(), c.clone()].concat(); + let expected_payload = expected.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got = vec![0u8; expected_payload.len()]; + let mut pos = 0usize; + while pos < got.len() { + let step = (pos % 257).max(1).min(got.len() - pos); + stream.read_exact(&mut got[pos..pos + step]).await.unwrap(); + pos += step; + tokio::time::sleep(Duration::from_millis(1)).await; + } + assert_eq!(got, expected_payload); + }); + + let harness = build_harness("90909090909090909090909090909090", backend_addr.port()); + let (server_side, mut client_side) = duplex(262144); + let peer: SocketAddr = "198.51.100.196:56096".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut tls_response_head = [0u8; 5]; + client_side.read_exact(&mut tls_response_head).await.unwrap(); + assert_eq!(tls_response_head[0], 0x16); + + client_side.write_all(&invalid_mtproto_record).await.unwrap(); + + let chaos = [5usize, 23, 11, 47, 3, 19, 29, 13, 7, 31]; + for record in [&a, &b, &c] { + let mut pos = 0usize; + let mut idx = 0usize; + while pos < record.len() { + let step = chaos[idx % chaos.len()]; + let end = (pos + step).min(record.len()); + client_side.write_all(&record[pos..end]).await.unwrap(); + pos = end; + idx += 1; + } + } + + tokio::time::timeout(Duration::from_secs(5), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn tls_bad_mtproto_fallback_interleaved_control_and_application_records_verbatim() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x91u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 19, 600, 0x61); + let invalid_mtproto_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); + + let ccs = wrap_tls_record(0x14, &[0x01]); + let app = wrap_tls_application_data(b"opaque"); + let alert = wrap_tls_record(0x15, &[0x01, 0x00]); + let expected = [ccs.clone(), app.clone(), alert.clone()].concat(); + let expected_records = expected.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got = vec![0u8; expected_records.len()]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, expected_records); + }); + + let harness = build_harness("91919191919191919191919191919191", backend_addr.port()); + let (server_side, mut client_side) = duplex(131072); + let peer: SocketAddr = "198.51.100.197:56097".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut tls_response_head = [0u8; 5]; + client_side.read_exact(&mut tls_response_head).await.unwrap(); + assert_eq!(tls_response_head[0], 0x16); + + client_side.write_all(&invalid_mtproto_record).await.unwrap(); + client_side.write_all(&ccs).await.unwrap(); + client_side.write_all(&app).await.unwrap(); + client_side.write_all(&alert).await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn tls_bad_mtproto_fallback_many_short_sessions_with_chaos_no_cross_leak() { + let sessions = 40usize; + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let mut expected_records = std::collections::HashSet::new(); + let secret = [0x92u8; 16]; + for idx in 0..sessions { + let _hello = make_valid_tls_client_hello(&secret, idx as u32 + 200, 600, 0x70 + idx as u8); + let payload = vec![idx as u8; 33 + (idx % 17)]; + let record = wrap_tls_application_data(&payload); + expected_records.insert(record); + } + + let accept_task = tokio::spawn(async move { + let mut remaining = expected_records; + for idx in 0..sessions { + let (mut stream, _) = listener.accept().await.unwrap(); + + let _ = idx; + let mut header = [0u8; 5]; + stream.read_exact(&mut header).await.unwrap(); + assert_eq!(header[0], TLS_RECORD_APPLICATION); + + let len = u16::from_be_bytes([header[3], header[4]]) as usize; + let mut record = vec![0u8; 5 + len]; + record[..5].copy_from_slice(&header); + stream.read_exact(&mut record[5..]).await.unwrap(); + + assert!( + remaining.remove(&record), + "unexpected trailing TLS record in short-session chaos test" + ); + } + + assert!(remaining.is_empty(), "all expected sessions must be consumed exactly once"); + }); + + let mut tasks = Vec::with_capacity(sessions); + for idx in 0..sessions { + let harness = build_harness("92929292929292929292929292929292", backend_addr.port()); + let secret = [0x92u8; 16]; + let client_hello = + make_valid_tls_client_hello(&secret, idx as u32 + 200, 600, 0x70 + idx as u8); + let invalid_mtproto_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); + let payload = vec![idx as u8; 33 + (idx % 17)]; + let record = wrap_tls_application_data(&payload); + + let peer: SocketAddr = format!("198.51.100.198:{}", 58000 + idx as u16) + .parse() + .unwrap(); + + tasks.push(tokio::spawn(async move { + let (server_side, mut client_side) = duplex(131072); + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut head = [0u8; 5]; + client_side.read_exact(&mut head).await.unwrap(); + assert_eq!(head[0], 0x16); + + client_side.write_all(&invalid_mtproto_record).await.unwrap(); + for chunk in record.chunks((idx % 9) + 1) { + client_side.write_all(chunk).await.unwrap(); + } + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + })); + } + + for task in tasks { + task.await.unwrap(); + } + + tokio::time::timeout(Duration::from_secs(6), accept_task) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn tls_bad_mtproto_fallback_coalesced_tail_small_is_forwarded_as_tls_record() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0xA1u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 300, 600, 0x31); + let coalesced_tail = b"coalesced-tail-small".to_vec(); + let coalesced_record = wrap_invalid_mtproto_with_coalesced_tail(&coalesced_tail); + let expected_tail_record = wrap_tls_application_data(&coalesced_tail); + let expected_tail = expected_tail_record.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got_tail = vec![0u8; expected_tail.len()]; + stream.read_exact(&mut got_tail).await.unwrap(); + assert_eq!(got_tail, expected_tail); + }); + + let harness = build_harness("a1a1a1a1a1a1a1a1a1a1a1a1a1a1a1a1", backend_addr.port()); + let (server_side, mut client_side) = duplex(131072); + + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.210:56110".parse().unwrap(), + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut head = [0u8; 5]; + client_side.read_exact(&mut head).await.unwrap(); + assert_eq!(head[0], 0x16); + read_and_discard_tls_record_body(&mut client_side, head).await; + + client_side.write_all(&coalesced_record).await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn tls_bad_mtproto_fallback_coalesced_tail_large_is_forwarded_as_tls_record() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0xA2u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 301, 600, 0x32); + let coalesced_tail = vec![0xAB; 4096]; + let coalesced_record = wrap_invalid_mtproto_with_coalesced_tail(&coalesced_tail); + let expected_tail_record = wrap_tls_application_data(&coalesced_tail); + let expected_tail = expected_tail_record.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got_tail = vec![0u8; expected_tail.len()]; + stream.read_exact(&mut got_tail).await.unwrap(); + assert_eq!(got_tail, expected_tail); + }); + + let harness = build_harness("a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2a2", backend_addr.port()); + let (server_side, mut client_side) = duplex(262144); + + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.211:56111".parse().unwrap(), + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut head = [0u8; 5]; + client_side.read_exact(&mut head).await.unwrap(); + assert_eq!(head[0], 0x16); + read_and_discard_tls_record_body(&mut client_side, head).await; + + client_side.write_all(&coalesced_record).await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn tls_bad_mtproto_fallback_coalesced_tail_keeps_order_before_following_record() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0xA3u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 302, 600, 0x33); + let coalesced_tail = b"coalesced-first".to_vec(); + let coalesced_record = wrap_invalid_mtproto_with_coalesced_tail(&coalesced_tail); + let expected_tail_record = wrap_tls_application_data(&coalesced_tail); + let following_record = wrap_tls_application_data(b"following-record"); + let expected_concat = [expected_tail_record.clone(), following_record.clone()].concat(); + let expected_records = expected_concat.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got_records = vec![0u8; expected_records.len()]; + stream.read_exact(&mut got_records).await.unwrap(); + assert_eq!(got_records, expected_records); + }); + + let harness = build_harness("a3a3a3a3a3a3a3a3a3a3a3a3a3a3a3a3", backend_addr.port()); + let (server_side, mut client_side) = duplex(131072); + + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.212:56112".parse().unwrap(), + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut head = [0u8; 5]; + client_side.read_exact(&mut head).await.unwrap(); + assert_eq!(head[0], 0x16); + read_and_discard_tls_record_body(&mut client_side, head).await; + + client_side.write_all(&coalesced_record).await.unwrap(); + client_side.write_all(&following_record).await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn tls_bad_mtproto_fallback_coalesced_tail_fragmented_client_write_is_forwarded() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0xA4u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 303, 600, 0x34); + let coalesced_tail = vec![0xCD; 1536]; + let coalesced_record = wrap_invalid_mtproto_with_coalesced_tail(&coalesced_tail); + let expected_tail_record = wrap_tls_application_data(&coalesced_tail); + let expected_tail = expected_tail_record.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got_tail = vec![0u8; expected_tail.len()]; + stream.read_exact(&mut got_tail).await.unwrap(); + assert_eq!(got_tail, expected_tail); + }); + + let harness = build_harness("a4a4a4a4a4a4a4a4a4a4a4a4a4a4a4a4", backend_addr.port()); + let (server_side, mut client_side) = duplex(262144); + + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.213:56113".parse().unwrap(), + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut head = [0u8; 5]; + client_side.read_exact(&mut head).await.unwrap(); + assert_eq!(head[0], 0x16); + read_and_discard_tls_record_body(&mut client_side, head).await; + + let steps = [7usize, 3, 13, 5, 11, 2, 17, 19]; + let mut offset = 0usize; + let mut i = 0usize; + while offset < coalesced_record.len() { + let step = steps[i % steps.len()]; + let end = (offset + step).min(coalesced_record.len()); + client_side + .write_all(&coalesced_record[offset..end]) + .await + .unwrap(); + offset = end; + i += 1; + } + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn tls_bad_mtproto_fallback_coalesced_tail_max_payload_is_forwarded() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0xA5u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 304, 600, 0x35); + let coalesced_tail = vec![0xEF; MAX_TLS_CIPHERTEXT_SIZE - HANDSHAKE_LEN]; + let coalesced_record = wrap_invalid_mtproto_with_coalesced_tail(&coalesced_tail); + let expected_tail_record = wrap_tls_application_data(&coalesced_tail); + let expected_tail = expected_tail_record.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got_tail = vec![0u8; expected_tail.len()]; + stream.read_exact(&mut got_tail).await.unwrap(); + assert_eq!(got_tail, expected_tail); + }); + + let harness = build_harness("a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5", backend_addr.port()); + let (server_side, mut client_side) = duplex(262144); + + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.214:56114".parse().unwrap(), + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut head = [0u8; 5]; + client_side.read_exact(&mut head).await.unwrap(); + assert_eq!(head[0], 0x16); + read_and_discard_tls_record_body(&mut client_side, head).await; + + client_side.write_all(&coalesced_record).await.unwrap(); + + tokio::time::timeout(Duration::from_secs(5), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn blackhat_coalesced_tail_identical_following_record_must_not_duplicate_or_reorder() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0xB1u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 400, 600, 0x21); + let tail = b"same-payload-record".to_vec(); + let coalesced_record = wrap_invalid_mtproto_with_coalesced_tail(&tail); + let tail_record = wrap_tls_application_data(&tail); + let expected = [tail_record.clone(), tail_record.clone()].concat(); + let expected_payload = expected.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got = vec![0u8; expected_payload.len()]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, expected_payload); + + let mut tail = [0u8; 1]; + let n = stream.read(&mut tail).await.unwrap(); + assert_eq!(n, 0, "fallback stream must not emit extra bytes"); + }); + + let harness = build_harness("b1b1b1b1b1b1b1b1b1b1b1b1b1b1b1b1", backend_addr.port()); + let (server_side, mut client_side) = duplex(131072); + + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.220:56120".parse().unwrap(), + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut head = [0u8; 5]; + client_side.read_exact(&mut head).await.unwrap(); + assert_eq!(head[0], 0x16); + read_and_discard_tls_record_body(&mut client_side, head).await; + + client_side.write_all(&coalesced_record).await.unwrap(); + client_side.write_all(&tail_record).await.unwrap(); + client_side.shutdown().await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn blackhat_coalesced_tail_tls_header_looking_bytes_must_stay_payload() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0xB2u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 401, 600, 0x22); + let mut tail = vec![0x16, 0x03, 0x03, 0x00, 0x10]; + tail.extend_from_slice(b"not-a-real-record-boundary"); + let coalesced_record = wrap_invalid_mtproto_with_coalesced_tail(&tail); + let expected_tail_record = wrap_tls_application_data(&tail); + let expected_tail = expected_tail_record.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got_tail = vec![0u8; expected_tail.len()]; + stream.read_exact(&mut got_tail).await.unwrap(); + assert_eq!(got_tail, expected_tail); + }); + + let harness = build_harness("b2b2b2b2b2b2b2b2b2b2b2b2b2b2b2b2", backend_addr.port()); + let (server_side, mut client_side) = duplex(131072); + + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.221:56121".parse().unwrap(), + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut head = [0u8; 5]; + client_side.read_exact(&mut head).await.unwrap(); + assert_eq!(head[0], 0x16); + read_and_discard_tls_record_body(&mut client_side, head).await; + + client_side.write_all(&coalesced_record).await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn blackhat_coalesced_tail_client_half_close_must_not_truncate_prepended_record() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0xB3u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 402, 600, 0x23); + let tail = vec![0xAA; 3072]; + let coalesced_record = wrap_invalid_mtproto_with_coalesced_tail(&tail); + let expected_tail_record = wrap_tls_application_data(&tail); + let expected_tail = expected_tail_record.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got_tail = vec![0u8; expected_tail.len()]; + stream.read_exact(&mut got_tail).await.unwrap(); + assert_eq!(got_tail, expected_tail); + + let mut one = [0u8; 1]; + let n = stream.read(&mut one).await.unwrap(); + assert_eq!(n, 0, "backend must observe EOF after client half-close"); + }); + + let harness = build_harness("b3b3b3b3b3b3b3b3b3b3b3b3b3b3b3b3", backend_addr.port()); + let (server_side, mut client_side) = duplex(262144); + + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.222:56122".parse().unwrap(), + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut head = [0u8; 5]; + client_side.read_exact(&mut head).await.unwrap(); + assert_eq!(head[0], 0x16); + read_and_discard_tls_record_body(&mut client_side, head).await; + + client_side.write_all(&coalesced_record).await.unwrap(); + client_side.shutdown().await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn blackhat_coalesced_tail_multi_session_no_cross_bleed_under_churn() { + let sessions = 16usize; + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let mut expected = std::collections::HashSet::new(); + let secret = [0xB4u8; 16]; + for idx in 0..sessions { + let _hello = make_valid_tls_client_hello(&secret, 450 + idx as u32, 600, 0x40 + idx as u8); + let tail = vec![idx as u8; 17 + idx]; + expected.insert(wrap_tls_application_data(&tail)); + } + + let accept_task = tokio::spawn(async move { + let mut remaining = expected; + for _ in 0..sessions { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut header = [0u8; 5]; + stream.read_exact(&mut header).await.unwrap(); + assert_eq!(header[0], TLS_RECORD_APPLICATION); + + let len = u16::from_be_bytes([header[3], header[4]]) as usize; + let mut record = vec![0u8; 5 + len]; + record[..5].copy_from_slice(&header); + stream.read_exact(&mut record[5..]).await.unwrap(); + + assert!( + remaining.remove(&record), + "unexpected record or duplicated session routing" + ); + } + assert!(remaining.is_empty(), "all sessions must map one-to-one"); + }); + + let mut tasks = Vec::with_capacity(sessions); + for idx in 0..sessions { + let harness = build_harness("b4b4b4b4b4b4b4b4b4b4b4b4b4b4b4b4", backend_addr.port()); + let hello = make_valid_tls_client_hello(&secret, 450 + idx as u32, 600, 0x40 + idx as u8); + let tail = vec![idx as u8; 17 + idx]; + let coalesced_record = wrap_invalid_mtproto_with_coalesced_tail(&tail); + let peer: SocketAddr = format!("198.51.100.223:{}", 56200 + idx as u16) + .parse() + .unwrap(); + + tasks.push(tokio::spawn(async move { + let (server_side, mut client_side) = duplex(131072); + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&hello).await.unwrap(); + let mut head = [0u8; 5]; + client_side.read_exact(&mut head).await.unwrap(); + assert_eq!(head[0], 0x16); + read_and_discard_tls_record_body(&mut client_side, head).await; + + for chunk in coalesced_record.chunks((idx % 7) + 1) { + client_side.write_all(chunk).await.unwrap(); + } + client_side.shutdown().await.unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + })); + } + + for task in tasks { + task.await.unwrap(); + } + + tokio::time::timeout(Duration::from_secs(6), accept_task) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn blackhat_coalesced_tail_single_byte_tail_is_preserved() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0xC1u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 500, 600, 0x11); + let tail = vec![0x7F]; + let coalesced_record = wrap_invalid_mtproto_with_coalesced_tail(&tail); + let expected_tail = wrap_tls_application_data(&tail); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got_tail = vec![0u8; expected_tail.len()]; + stream.read_exact(&mut got_tail).await.unwrap(); + assert_eq!(got_tail, expected_tail); + }); + + let harness = build_harness("c1c1c1c1c1c1c1c1c1c1c1c1c1c1c1c1", backend_addr.port()); + let (server_side, mut client_side) = duplex(65536); + + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.230:56130".parse().unwrap(), + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut head = [0u8; 5]; + client_side.read_exact(&mut head).await.unwrap(); + assert_eq!(head[0], 0x16); + read_and_discard_tls_record_body(&mut client_side, head).await; + client_side.write_all(&coalesced_record).await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn blackhat_coalesced_tail_exact_tls_header_size_payload_is_preserved() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0xC2u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 501, 600, 0x12); + let tail = vec![0xAA, 0xBB, 0xCC, 0xDD, 0xEE]; + let coalesced_record = wrap_invalid_mtproto_with_coalesced_tail(&tail); + let expected_tail = wrap_tls_application_data(&tail); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got_tail = vec![0u8; expected_tail.len()]; + stream.read_exact(&mut got_tail).await.unwrap(); + assert_eq!(got_tail, expected_tail); + }); + + let harness = build_harness("c2c2c2c2c2c2c2c2c2c2c2c2c2c2c2c2", backend_addr.port()); + let (server_side, mut client_side) = duplex(65536); + + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.231:56131".parse().unwrap(), + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut head = [0u8; 5]; + client_side.read_exact(&mut head).await.unwrap(); + assert_eq!(head[0], 0x16); + read_and_discard_tls_record_body(&mut client_side, head).await; + client_side.write_all(&coalesced_record).await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn blackhat_coalesced_tail_all_zero_payload_is_preserved() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0xC3u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 502, 600, 0x13); + let tail = vec![0u8; 2048]; + let coalesced_record = wrap_invalid_mtproto_with_coalesced_tail(&tail); + let expected_tail = wrap_tls_application_data(&tail); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got_tail = vec![0u8; expected_tail.len()]; + stream.read_exact(&mut got_tail).await.unwrap(); + assert_eq!(got_tail, expected_tail); + }); + + let harness = build_harness("c3c3c3c3c3c3c3c3c3c3c3c3c3c3c3c3", backend_addr.port()); + let (server_side, mut client_side) = duplex(131072); + + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.232:56132".parse().unwrap(), + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut head = [0u8; 5]; + client_side.read_exact(&mut head).await.unwrap(); + assert_eq!(head[0], 0x16); + read_and_discard_tls_record_body(&mut client_side, head).await; + client_side.write_all(&coalesced_record).await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn blackhat_coalesced_tail_following_control_records_are_not_mutated() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0xC4u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 503, 600, 0x14); + let tail = b"tail-before-controls".to_vec(); + let coalesced_record = wrap_invalid_mtproto_with_coalesced_tail(&tail); + let tail_record = wrap_tls_application_data(&tail); + let ccs = wrap_tls_record(0x14, &[0x01]); + let alert = wrap_tls_record(0x15, &[0x01, 0x00]); + let app = wrap_tls_application_data(b"control-final-app"); + let expected = [tail_record, ccs.clone(), alert.clone(), app.clone()].concat(); + let expected_payload = expected.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got_payload = vec![0u8; expected_payload.len()]; + stream.read_exact(&mut got_payload).await.unwrap(); + assert_eq!(got_payload, expected_payload); + }); + + let harness = build_harness("c4c4c4c4c4c4c4c4c4c4c4c4c4c4c4c4", backend_addr.port()); + let (server_side, mut client_side) = duplex(131072); + + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.233:56133".parse().unwrap(), + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut head = [0u8; 5]; + client_side.read_exact(&mut head).await.unwrap(); + assert_eq!(head[0], 0x16); + read_and_discard_tls_record_body(&mut client_side, head).await; + + client_side.write_all(&coalesced_record).await.unwrap(); + client_side.write_all(&ccs).await.unwrap(); + client_side.write_all(&alert).await.unwrap(); + client_side.write_all(&app).await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn blackhat_coalesced_tail_then_following_records_fragmented_chaos_stays_ordered() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0xC5u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 504, 600, 0x15); + let tail = vec![0xAC; 900]; + let coalesced_record = wrap_invalid_mtproto_with_coalesced_tail(&tail); + let tail_record = wrap_tls_application_data(&tail); + let r1 = wrap_tls_application_data(b"r1"); + let r2 = wrap_tls_application_data(&vec![0xDD; 257]); + let expected = [tail_record, r1.clone(), r2.clone()].concat(); + let expected_payload = expected.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got_payload = vec![0u8; expected_payload.len()]; + stream.read_exact(&mut got_payload).await.unwrap(); + assert_eq!(got_payload, expected_payload); + }); + + let harness = build_harness("c5c5c5c5c5c5c5c5c5c5c5c5c5c5c5c5", backend_addr.port()); + let (server_side, mut client_side) = duplex(262144); + + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.234:56134".parse().unwrap(), + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut head = [0u8; 5]; + client_side.read_exact(&mut head).await.unwrap(); + assert_eq!(head[0], 0x16); + read_and_discard_tls_record_body(&mut client_side, head).await; + + let pattern = [3usize, 1, 5, 2, 7, 11, 13, 17, 19]; + let mut idx = 0usize; + for data in [&coalesced_record, &r1, &r2] { + let mut pos = 0usize; + while pos < data.len() { + let step = pattern[idx % pattern.len()]; + idx += 1; + let end = (pos + step).min(data.len()); + client_side.write_all(&data[pos..end]).await.unwrap(); + pos = end; + } + } + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn blackhat_coalesced_tail_backend_response_integrity_after_fallback() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0xC6u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 505, 600, 0x16); + let tail = b"coalesced-request-body".to_vec(); + let coalesced_record = wrap_invalid_mtproto_with_coalesced_tail(&tail); + let expected_tail = wrap_tls_application_data(&tail); + let backend_response = b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\n\r\n".to_vec(); + let expected_resp = backend_response.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got_tail = vec![0u8; expected_tail.len()]; + stream.read_exact(&mut got_tail).await.unwrap(); + assert_eq!(got_tail, expected_tail); + + stream.write_all(&expected_resp).await.unwrap(); + }); + + let harness = build_harness("c6c6c6c6c6c6c6c6c6c6c6c6c6c6c6c6", backend_addr.port()); + let (server_side, mut client_side) = duplex(131072); + + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.235:56135".parse().unwrap(), + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut head = [0u8; 5]; + client_side.read_exact(&mut head).await.unwrap(); + assert_eq!(head[0], 0x16); + read_and_discard_tls_record_body(&mut client_side, head).await; + client_side.write_all(&coalesced_record).await.unwrap(); + + let mut observed = Vec::new(); + let mut buf = [0u8; 512]; + let mut found = false; + for _ in 0..32 { + let n = tokio::time::timeout(Duration::from_millis(200), client_side.read(&mut buf)) + .await + .unwrap() + .unwrap(); + if n == 0 { + break; + } + observed.extend_from_slice(&buf[..n]); + if observed + .windows(backend_response.len()) + .any(|w| w == backend_response.as_slice()) + { + found = true; + break; + } + } + assert!( + found, + "backend plaintext response must be observable on client stream after fallback" + ); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn blackhat_coalesced_tail_connects_bad_increments_exactly_once() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0xC7u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 506, 600, 0x17); + let tail = b"count-bad-once".to_vec(); + let coalesced_record = wrap_invalid_mtproto_with_coalesced_tail(&tail); + let expected_tail = wrap_tls_application_data(&tail); + + let harness = build_harness("c7c7c7c7c7c7c7c7c7c7c7c7c7c7c7c7", backend_addr.port()); + let stats = harness.stats.clone(); + let bad_before = stats.get_connects_bad(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got_tail = vec![0u8; expected_tail.len()]; + stream.read_exact(&mut got_tail).await.unwrap(); + assert_eq!(got_tail, expected_tail); + }); + + let (server_side, mut client_side) = duplex(131072); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.236:56136".parse().unwrap(), + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut head = [0u8; 5]; + client_side.read_exact(&mut head).await.unwrap(); + assert_eq!(head[0], 0x16); + read_and_discard_tls_record_body(&mut client_side, head).await; + client_side.write_all(&coalesced_record).await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + + let bad_after = stats.get_connects_bad(); + assert_eq!( + bad_after, + bad_before + 1, + "invalid MTProto after valid TLS must increment connects_bad exactly once" + ); +} + +#[tokio::test] +async fn blackhat_coalesced_tail_parallel_32_sessions_no_cross_bleed() { + let sessions = 32usize; + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let mut expected = std::collections::HashSet::new(); + let secret = [0xC8u8; 16]; + for idx in 0..sessions { + let _hello = make_valid_tls_client_hello(&secret, 550 + idx as u32, 600, 0x20 + idx as u8); + let tail = vec![idx as u8; 48 + (idx % 11)]; + expected.insert(wrap_tls_application_data(&tail)); + } + + let accept_task = tokio::spawn(async move { + let mut remaining = expected; + for _ in 0..sessions { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut header = [0u8; 5]; + stream.read_exact(&mut header).await.unwrap(); + assert_eq!(header[0], TLS_RECORD_APPLICATION); + + let len = u16::from_be_bytes([header[3], header[4]]) as usize; + let mut record = vec![0u8; 5 + len]; + record[..5].copy_from_slice(&header); + stream.read_exact(&mut record[5..]).await.unwrap(); + + assert!( + remaining.remove(&record), + "session mixup detected in parallel-32 blackhat test" + ); + } + assert!(remaining.is_empty(), "all expected sessions must be consumed"); + }); + + let mut tasks = Vec::with_capacity(sessions); + for idx in 0..sessions { + let harness = build_harness("c8c8c8c8c8c8c8c8c8c8c8c8c8c8c8c8", backend_addr.port()); + let hello = make_valid_tls_client_hello(&secret, 550 + idx as u32, 600, 0x20 + idx as u8); + let tail = vec![idx as u8; 48 + (idx % 11)]; + let coalesced_record = wrap_invalid_mtproto_with_coalesced_tail(&tail); + let peer: SocketAddr = format!("198.51.100.237:{}", 56300 + idx as u16) + .parse() + .unwrap(); + + tasks.push(tokio::spawn(async move { + let (server_side, mut client_side) = duplex(131072); + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&hello).await.unwrap(); + let mut head = [0u8; 5]; + client_side.read_exact(&mut head).await.unwrap(); + assert_eq!(head[0], 0x16); + read_and_discard_tls_record_body(&mut client_side, head).await; + + let chunk = (idx % 13) + 1; + for part in coalesced_record.chunks(chunk) { + client_side.write_all(part).await.unwrap(); + } + client_side.shutdown().await.unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + })); + } + + for task in tasks { + task.await.unwrap(); + } + + tokio::time::timeout(Duration::from_secs(6), accept_task) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn blackhat_coalesced_tail_repeated_tls_like_prefixes_are_preserved() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0xC9u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 507, 600, 0x18); + let mut tail = Vec::new(); + for _ in 0..64 { + tail.extend_from_slice(&[0x16, 0x03, 0x03, 0x00, 0x20]); + } + tail.extend_from_slice(b"suffix-data"); + let coalesced_record = wrap_invalid_mtproto_with_coalesced_tail(&tail); + let expected_tail = wrap_tls_application_data(&tail); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got_tail = vec![0u8; expected_tail.len()]; + stream.read_exact(&mut got_tail).await.unwrap(); + assert_eq!(got_tail, expected_tail); + }); + + let harness = build_harness("c9c9c9c9c9c9c9c9c9c9c9c9c9c9c9c9", backend_addr.port()); + let (server_side, mut client_side) = duplex(131072); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.238:56138".parse().unwrap(), + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut head = [0u8; 5]; + client_side.read_exact(&mut head).await.unwrap(); + assert_eq!(head[0], 0x16); + read_and_discard_tls_record_body(&mut client_side, head).await; + client_side.write_all(&coalesced_record).await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn blackhat_coalesced_tail_drop_after_write_still_delivers_prepended_record() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0xCAu8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 508, 600, 0x19); + let tail = vec![0xBE; 1024]; + let coalesced_record = wrap_invalid_mtproto_with_coalesced_tail(&tail); + let expected_tail = wrap_tls_application_data(&tail); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got_tail = vec![0u8; expected_tail.len()]; + stream.read_exact(&mut got_tail).await.unwrap(); + assert_eq!(got_tail, expected_tail); + }); + + let harness = build_harness("cacacacacacacacacacacacacacacaca", backend_addr.port()); + let (server_side, mut client_side) = duplex(131072); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.239:56139".parse().unwrap(), + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut head = [0u8; 5]; + client_side.read_exact(&mut head).await.unwrap(); + assert_eq!(head[0], 0x16); + read_and_discard_tls_record_body(&mut client_side, head).await; + client_side.write_all(&coalesced_record).await.unwrap(); + drop(client_side); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn blackhat_coalesced_tail_zero_following_record_after_coalesced_is_not_invented() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0xCBu8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 509, 600, 0x1A); + let tail = b"terminal-tail".to_vec(); + let coalesced_record = wrap_invalid_mtproto_with_coalesced_tail(&tail); + let expected_tail = wrap_tls_application_data(&tail); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got_tail = vec![0u8; expected_tail.len()]; + stream.read_exact(&mut got_tail).await.unwrap(); + assert_eq!(got_tail, expected_tail); + + let mut one = [0u8; 1]; + let n = stream.read(&mut one).await.unwrap(); + assert_eq!(n, 0, "no synthetic extra record must appear"); + }); + + let harness = build_harness("cbcbcbcbcbcbcbcbcbcbcbcbcbcbcbcb", backend_addr.port()); + let (server_side, mut client_side) = duplex(131072); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.240:56140".parse().unwrap(), + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut head = [0u8; 5]; + client_side.read_exact(&mut head).await.unwrap(); + assert_eq!(head[0], 0x16); + read_and_discard_tls_record_body(&mut client_side, head).await; + client_side.write_all(&coalesced_record).await.unwrap(); + client_side.shutdown().await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn blackhat_coalesced_tail_light_fuzz_mixed_followup_records_stay_byte_exact() { + let mut seed = 0xA11C_E2E5_F00D_BAADu64; + + for case in 0..24u32 { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + let tail_len = (seed as usize % 1536) + 1; + let mut tail = vec![0u8; tail_len]; + for (i, b) in tail.iter_mut().enumerate() { + *b = (seed as u8).wrapping_add(i as u8).wrapping_mul(13); + } + + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + let follow_type = match seed & 0x3 { + 0 => TLS_RECORD_APPLICATION, + 1 => TLS_RECORD_ALERT, + 2 => TLS_RECORD_CHANGE_CIPHER, + _ => TLS_RECORD_HANDSHAKE, + }; + let follow_len = (seed as usize % 96) + (case as usize % 3); + let mut follow_payload = vec![0u8; follow_len]; + for (i, b) in follow_payload.iter_mut().enumerate() { + *b = (case as u8).wrapping_mul(29).wrapping_add(i as u8); + } + + let secret = [0xD1u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 600 + case, 600, 0x33); + let coalesced_record = wrap_invalid_mtproto_with_coalesced_tail(&tail); + let expected_tail = wrap_tls_application_data(&tail); + let follow_record = wrap_tls_record(follow_type, &follow_payload); + let expected_wire = [expected_tail.clone(), follow_record.clone()].concat(); + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got = vec![0u8; expected_wire.len()]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, expected_wire); + }); + + let harness = build_harness("d1d1d1d1d1d1d1d1d1d1d1d1d1d1d1d1", backend_addr.port()); + let (server_side, mut client_side) = duplex(262144); + let peer: SocketAddr = format!("198.51.100.250:{}", 57000 + case as u16) + .parse() + .unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut head = [0u8; 5]; + client_side.read_exact(&mut head).await.unwrap(); + assert_eq!(head[0], 0x16); + read_and_discard_tls_record_body(&mut client_side, head).await; + + let mut local_seed = seed ^ 0x55AA_55AA_1234_5678; + for data in [&coalesced_record, &follow_record] { + let mut pos = 0usize; + while pos < data.len() { + local_seed ^= local_seed << 7; + local_seed ^= local_seed >> 9; + local_seed ^= local_seed << 8; + let step = ((local_seed as usize % 17) + 1).min(data.len() - pos); + let end = pos + step; + client_side.write_all(&data[pos..end]).await.unwrap(); + pos = end; + } + } + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + } +} diff --git a/src/proxy/tests/direct_relay_business_logic_tests.rs b/src/proxy/tests/direct_relay_business_logic_tests.rs new file mode 100644 index 0000000..166518e --- /dev/null +++ b/src/proxy/tests/direct_relay_business_logic_tests.rs @@ -0,0 +1,51 @@ +use super::*; +use crate::protocol::constants::{TG_DATACENTER_PORT, TG_DATACENTERS_V4, TG_DATACENTERS_V6}; +use std::net::SocketAddr; + +#[test] +fn business_scope_hint_accepts_exact_boundary_length() { + let value = format!("scope_{}", "a".repeat(MAX_SCOPE_HINT_LEN)); + assert_eq!(validated_scope_hint(&value), Some("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa")); +} + +#[test] +fn business_scope_hint_rejects_missing_prefix_even_when_charset_is_valid() { + assert_eq!(validated_scope_hint("alpha-01"), None); +} + +#[test] +fn business_known_dc_uses_ipv4_table_by_default() { + let cfg = ProxyConfig::default(); + let resolved = get_dc_addr_static(2, &cfg).expect("known dc must resolve"); + let expected = SocketAddr::new(TG_DATACENTERS_V4[1], TG_DATACENTER_PORT); + assert_eq!(resolved, expected); +} + +#[test] +fn business_negative_dc_maps_by_absolute_value() { + let cfg = ProxyConfig::default(); + let resolved = get_dc_addr_static(-3, &cfg).expect("negative dc index must map by absolute value"); + let expected = SocketAddr::new(TG_DATACENTERS_V4[2], TG_DATACENTER_PORT); + assert_eq!(resolved, expected); +} + +#[test] +fn business_known_dc_uses_ipv6_table_when_preferred_and_enabled() { + let mut cfg = ProxyConfig::default(); + cfg.network.prefer = 6; + cfg.network.ipv6 = Some(true); + + let resolved = get_dc_addr_static(1, &cfg).expect("known dc must resolve on ipv6 path"); + let expected = SocketAddr::new(TG_DATACENTERS_V6[0], TG_DATACENTER_PORT); + assert_eq!(resolved, expected); +} + +#[test] +fn business_unknown_dc_uses_configured_default_dc_when_in_range() { + let mut cfg = ProxyConfig::default(); + cfg.default_dc = Some(4); + + let resolved = get_dc_addr_static(29_999, &cfg).expect("unknown dc must resolve to configured default"); + let expected = SocketAddr::new(TG_DATACENTERS_V4[3], TG_DATACENTER_PORT); + assert_eq!(resolved, expected); +} diff --git a/src/proxy/tests/direct_relay_common_mistakes_tests.rs b/src/proxy/tests/direct_relay_common_mistakes_tests.rs new file mode 100644 index 0000000..ef40f37 --- /dev/null +++ b/src/proxy/tests/direct_relay_common_mistakes_tests.rs @@ -0,0 +1,98 @@ +use super::*; +use crate::protocol::constants::{TG_DATACENTER_PORT, TG_DATACENTERS_V4}; +use std::collections::HashSet; +use std::net::SocketAddr; +use std::sync::Mutex; + +#[test] +fn common_invalid_override_entries_fallback_to_static_table() { + let mut cfg = ProxyConfig::default(); + cfg.dc_overrides.insert( + "2".to_string(), + vec!["bad-address".to_string(), "still-bad".to_string()], + ); + + let resolved = get_dc_addr_static(2, &cfg).expect("fallback to static table must still resolve"); + let expected = SocketAddr::new(TG_DATACENTERS_V4[1], TG_DATACENTER_PORT); + assert_eq!(resolved, expected); +} + +#[test] +fn common_prefer_v6_with_only_ipv4_override_uses_override_instead_of_ignoring_it() { + let mut cfg = ProxyConfig::default(); + cfg.network.prefer = 6; + cfg.network.ipv6 = Some(true); + cfg.dc_overrides + .insert("3".to_string(), vec!["203.0.113.203:443".to_string()]); + + let resolved = get_dc_addr_static(3, &cfg).expect("ipv4 override must be used if no ipv6 override exists"); + assert_eq!(resolved, "203.0.113.203:443".parse::().unwrap()); +} + +#[test] +fn common_scope_hint_rejects_unicode_lookalike_characters() { + assert_eq!(validated_scope_hint("scope_аlpha"), None); + assert_eq!(validated_scope_hint("scope_Αlpha"), None); +} + +#[cfg(unix)] +#[test] +fn common_anchored_open_rejects_nul_filename() { + use std::os::unix::ffi::OsStringExt; + + let parent = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!("telemt-direct-relay-nul-{}", std::process::id())); + std::fs::create_dir_all(&parent).expect("parent directory must be creatable"); + + let path = SanitizedUnknownDcLogPath { + resolved_path: parent.join("placeholder.log"), + allowed_parent: parent, + file_name: std::ffi::OsString::from_vec(vec![b'a', 0, b'b']), + }; + + let err = open_unknown_dc_log_append_anchored(&path) + .expect_err("anchored open must fail on NUL in filename"); + assert_eq!(err.kind(), std::io::ErrorKind::InvalidInput); +} + +#[cfg(unix)] +#[test] +fn common_anchored_open_creates_owner_only_file_permissions() { + use std::os::unix::fs::PermissionsExt; + + let parent = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!("telemt-direct-relay-perm-{}", std::process::id())); + std::fs::create_dir_all(&parent).expect("parent directory must be creatable"); + + let sanitized = SanitizedUnknownDcLogPath { + resolved_path: parent.join("unknown-dc.log"), + allowed_parent: parent.clone(), + file_name: std::ffi::OsString::from("unknown-dc.log"), + }; + + let mut file = open_unknown_dc_log_append_anchored(&sanitized) + .expect("anchored open must create regular file"); + use std::io::Write; + writeln!(file, "dc_idx=1").expect("write must succeed"); + + let mode = std::fs::metadata(parent.join("unknown-dc.log")) + .expect("metadata must be readable") + .permissions() + .mode() + & 0o777; + assert_eq!(mode, 0o600); +} + +#[test] +fn common_duplicate_dc_attempts_do_not_consume_unique_slots() { + let set = Mutex::new(HashSet::new()); + + assert!(should_log_unknown_dc_with_set(&set, 100)); + assert!(!should_log_unknown_dc_with_set(&set, 100)); + assert!(should_log_unknown_dc_with_set(&set, 101)); + assert_eq!(set.lock().expect("set lock must be available").len(), 2); +} diff --git a/src/proxy/tests/direct_relay_security_tests.rs b/src/proxy/tests/direct_relay_security_tests.rs new file mode 100644 index 0000000..7c3a51e --- /dev/null +++ b/src/proxy/tests/direct_relay_security_tests.rs @@ -0,0 +1,1567 @@ +use super::*; +use crate::config::{UpstreamConfig, UpstreamType}; +use crate::crypto::{AesCtr, SecureRandom}; +use crate::protocol::constants::ProtoTag; +use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController}; +use crate::stats::Stats; +use crate::stream::{BufferPool, CryptoReader, CryptoWriter}; +use crate::transport::UpstreamManager; +use std::fs; +use std::io::Write; +use std::path::Path; +use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::time::Duration; +use tokio::io::AsyncReadExt; +use tokio::io::duplex; +use tokio::net::TcpListener; +use tokio::time::{timeout, Duration as TokioDuration}; + +fn make_crypto_reader(reader: R) -> CryptoReader +where + R: tokio::io::AsyncRead + Unpin, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoReader::new(reader, AesCtr::new(&key, iv)) +} + +fn make_crypto_writer(writer: W) -> CryptoWriter +where + W: tokio::io::AsyncWrite + Unpin, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) +} + +fn nonempty_line_count(text: &str) -> usize { + text.lines().filter(|line| !line.trim().is_empty()).count() +} + +#[test] +fn unknown_dc_log_is_deduplicated_per_dc_idx() { + let _guard = unknown_dc_test_lock() + .lock() + .expect("unknown dc test lock must be available"); + clear_unknown_dc_log_cache_for_testing(); + + assert!(should_log_unknown_dc(777)); + assert!( + !should_log_unknown_dc(777), + "same unknown dc_idx must not be logged repeatedly" + ); + assert!( + should_log_unknown_dc(778), + "different unknown dc_idx must still be loggable" + ); +} + +#[test] +fn unknown_dc_log_respects_distinct_limit() { + let _guard = unknown_dc_test_lock() + .lock() + .expect("unknown dc test lock must be available"); + clear_unknown_dc_log_cache_for_testing(); + + for dc in 1..=UNKNOWN_DC_LOG_DISTINCT_LIMIT { + assert!( + should_log_unknown_dc(dc as i16), + "expected first-time unknown dc_idx to be loggable" + ); + } + + assert!( + !should_log_unknown_dc(i16::MAX), + "distinct unknown dc_idx entries above limit must not be logged" + ); +} + +#[test] +fn unknown_dc_log_fails_closed_when_dedup_lock_is_poisoned() { + let poisoned = Arc::new(std::sync::Mutex::new(std::collections::HashSet::::new())); + let poisoned_for_thread = poisoned.clone(); + + let _ = std::thread::spawn(move || { + let _guard = poisoned_for_thread + .lock() + .expect("poison setup lock must be available"); + panic!("intentional poison for fail-closed regression"); + }) + .join(); + + assert!( + !should_log_unknown_dc_with_set(poisoned.as_ref(), 4242), + "poisoned unknown-DC dedup lock must fail closed" + ); +} + +#[test] +fn unsafe_unknown_dc_log_path_does_not_consume_dedup_slot() { + let _guard = unknown_dc_test_lock() + .lock() + .expect("unknown dc test lock must be available"); + clear_unknown_dc_log_cache_for_testing(); + + let dc_idx: i16 = 31_123; + let mut cfg = ProxyConfig::default(); + cfg.general.unknown_dc_file_log_enabled = true; + cfg.general.unknown_dc_log_path = Some("../telemt-unknown-dc-unsafe.log".to_string()); + + let _ = get_dc_addr_static(dc_idx, &cfg).expect("fallback routing must still work"); + + assert!( + should_log_unknown_dc(dc_idx), + "rejected unsafe log path must not consume unknown-dc dedup entry" + ); +} + +#[test] +fn stress_unknown_dc_log_concurrent_unique_churn_respects_cap() { + let _guard = unknown_dc_test_lock() + .lock() + .expect("unknown dc test lock must be available"); + clear_unknown_dc_log_cache_for_testing(); + + let accepted = Arc::new(AtomicUsize::new(0)); + let mut workers = Vec::new(); + + // Adversarial model: many concurrent peers rotate dc_idx values rapidly. + for worker in 0..16usize { + let accepted = Arc::clone(&accepted); + workers.push(std::thread::spawn(move || { + let base = (worker * 2048) as i32; + for offset in 0..512i32 { + let raw = base + offset; + let dc = (raw % i16::MAX as i32) as i16; + if should_log_unknown_dc(dc) { + accepted.fetch_add(1, Ordering::Relaxed); + } + } + })); + } + + for worker in workers { + worker.join().expect("worker thread must not panic"); + } + + assert_eq!( + accepted.load(Ordering::Relaxed), + UNKNOWN_DC_LOG_DISTINCT_LIMIT, + "concurrent unique churn must never admit more than the configured distinct cap" + ); +} + +#[test] +fn light_fuzz_unknown_dc_log_mixed_duplicates_never_exceeds_cap() { + let _guard = unknown_dc_test_lock() + .lock() + .expect("unknown dc test lock must be available"); + clear_unknown_dc_log_cache_for_testing(); + + // Deterministic xorshift sequence for reproducible mixed duplicate fuzzing. + let mut s: u64 = 0xA5A5_5A5A_C3C3_3C3C; + let mut admitted = 0usize; + + for _ in 0..20_000 { + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + + let dc = (s as i16).wrapping_sub(i16::MAX / 2); + if should_log_unknown_dc(dc) { + admitted += 1; + } + } + + assert!( + admitted <= UNKNOWN_DC_LOG_DISTINCT_LIMIT, + "mixed-duplicate fuzzed inputs must not admit more than cap" + ); +} + +#[test] +fn scope_hint_accepts_ascii_alnum_and_dash_within_limit() { + assert_eq!(validated_scope_hint("scope_alpha-1"), Some("alpha-1")); + assert_eq!(validated_scope_hint("scope_AZ09"), Some("AZ09")); +} + +#[test] +fn scope_hint_rejects_invalid_or_oversized_values() { + assert_eq!(validated_scope_hint("plain_user"), None); + assert_eq!(validated_scope_hint("scope_"), None); + assert_eq!(validated_scope_hint("scope_a/b"), None); + assert_eq!(validated_scope_hint("scope_bad space"), None); + assert_eq!(validated_scope_hint("scope_bad.dot"), None); + + let oversized = format!("scope_{}", "a".repeat(MAX_SCOPE_HINT_LEN + 1)); + assert_eq!(validated_scope_hint(&oversized), None); +} + +#[test] +fn unknown_dc_log_path_sanitizer_rejects_parent_traversal_inputs() { + assert!( + sanitize_unknown_dc_log_path("../unknown-dc.txt").is_none(), + "parent traversal paths must be rejected" + ); + assert!( + sanitize_unknown_dc_log_path("logs/../unknown-dc.txt").is_none(), + "embedded parent traversal must be rejected" + ); + assert!( + sanitize_unknown_dc_log_path("./../unknown-dc.txt").is_none(), + "relative parent traversal must be rejected" + ); +} + +#[test] +fn unknown_dc_log_path_sanitizer_accepts_absolute_paths_with_existing_parent() { + let absolute = std::env::temp_dir().join("unknown-dc.txt"); + let absolute_str = absolute + .to_str() + .expect("temp absolute path must be valid UTF-8"); + + let sanitized = sanitize_unknown_dc_log_path(absolute_str) + .expect("absolute paths with existing parent must be accepted"); + assert_eq!(sanitized.resolved_path, absolute); +} + +#[test] +fn unknown_dc_log_path_sanitizer_rejects_absolute_parent_traversal() { + assert!( + sanitize_unknown_dc_log_path("/tmp/../etc/passwd").is_none(), + "absolute parent traversal must be rejected" + ); +} + +#[test] +fn unknown_dc_log_path_sanitizer_accepts_safe_relative_path() { + let base = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!("telemt-unknown-dc-log-{}", std::process::id())); + fs::create_dir_all(&base).expect("temp test directory must be creatable"); + + let candidate = base.join("unknown-dc.txt"); + let candidate_relative = format!("target/telemt-unknown-dc-log-{}/unknown-dc.txt", std::process::id()); + + let sanitized = sanitize_unknown_dc_log_path(&candidate_relative) + .expect("safe relative path with existing parent must be accepted"); + assert_eq!(sanitized.resolved_path, candidate); +} + +#[test] +fn unknown_dc_log_path_sanitizer_rejects_empty_or_dot_only_inputs() { + assert!( + sanitize_unknown_dc_log_path("").is_none(), + "empty path must be rejected" + ); + assert!( + sanitize_unknown_dc_log_path(".").is_none(), + "dot-only path without filename must be rejected" + ); +} + +#[test] +fn unknown_dc_log_path_sanitizer_accepts_directory_only_as_filename_projection() { + let sanitized = sanitize_unknown_dc_log_path("target/") + .expect("directory-only input is interpreted as filename projection in current sanitizer"); + assert!( + sanitized.resolved_path.ends_with("target"), + "directory-only input should resolve to canonical parent plus filename projection" + ); +} + +#[test] +fn unknown_dc_log_path_sanitizer_accepts_dot_prefixed_relative_path() { + let rel_dir = format!("target/telemt-unknown-dc-dot-{}", std::process::id()); + let abs_dir = std::env::current_dir() + .expect("cwd must be available") + .join(&rel_dir); + fs::create_dir_all(&abs_dir).expect("dot-prefixed test directory must be creatable"); + + let rel_candidate = format!("./{rel_dir}/unknown-dc.log"); + let expected = abs_dir.join("unknown-dc.log"); + let sanitized = sanitize_unknown_dc_log_path(&rel_candidate) + .expect("dot-prefixed safe path must be accepted"); + assert_eq!(sanitized.resolved_path, expected); +} + +#[test] +fn light_fuzz_unknown_dc_path_parentdir_inputs_always_rejected() { + let mut s: u64 = 0xD00D_BAAD_1234_5678; + for _ in 0..4096 { + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + let a = (s as usize) % 32; + let b = ((s >> 8) as usize) % 32; + let candidate = format!("target/{a}/../{b}/unknown-dc.log"); + assert!( + sanitize_unknown_dc_log_path(&candidate).is_none(), + "parent-dir candidate must be rejected: {candidate}" + ); + } +} + +#[test] +fn unknown_dc_log_path_sanitizer_rejects_nonexistent_parent_directory() { + let rel_candidate = format!( + "target/telemt-unknown-dc-missing-{}/nested/unknown-dc.txt", + std::process::id() + ); + + assert!( + sanitize_unknown_dc_log_path(&rel_candidate).is_none(), + "path with missing parent must be rejected to avoid implicit directory creation" + ); +} + +#[cfg(unix)] +#[test] +fn unknown_dc_log_path_sanitizer_accepts_symlinked_parent_inside_workspace() { + use std::os::unix::fs::symlink; + + let base = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!("telemt-unknown-dc-log-symlink-internal-{}", std::process::id())); + let real_parent = base.join("real_parent"); + fs::create_dir_all(&real_parent).expect("real parent dir must be creatable"); + + let symlink_parent = base.join("internal_link"); + let _ = fs::remove_file(&symlink_parent); + symlink(&real_parent, &symlink_parent).expect("internal symlink must be creatable"); + + let rel_candidate = format!( + "target/telemt-unknown-dc-log-symlink-internal-{}/internal_link/unknown-dc.txt", + std::process::id() + ); + + let sanitized = sanitize_unknown_dc_log_path(&rel_candidate) + .expect("symlinked parent that resolves inside workspace must be accepted"); + assert!( + sanitized.resolved_path.starts_with(&real_parent), + "sanitized path must resolve to canonical internal parent" + ); +} + +#[cfg(unix)] +#[test] +fn unknown_dc_log_path_sanitizer_accepts_symlink_parent_escape_as_canonical_path() { + use std::os::unix::fs::symlink; + + let base = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!("telemt-unknown-dc-log-symlink-{}", std::process::id())); + fs::create_dir_all(&base).expect("symlink test directory must be creatable"); + + let symlink_parent = base.join("escape_link"); + let _ = fs::remove_file(&symlink_parent); + symlink("/tmp", &symlink_parent).expect("symlink parent must be creatable"); + + let rel_candidate = format!( + "target/telemt-unknown-dc-log-symlink-{}/escape_link/unknown-dc.txt", + std::process::id() + ); + + let sanitized = sanitize_unknown_dc_log_path(&rel_candidate) + .expect("symlinked parent must canonicalize to target path"); + assert!( + sanitized.resolved_path.starts_with(Path::new("/tmp")), + "sanitized path must resolve to canonical symlink target" + ); +} + +#[cfg(unix)] +#[test] +fn unknown_dc_log_path_revalidation_rejects_symlinked_target_escape() { + use std::os::unix::fs::symlink; + + let base = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!("telemt-unknown-dc-target-link-{}", std::process::id())); + fs::create_dir_all(&base).expect("target-link base must be creatable"); + + let outside = std::env::temp_dir().join(format!("telemt-outside-{}", std::process::id())); + let _ = fs::remove_file(&outside); + fs::write(&outside, "outside").expect("outside file must be writable"); + + let linked_target = base.join("unknown-dc.log"); + let _ = fs::remove_file(&linked_target); + symlink(&outside, &linked_target).expect("target symlink must be creatable"); + + let rel_candidate = format!( + "target/telemt-unknown-dc-target-link-{}/unknown-dc.log", + std::process::id() + ); + let sanitized = sanitize_unknown_dc_log_path(&rel_candidate) + .expect("candidate should sanitize before final revalidation"); + + assert!( + !unknown_dc_log_path_is_still_safe(&sanitized), + "final revalidation must reject symlinked target escape" + ); +} + +#[cfg(unix)] +#[test] +fn unknown_dc_open_append_rejects_symlink_target_with_nofollow() { + use std::os::unix::fs::symlink; + + let base = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!("telemt-unknown-dc-nofollow-{}", std::process::id())); + fs::create_dir_all(&base).expect("nofollow base must be creatable"); + + let outside = std::env::temp_dir().join(format!( + "telemt-unknown-dc-nofollow-outside-{}.log", + std::process::id() + )); + let _ = fs::remove_file(&outside); + fs::write(&outside, "outside\n").expect("outside file must be writable"); + + let linked_target = base.join("unknown-dc.log"); + let _ = fs::remove_file(&linked_target); + symlink(&outside, &linked_target).expect("symlink target must be creatable"); + + let err = open_unknown_dc_log_append(&linked_target) + .expect_err("O_NOFOLLOW open must fail for symlink target"); + assert_eq!( + err.raw_os_error(), + Some(libc::ELOOP), + "symlink target must be rejected with ELOOP when O_NOFOLLOW is applied" + ); +} + +#[cfg(unix)] +#[test] +fn unknown_dc_open_append_rejects_broken_symlink_target_with_nofollow() { + use std::os::unix::fs::symlink; + + let base = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!("telemt-unknown-dc-broken-link-{}", std::process::id())); + fs::create_dir_all(&base).expect("broken-link base must be creatable"); + + let linked_target = base.join("unknown-dc.log"); + let _ = fs::remove_file(&linked_target); + symlink(base.join("missing-target.log"), &linked_target) + .expect("broken symlink target must be creatable"); + + let err = open_unknown_dc_log_append(&linked_target) + .expect_err("O_NOFOLLOW open must fail for broken symlink target"); + assert_eq!( + err.raw_os_error(), + Some(libc::ELOOP), + "broken symlink target must be rejected with ELOOP when O_NOFOLLOW is applied" + ); +} + +#[cfg(unix)] +#[test] +fn adversarial_unknown_dc_open_append_symlink_flip_never_writes_outside_file() { + use std::os::unix::fs::symlink; + + let base = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!("telemt-unknown-dc-symlink-flip-{}", std::process::id())); + fs::create_dir_all(&base).expect("symlink-flip base must be creatable"); + + let outside = std::env::temp_dir().join(format!( + "telemt-unknown-dc-symlink-flip-outside-{}.log", + std::process::id() + )); + fs::write(&outside, "outside-baseline\n").expect("outside baseline file must be writable"); + let outside_before = fs::read_to_string(&outside).expect("outside baseline must be readable"); + + let target = base.join("unknown-dc.log"); + let _ = fs::remove_file(&target); + + for step in 0..1024usize { + let _ = fs::remove_file(&target); + if step % 2 == 0 { + symlink(&outside, &target).expect("symlink creation in flip loop must succeed"); + } + if let Ok(mut file) = open_unknown_dc_log_append(&target) { + writeln!(file, "dc_idx={step}").expect("append on regular file must succeed"); + } + } + + let outside_after = fs::read_to_string(&outside).expect("outside file must remain readable"); + assert_eq!( + outside_after, outside_before, + "outside file must never be modified under symlink-flip adversarial churn" + ); +} + +#[test] +fn unknown_dc_open_append_creates_regular_file() { + let base = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!("telemt-unknown-dc-open-{}", std::process::id())); + fs::create_dir_all(&base).expect("open test base must be creatable"); + + let target = base.join("unknown-dc.log"); + let _ = fs::remove_file(&target); + + { + let mut file = open_unknown_dc_log_append(&target) + .expect("regular target must be creatable with append open"); + writeln!(file, "dc_idx=1234").expect("append write must succeed"); + } + + let meta = fs::symlink_metadata(&target).expect("created target metadata must be readable"); + assert!(meta.file_type().is_file(), "target must be a regular file"); + assert!( + !meta.file_type().is_symlink(), + "regular target open path must not produce symlink artifacts" + ); +} + +#[test] +fn stress_unknown_dc_open_append_regular_file_preserves_line_integrity() { + let base = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!("telemt-unknown-dc-open-stress-{}", std::process::id())); + fs::create_dir_all(&base).expect("stress open base must be creatable"); + + let target = base.join("unknown-dc.log"); + let _ = fs::remove_file(&target); + + let writes = 2048usize; + for idx in 0..writes { + let mut file = open_unknown_dc_log_append(&target) + .expect("stress append open on regular file must succeed"); + writeln!(file, "dc_idx={idx}").expect("stress append write must succeed"); + } + + let content = fs::read_to_string(&target).expect("stress output file must be readable"); + assert_eq!( + nonempty_line_count(&content), + writes, + "regular-file append stress must preserve one logical line per write" + ); +} + +#[test] +fn unknown_dc_log_path_revalidation_accepts_regular_existing_target() { + let base = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!("telemt-unknown-dc-safe-target-{}", std::process::id())); + fs::create_dir_all(&base).expect("safe target base must be creatable"); + + let target = base.join("unknown-dc.log"); + fs::write(&target, "seed\n").expect("safe target seed write must succeed"); + + let rel_candidate = format!( + "target/telemt-unknown-dc-safe-target-{}/unknown-dc.log", + std::process::id() + ); + let sanitized = sanitize_unknown_dc_log_path(&rel_candidate) + .expect("safe candidate must sanitize"); + assert!( + unknown_dc_log_path_is_still_safe(&sanitized), + "revalidation must allow safe existing regular files" + ); +} + +#[test] +fn unknown_dc_log_path_revalidation_rejects_deleted_parent_after_sanitize() { + let base = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!("telemt-unknown-dc-vanish-parent-{}", std::process::id())); + fs::create_dir_all(&base).expect("vanish-parent base must be creatable"); + + let rel_candidate = format!( + "target/telemt-unknown-dc-vanish-parent-{}/unknown-dc.log", + std::process::id() + ); + let sanitized = sanitize_unknown_dc_log_path(&rel_candidate) + .expect("candidate must sanitize before parent deletion"); + + fs::remove_dir_all(&base).expect("test parent directory must be removable"); + assert!( + !unknown_dc_log_path_is_still_safe(&sanitized), + "revalidation must fail when sanitized parent disappears before write" + ); +} + +#[cfg(unix)] +#[test] +fn unknown_dc_log_path_revalidation_rejects_parent_swapped_to_symlink() { + use std::os::unix::fs::symlink; + + let parent = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!("telemt-unknown-dc-parent-swap-{}", std::process::id())); + fs::create_dir_all(&parent).expect("parent-swap test parent must be creatable"); + + let rel_candidate = format!( + "target/telemt-unknown-dc-parent-swap-{}/unknown-dc.log", + std::process::id() + ); + let sanitized = sanitize_unknown_dc_log_path(&rel_candidate) + .expect("candidate must sanitize before parent swap"); + + let moved = parent.with_extension("bak"); + let _ = fs::remove_dir_all(&moved); + fs::rename(&parent, &moved).expect("parent must be movable for swap simulation"); + symlink("/tmp", &parent).expect("symlink replacement for parent must be creatable"); + + assert!( + !unknown_dc_log_path_is_still_safe(&sanitized), + "revalidation must fail when canonical parent is swapped to a symlinked target" + ); +} + +#[cfg(unix)] +#[test] +fn adversarial_check_then_symlink_flip_is_blocked_by_nofollow_open() { + use std::os::unix::fs::symlink; + + let parent = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!("telemt-unknown-dc-check-open-race-{}", std::process::id())); + fs::create_dir_all(&parent).expect("check-open-race parent must be creatable"); + + let target = parent.join("unknown-dc.log"); + fs::write(&target, "seed\n").expect("seed target file must be writable"); + let rel_candidate = format!( + "target/telemt-unknown-dc-check-open-race-{}/unknown-dc.log", + std::process::id() + ); + let sanitized = sanitize_unknown_dc_log_path(&rel_candidate) + .expect("candidate must sanitize"); + + assert!( + unknown_dc_log_path_is_still_safe(&sanitized), + "precondition: target should initially pass revalidation" + ); + + let outside = std::env::temp_dir().join(format!( + "telemt-unknown-dc-check-open-race-outside-{}.log", + std::process::id() + )); + fs::write(&outside, "outside\n").expect("outside file must be writable"); + fs::remove_file(&target).expect("target removal before flip must succeed"); + symlink(&outside, &target).expect("target symlink flip must be creatable"); + + let err = open_unknown_dc_log_append(&sanitized.resolved_path) + .expect_err("nofollow open must fail after symlink flip between check and open"); + assert_eq!( + err.raw_os_error(), + Some(libc::ELOOP), + "symlink flip in check/open window must be neutralized by O_NOFOLLOW" + ); +} + +#[cfg(unix)] +#[test] +fn adversarial_parent_swap_after_check_is_blocked_by_anchored_open() { + use std::os::unix::fs::symlink; + + let base = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!("telemt-unknown-dc-parent-swap-openat-{}", std::process::id())); + fs::create_dir_all(&base).expect("parent-swap-openat base must be creatable"); + + let rel_candidate = format!( + "target/telemt-unknown-dc-parent-swap-openat-{}/unknown-dc.log", + std::process::id() + ); + let sanitized = sanitize_unknown_dc_log_path(&rel_candidate) + .expect("candidate must sanitize before parent swap"); + fs::write(&sanitized.resolved_path, "seed\n").expect("seed target file must be writable"); + + assert!( + unknown_dc_log_path_is_still_safe(&sanitized), + "precondition: target should initially pass revalidation" + ); + + let outside_parent = std::env::temp_dir().join(format!( + "telemt-unknown-dc-parent-swap-openat-outside-{}", + std::process::id() + )); + fs::create_dir_all(&outside_parent).expect("outside parent directory must be creatable"); + let outside_target = outside_parent.join("unknown-dc.log"); + let _ = fs::remove_file(&outside_target); + + let moved = base.with_extension("bak"); + let _ = fs::remove_dir_all(&moved); + fs::rename(&base, &moved).expect("base parent must be movable for swap simulation"); + symlink(&outside_parent, &base).expect("base parent symlink replacement must be creatable"); + + let err = open_unknown_dc_log_append_anchored(&sanitized) + .expect_err("anchored open must fail when parent is swapped to symlink"); + let raw = err.raw_os_error(); + assert!( + matches!(raw, Some(libc::ELOOP) | Some(libc::ENOTDIR) | Some(libc::ENOENT)), + "anchored open must fail closed on parent swap race, got raw_os_error={raw:?}" + ); + assert!( + !outside_target.exists(), + "anchored open must never create a log file in swapped outside parent" + ); +} + +#[tokio::test] +async fn unknown_dc_absolute_log_path_writes_one_entry() { + let _guard = unknown_dc_test_lock() + .lock() + .expect("unknown dc test lock must be available"); + clear_unknown_dc_log_cache_for_testing(); + + let dc_idx: i16 = 31_001; + let file_path = std::env::temp_dir().join(format!( + "telemt-unknown-dc-abs-{}-{}.log", + std::process::id(), + dc_idx + )); + let _ = fs::remove_file(&file_path); + + let mut cfg = ProxyConfig::default(); + cfg.general.unknown_dc_file_log_enabled = true; + cfg.general.unknown_dc_log_path = Some( + file_path + .to_str() + .expect("temp file path must be valid UTF-8") + .to_string(), + ); + + let _ = get_dc_addr_static(dc_idx, &cfg).expect("fallback routing must still work"); + + let mut content = None; + for _ in 0..20 { + if let Ok(text) = fs::read_to_string(&file_path) { + content = Some(text); + break; + } + tokio::time::sleep(Duration::from_millis(15)).await; + } + + let text = content.expect("absolute unknown-DC log path must produce exactly one log write"); + assert!( + text.contains(&format!("dc_idx={dc_idx}")), + "absolute unknown-DC integration log must contain requested dc_idx" + ); +} + +#[tokio::test] +async fn unknown_dc_safe_relative_log_path_writes_one_entry() { + let _guard = unknown_dc_test_lock() + .lock() + .expect("unknown dc test lock must be available"); + clear_unknown_dc_log_cache_for_testing(); + + let dc_idx: i16 = 31_002; + let rel_dir = format!("target/telemt-unknown-dc-int-{}", std::process::id()); + let rel_file = format!("{rel_dir}/unknown-dc.log"); + let abs_dir = std::env::current_dir() + .expect("cwd must be available") + .join(&rel_dir); + fs::create_dir_all(&abs_dir).expect("integration test log directory must be creatable"); + let abs_file = abs_dir.join("unknown-dc.log"); + let _ = fs::remove_file(&abs_file); + + let mut cfg = ProxyConfig::default(); + cfg.general.unknown_dc_file_log_enabled = true; + cfg.general.unknown_dc_log_path = Some(rel_file); + + let _ = get_dc_addr_static(dc_idx, &cfg).expect("fallback routing must still work"); + + let mut content = None; + for _ in 0..20 { + if let Ok(text) = fs::read_to_string(&abs_file) { + content = Some(text); + break; + } + tokio::time::sleep(Duration::from_millis(15)).await; + } + + let text = content.expect("safe relative path must produce exactly one log write"); + assert!( + text.contains(&format!("dc_idx={dc_idx}")), + "unknown-DC integration log must contain requested dc_idx" + ); +} + +#[tokio::test] +async fn unknown_dc_same_index_burst_writes_only_once() { + let _guard = unknown_dc_test_lock() + .lock() + .expect("unknown dc test lock must be available"); + clear_unknown_dc_log_cache_for_testing(); + + let dc_idx: i16 = 31_010; + let rel_dir = format!("target/telemt-unknown-dc-same-{}", std::process::id()); + let rel_file = format!("{rel_dir}/unknown-dc.log"); + let abs_dir = std::env::current_dir().unwrap().join(&rel_dir); + fs::create_dir_all(&abs_dir).expect("same-index log directory must be creatable"); + let abs_file = abs_dir.join("unknown-dc.log"); + let _ = fs::remove_file(&abs_file); + + let mut cfg = ProxyConfig::default(); + cfg.general.unknown_dc_file_log_enabled = true; + cfg.general.unknown_dc_log_path = Some(rel_file); + + for _ in 0..64 { + let _ = get_dc_addr_static(dc_idx, &cfg).expect("fallback routing must still work"); + } + + let mut content = None; + for _ in 0..30 { + if let Ok(text) = fs::read_to_string(&abs_file) { + content = Some(text); + break; + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + + let text = content.expect("same-index burst must produce at least one log write"); + assert_eq!( + nonempty_line_count(&text), + 1, + "same unknown dc index must be deduplicated to one file line" + ); +} + +#[tokio::test] +async fn unknown_dc_distinct_burst_is_hard_capped_on_file_writes() { + let _guard = unknown_dc_test_lock() + .lock() + .expect("unknown dc test lock must be available"); + clear_unknown_dc_log_cache_for_testing(); + + let rel_dir = format!("target/telemt-unknown-dc-cap-{}", std::process::id()); + let rel_file = format!("{rel_dir}/unknown-dc.log"); + let abs_dir = std::env::current_dir().unwrap().join(&rel_dir); + fs::create_dir_all(&abs_dir).expect("cap log directory must be creatable"); + let abs_file = abs_dir.join("unknown-dc.log"); + let _ = fs::remove_file(&abs_file); + + let mut cfg = ProxyConfig::default(); + cfg.general.unknown_dc_file_log_enabled = true; + cfg.general.unknown_dc_log_path = Some(rel_file); + + for i in 0..(UNKNOWN_DC_LOG_DISTINCT_LIMIT + 128) { + let dc_idx = 20_000i16.wrapping_add(i as i16); + let _ = get_dc_addr_static(dc_idx, &cfg).expect("fallback routing must still work"); + } + + let mut final_text = String::new(); + for _ in 0..80 { + if let Ok(text) = fs::read_to_string(&abs_file) { + final_text = text; + if nonempty_line_count(&final_text) >= UNKNOWN_DC_LOG_DISTINCT_LIMIT { + break; + } + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + + let line_count = nonempty_line_count(&final_text); + assert!( + line_count > 0, + "distinct unknown-dc burst must write at least one line" + ); + assert!( + line_count <= UNKNOWN_DC_LOG_DISTINCT_LIMIT, + "distinct unknown-dc writes must stay within dedup hard cap" + ); +} + +#[cfg(unix)] +#[tokio::test] +async fn unknown_dc_symlinked_target_escape_is_not_written_integration() { + use std::os::unix::fs::symlink; + + let _guard = unknown_dc_test_lock() + .lock() + .expect("unknown dc test lock must be available"); + clear_unknown_dc_log_cache_for_testing(); + + let base = std::env::current_dir() + .expect("cwd must be available") + .join("target") + .join(format!("telemt-unknown-dc-no-write-link-{}", std::process::id())); + fs::create_dir_all(&base).expect("integration symlink base must be creatable"); + + let outside = std::env::temp_dir().join(format!( + "telemt-unknown-dc-outside-{}.log", + std::process::id() + )); + fs::write(&outside, "baseline\n").expect("outside baseline file must be writable"); + + let linked_target = base.join("unknown-dc.log"); + let _ = fs::remove_file(&linked_target); + symlink(&outside, &linked_target).expect("symlink target must be creatable"); + + let rel_file = format!( + "target/telemt-unknown-dc-no-write-link-{}/unknown-dc.log", + std::process::id() + ); + let dc_idx: i16 = 31_050; + + let mut cfg = ProxyConfig::default(); + cfg.general.unknown_dc_file_log_enabled = true; + cfg.general.unknown_dc_log_path = Some(rel_file); + + let before = fs::read_to_string(&outside).expect("must read baseline outside file"); + let _ = get_dc_addr_static(dc_idx, &cfg).expect("fallback routing must still work"); + tokio::time::sleep(Duration::from_millis(80)).await; + let after = fs::read_to_string(&outside).expect("must read outside file after attempt"); + + assert_eq!( + after, before, + "symlink target escape must not be written by unknown-DC logging" + ); +} + +#[test] +fn fallback_dc_never_panics_with_single_dc_list() { + let mut cfg = ProxyConfig::default(); + cfg.network.prefer = 6; + cfg.network.ipv6 = Some(true); + cfg.default_dc = Some(42); + + let addr = get_dc_addr_static(999, &cfg).expect("fallback dc must resolve safely"); + let expected = SocketAddr::new(TG_DATACENTERS_V6[0], TG_DATACENTER_PORT); + assert_eq!(addr, expected); +} + +#[tokio::test] +async fn direct_relay_abort_midflight_releases_route_gauge() { + let tg_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let tg_addr = tg_listener.local_addr().unwrap(); + + let tg_accept_task = tokio::spawn(async move { + let (stream, _) = tg_listener.accept().await.unwrap(); + let _hold_stream = stream; + tokio::time::sleep(Duration::from_secs(60)).await; + }); + + let stats = Arc::new(Stats::new()); + let mut config = ProxyConfig::default(); + config + .dc_overrides + .insert("2".to_string(), vec![tg_addr.to_string()]); + let config = Arc::new(config); + + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + + let rng = Arc::new(SecureRandom::new()); + let buffer_pool = Arc::new(BufferPool::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let route_snapshot = route_runtime.snapshot(); + + let (server_side, client_side) = duplex(64 * 1024); + let (server_reader, server_writer) = tokio::io::split(server_side); + let client_reader = make_crypto_reader(server_reader); + let client_writer = make_crypto_writer(server_writer); + + let success = HandshakeSuccess { + user: "abort-direct-user".to_string(), + dc_idx: 2, + proto_tag: ProtoTag::Intermediate, + dec_key: [0u8; 32], + dec_iv: 0, + enc_key: [0u8; 32], + enc_iv: 0, + peer: "127.0.0.1:50000".parse().unwrap(), + is_tls: false, + }; + + let relay_task = tokio::spawn(handle_via_direct( + client_reader, + client_writer, + success, + upstream_manager, + stats.clone(), + config, + buffer_pool, + rng, + route_runtime.subscribe(), + route_snapshot, + 0xabad1dea, + )); + + let started = tokio::time::timeout(Duration::from_secs(2), async { + loop { + if stats.get_current_connections_direct() == 1 { + break; + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + }) + .await; + assert!(started.is_ok(), "direct relay must increment route gauge before abort"); + + relay_task.abort(); + let joined = relay_task.await; + assert!(joined.is_err(), "aborted direct relay task must return join error"); + + tokio::time::sleep(Duration::from_millis(20)).await; + assert_eq!( + stats.get_current_connections_direct(), + 0, + "route gauge must be released when direct relay task is aborted mid-flight" + ); + + drop(client_side); + tg_accept_task.abort(); + let _ = tg_accept_task.await; +} + +#[tokio::test] +async fn direct_relay_cutover_midflight_releases_route_gauge() { + let tg_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let tg_addr = tg_listener.local_addr().unwrap(); + + let tg_accept_task = tokio::spawn(async move { + let (stream, _) = tg_listener.accept().await.unwrap(); + let _hold_stream = stream; + tokio::time::sleep(Duration::from_secs(60)).await; + }); + + let stats = Arc::new(Stats::new()); + let mut config = ProxyConfig::default(); + config + .dc_overrides + .insert("2".to_string(), vec![tg_addr.to_string()]); + let config = Arc::new(config); + + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + + let rng = Arc::new(SecureRandom::new()); + let buffer_pool = Arc::new(BufferPool::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let route_snapshot = route_runtime.snapshot(); + + let (server_side, client_side) = duplex(64 * 1024); + let (server_reader, server_writer) = tokio::io::split(server_side); + let client_reader = make_crypto_reader(server_reader); + let client_writer = make_crypto_writer(server_writer); + + let success = HandshakeSuccess { + user: "cutover-direct-user".to_string(), + dc_idx: 2, + proto_tag: ProtoTag::Intermediate, + dec_key: [0u8; 32], + dec_iv: 0, + enc_key: [0u8; 32], + enc_iv: 0, + peer: "127.0.0.1:50002".parse().unwrap(), + is_tls: false, + }; + + let relay_task = tokio::spawn(handle_via_direct( + client_reader, + client_writer, + success, + upstream_manager, + stats.clone(), + config, + buffer_pool, + rng, + route_runtime.subscribe(), + route_snapshot, + 0xface_cafe, + )); + + tokio::time::timeout(Duration::from_secs(2), async { + loop { + if stats.get_current_connections_direct() == 1 { + break; + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + }) + .await + .expect("direct relay must increment route gauge before cutover"); + + assert!( + route_runtime.set_mode(RelayRouteMode::Middle).is_some(), + "cutover must advance route generation" + ); + + let relay_result = tokio::time::timeout(Duration::from_secs(6), relay_task) + .await + .expect("direct relay must terminate after cutover") + .expect("direct relay task must not panic"); + assert!( + relay_result.is_err(), + "cutover should terminate direct relay session" + ); + assert!( + matches!( + relay_result, + Err(ProxyError::Proxy(ref msg)) if msg == ROUTE_SWITCH_ERROR_MSG + ), + "client-visible cutover error must stay generic and avoid route-internal metadata" + ); + + assert_eq!( + stats.get_current_connections_direct(), + 0, + "route gauge must be released when direct relay exits on cutover" + ); + + drop(client_side); + tg_accept_task.abort(); + let _ = tg_accept_task.await; +} + +#[tokio::test] +async fn direct_relay_cutover_storm_multi_session_keeps_generic_errors_and_releases_gauge() { + let session_count = 6usize; + let tg_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let tg_addr = tg_listener.local_addr().unwrap(); + + let tg_accept_task = tokio::spawn(async move { + let mut held_streams = Vec::with_capacity(session_count); + for _ in 0..session_count { + let (stream, _) = tg_listener.accept().await.unwrap(); + held_streams.push(stream); + } + tokio::time::sleep(Duration::from_secs(60)).await; + drop(held_streams); + }); + + let stats = Arc::new(Stats::new()); + let mut config = ProxyConfig::default(); + config + .dc_overrides + .insert("2".to_string(), vec![tg_addr.to_string()]); + let config = Arc::new(config); + + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + + let rng = Arc::new(SecureRandom::new()); + let buffer_pool = Arc::new(BufferPool::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let route_snapshot = route_runtime.snapshot(); + + let mut relay_tasks = Vec::with_capacity(session_count); + let mut client_sides = Vec::with_capacity(session_count); + + for idx in 0..session_count { + let (server_side, client_side) = duplex(64 * 1024); + client_sides.push(client_side); + let (server_reader, server_writer) = tokio::io::split(server_side); + let client_reader = make_crypto_reader(server_reader); + let client_writer = make_crypto_writer(server_writer); + + let success = HandshakeSuccess { + user: format!("cutover-storm-direct-user-{idx}"), + dc_idx: 2, + proto_tag: ProtoTag::Intermediate, + dec_key: [0u8; 32], + dec_iv: 0, + enc_key: [0u8; 32], + enc_iv: 0, + peer: SocketAddr::new( + std::net::IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 1)), + 51000 + idx as u16, + ), + is_tls: false, + }; + + relay_tasks.push(tokio::spawn(handle_via_direct( + client_reader, + client_writer, + success, + upstream_manager.clone(), + stats.clone(), + config.clone(), + buffer_pool.clone(), + rng.clone(), + route_runtime.subscribe(), + route_snapshot, + 0xA000_0000 + idx as u64, + ))); + } + + tokio::time::timeout(Duration::from_secs(4), async { + loop { + if stats.get_current_connections_direct() == session_count as u64 { + break; + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + }) + .await + .expect("all direct sessions must become active before cutover storm"); + + let route_runtime_flipper = route_runtime.clone(); + let flipper = tokio::spawn(async move { + for step in 0..64u32 { + let mode = if (step & 1) == 0 { + RelayRouteMode::Middle + } else { + RelayRouteMode::Direct + }; + let _ = route_runtime_flipper.set_mode(mode); + tokio::time::sleep(Duration::from_millis(15)).await; + } + }); + + for relay_task in relay_tasks { + let relay_result = tokio::time::timeout(Duration::from_secs(10), relay_task) + .await + .expect("direct relay task must finish under cutover storm") + .expect("direct relay task must not panic"); + + assert!( + matches!( + relay_result, + Err(ProxyError::Proxy(ref msg)) if msg == ROUTE_SWITCH_ERROR_MSG + ), + "storm-cutover termination must remain generic for all direct sessions" + ); + } + + flipper.abort(); + let _ = flipper.await; + + assert_eq!( + stats.get_current_connections_direct(), + 0, + "direct route gauge must return to zero after cutover storm" + ); + + drop(client_sides); + tg_accept_task.abort(); + let _ = tg_accept_task.await; +} + +#[test] +fn prefer_v6_override_matrix_prefers_matching_family_then_degrades_safely() { + let dc_idx: i16 = 2; + + let mut cfg_a = ProxyConfig::default(); + cfg_a.network.prefer = 6; + cfg_a.network.ipv6 = Some(true); + cfg_a.dc_overrides.insert( + dc_idx.to_string(), + vec![ + "203.0.113.90:443".to_string(), + "[2001:db8::90]:443".to_string(), + ], + ); + let a = get_dc_addr_static(dc_idx, &cfg_a).expect("v6+v4 override set must resolve"); + assert!(a.is_ipv6(), "prefer_v6 should choose v6 override when present"); + + let mut cfg_b = ProxyConfig::default(); + cfg_b.network.prefer = 6; + cfg_b.network.ipv6 = Some(true); + cfg_b.dc_overrides + .insert(dc_idx.to_string(), vec!["203.0.113.91:443".to_string()]); + let b = get_dc_addr_static(dc_idx, &cfg_b).expect("v4-only override must still resolve"); + assert!(b.is_ipv4(), "when no v6 override exists, v4 override must be used"); + + let mut cfg_c = ProxyConfig::default(); + cfg_c.network.prefer = 6; + cfg_c.network.ipv6 = Some(true); + let c = get_dc_addr_static(dc_idx, &cfg_c).expect("table fallback must resolve"); + assert_eq!( + c, + SocketAddr::new(TG_DATACENTERS_V6[(dc_idx as usize) - 1], TG_DATACENTER_PORT), + "without overrides, prefer_v6 path must resolve from static v6 datacenter table" + ); +} + +#[test] +fn prefer_v6_override_matrix_ignores_invalid_entries_and_keeps_fail_closed_fallback() { + let dc_idx: i16 = 3; + + let mut cfg = ProxyConfig::default(); + cfg.network.prefer = 6; + cfg.network.ipv6 = Some(true); + cfg.dc_overrides.insert( + dc_idx.to_string(), + vec![ + "not-an-addr".to_string(), + "also:bad".to_string(), + "203.0.113.55:443".to_string(), + ], + ); + + let addr = get_dc_addr_static(dc_idx, &cfg).expect("at least one valid override must keep resolution alive"); + assert_eq!(addr, "203.0.113.55:443".parse::().unwrap()); +} + +#[test] +fn stress_prefer_v6_override_matrix_is_deterministic_under_mixed_inputs() { + for idx in 1..=5i16 { + let mut cfg = ProxyConfig::default(); + cfg.network.prefer = 6; + cfg.network.ipv6 = Some(true); + cfg.dc_overrides.insert( + idx.to_string(), + vec![ + format!("203.0.113.{}:443", 100 + idx), + format!("[2001:db8::{}]:443", 100 + idx), + ], + ); + + let first = get_dc_addr_static(idx, &cfg).expect("first lookup must resolve"); + let second = get_dc_addr_static(idx, &cfg).expect("second lookup must resolve"); + assert_eq!(first, second, "override resolution must stay deterministic for dc {idx}"); + assert!(first.is_ipv6(), "dc {idx}: v6 override should be preferred"); + } +} + +#[tokio::test] +async fn negative_direct_relay_dc_connection_refused_fails_fast() { + let (client_reader_side, _client_writer_side) = duplex(1024); + let (_client_reader_relay, client_writer_side) = duplex(1024); + + let key = [0u8; 32]; + let iv = 0u128; + let client_reader = CryptoReader::new(client_reader_side, AesCtr::new(&key, iv)); + let client_writer = CryptoWriter::new(client_writer_side, AesCtr::new(&key, iv), 1024); + + let stats = Arc::new(Stats::new()); + let buffer_pool = Arc::new(BufferPool::with_config(1024, 1)); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = RouteRuntimeController::new(RelayRouteMode::Direct); + + // Reserve an ephemeral port and immediately release it to deterministically + // exercise the direct-connect failure path without long-lived hangs. + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let dc_addr = listener.local_addr().unwrap(); + drop(listener); + + let mut config_with_override = ProxyConfig::default(); + config_with_override.dc_overrides.insert("1".to_string(), vec![dc_addr.to_string()]); + let config = Arc::new(config_with_override); + + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + enabled: true, + weight: 1, + scopes: String::new(), + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + selected_scope: String::new(), + }], + 1, + 100, + 5000, + 3, + false, + stats.clone(), + )); + + let success = HandshakeSuccess { + user: "test-user".to_string(), + peer: "127.0.0.1:12345".parse().unwrap(), + dc_idx: 1, + proto_tag: ProtoTag::Intermediate, + enc_key: key, + enc_iv: iv, + dec_key: key, + dec_iv: iv, + is_tls: false, + }; + + let result = timeout( + TokioDuration::from_secs(2), + handle_via_direct( + client_reader, + client_writer, + success, + upstream_manager, + stats, + config, + buffer_pool, + rng, + route_runtime.subscribe(), + route_runtime.snapshot(), + 0xABCD_1234, + ), + ) + .await + .expect("direct relay must fail fast on connection-refused upstream"); + + assert!( + result.is_err(), + "connection-refused upstream must fail closed" + ); +} + +#[tokio::test] +async fn adversarial_direct_relay_cutover_integrity() { + let (client_reader_side, _client_writer_side) = duplex(1024); + let (_client_reader_relay, client_writer_side) = duplex(1024); + + let key = [0u8; 32]; + let iv = 0u128; + let client_reader = CryptoReader::new(client_reader_side, AesCtr::new(&key, iv)); + let client_writer = CryptoWriter::new(client_writer_side, AesCtr::new(&key, iv), 1024); + + let stats = Arc::new(Stats::new()); + let buffer_pool = Arc::new(BufferPool::with_config(1024, 1)); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = RouteRuntimeController::new(RelayRouteMode::Direct); + + // Mock upstream server. + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let dc_addr = listener.local_addr().unwrap(); + + tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + // Read handshake nonce. + let mut nonce = [0u8; 64]; + let _ = stream.read_exact(&mut nonce).await; + // Keep connection open. + tokio::time::sleep(TokioDuration::from_secs(5)).await; + }); + + let mut config_with_override = ProxyConfig::default(); + config_with_override.dc_overrides.insert("1".to_string(), vec![dc_addr.to_string()]); + let config = Arc::new(config_with_override); + + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + enabled: true, + weight: 1, + scopes: String::new(), + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + selected_scope: String::new(), + }], + 1, + 100, + 5000, + 3, + false, + stats.clone(), + )); + + let success = HandshakeSuccess { + user: "test-user".to_string(), + peer: "127.0.0.1:12345".parse().unwrap(), + dc_idx: 1, + proto_tag: ProtoTag::Intermediate, + enc_key: key, + enc_iv: iv, + dec_key: key, + dec_iv: iv, + is_tls: false, + }; + + let stats_for_task = stats.clone(); + let runtime_clone = route_runtime.clone(); + let session_task = tokio::spawn(async move { + handle_via_direct( + client_reader, + client_writer, + success, + upstream_manager, + stats_for_task, + config, + buffer_pool, + rng, + runtime_clone.subscribe(), + runtime_clone.snapshot(), + 0xABCD_1234, + ).await + }); + + timeout(TokioDuration::from_secs(2), async { + loop { + if stats.get_current_connections_direct() == 1 { + break; + } + tokio::time::sleep(TokioDuration::from_millis(10)).await; + } + }) + .await + .expect("direct relay session must start before cutover"); + + // Trigger cutover. + route_runtime.set_mode(RelayRouteMode::Middle).unwrap(); + + // The session should terminate after the staggered delay (1000-2000ms). + let result = timeout(TokioDuration::from_secs(5), session_task) + .await + .expect("Session must terminate after cutover") + .expect("Session must not panic"); + + assert!( + matches!( + result, + Err(ProxyError::Proxy(ref msg)) if msg == ROUTE_SWITCH_ERROR_MSG + ), + "Session must terminate with route switch error on cutover" + ); +} diff --git a/src/proxy/tests/direct_relay_subtle_adversarial_tests.rs b/src/proxy/tests/direct_relay_subtle_adversarial_tests.rs new file mode 100644 index 0000000..5cbbc68 --- /dev/null +++ b/src/proxy/tests/direct_relay_subtle_adversarial_tests.rs @@ -0,0 +1,197 @@ +use super::*; +use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; + +fn nonempty_line_count(text: &str) -> usize { + text.lines().filter(|line| !line.trim().is_empty()).count() +} + +#[test] +fn subtle_stress_single_unknown_dc_under_concurrency_logs_once() { + let _guard = unknown_dc_test_lock() + .lock() + .expect("unknown dc test lock must be available"); + clear_unknown_dc_log_cache_for_testing(); + + let winners = Arc::new(AtomicUsize::new(0)); + let mut workers = Vec::new(); + + for _ in 0..128 { + let winners = Arc::clone(&winners); + workers.push(std::thread::spawn(move || { + if should_log_unknown_dc(31_333) { + winners.fetch_add(1, Ordering::Relaxed); + } + })); + } + + for worker in workers { + worker.join().expect("worker must not panic"); + } + + assert_eq!(winners.load(Ordering::Relaxed), 1); +} + +#[test] +fn subtle_light_fuzz_scope_hint_matches_oracle() { + fn oracle(input: &str) -> bool { + let Some(rest) = input.strip_prefix("scope_") else { + return false; + }; + !rest.is_empty() + && rest.len() <= MAX_SCOPE_HINT_LEN + && rest + .bytes() + .all(|b| b.is_ascii_alphanumeric() || b == b'-') + } + + let mut state: u64 = 0xC0FF_EE11_D15C_AFE5; + for _ in 0..4_096 { + state ^= state << 7; + state ^= state >> 9; + state ^= state << 8; + + let len = (state as usize % 72) + 1; + let mut s = String::with_capacity(len + 6); + if (state & 1) == 0 { + s.push_str("scope_"); + } else { + s.push_str("user_"); + } + + for idx in 0..len { + let v = ((state >> ((idx % 8) * 8)) & 0xff) as u8; + let ch = match v % 6 { + 0 => (b'a' + (v % 26)) as char, + 1 => (b'A' + (v % 26)) as char, + 2 => (b'0' + (v % 10)) as char, + 3 => '-', + 4 => '_', + _ => '.', + }; + s.push(ch); + } + + let got = validated_scope_hint(&s).is_some(); + assert_eq!(got, oracle(&s), "mismatch for input: {s}"); + } +} + +#[test] +fn subtle_light_fuzz_dc_resolution_never_panics_and_preserves_port() { + let mut state: u64 = 0x1234_5678_9ABC_DEF0; + + for _ in 0..2_048 { + state ^= state << 13; + state ^= state >> 7; + state ^= state << 17; + + let mut cfg = ProxyConfig::default(); + cfg.network.prefer = if (state & 1) == 0 { 4 } else { 6 }; + cfg.network.ipv6 = Some((state & 2) != 0); + cfg.default_dc = Some(((state >> 8) as u8).max(1)); + + let dc_idx = (state as i16).wrapping_sub(16_384); + let resolved = get_dc_addr_static(dc_idx, &cfg).expect("dc resolution must never fail"); + + assert_eq!(resolved.port(), crate::protocol::constants::TG_DATACENTER_PORT); + let expect_v6 = cfg.network.prefer == 6 && cfg.network.ipv6.unwrap_or(true); + assert_eq!(resolved.is_ipv6(), expect_v6); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn subtle_integration_parallel_same_dc_logs_one_line() { + let _guard = unknown_dc_test_lock() + .lock() + .expect("unknown dc test lock must be available"); + clear_unknown_dc_log_cache_for_testing(); + + let rel_dir = format!("target/telemt-direct-relay-same-{}", std::process::id()); + let rel_file = format!("{rel_dir}/unknown-dc.log"); + let abs_dir = std::env::current_dir() + .expect("cwd must be available") + .join(&rel_dir); + std::fs::create_dir_all(&abs_dir).expect("log directory must be creatable"); + let abs_file = abs_dir.join("unknown-dc.log"); + let _ = std::fs::remove_file(&abs_file); + + let mut cfg = ProxyConfig::default(); + cfg.general.unknown_dc_file_log_enabled = true; + cfg.general.unknown_dc_log_path = Some(rel_file); + + let cfg = Arc::new(cfg); + let mut tasks = Vec::new(); + for _ in 0..32 { + let cfg = Arc::clone(&cfg); + tasks.push(tokio::spawn(async move { + let _ = get_dc_addr_static(31_777, cfg.as_ref()); + })); + } + for task in tasks { + task.await.expect("task must not panic"); + } + + for _ in 0..60 { + if let Ok(content) = std::fs::read_to_string(&abs_file) + && nonempty_line_count(&content) == 1 + { + return; + } + tokio::time::sleep(std::time::Duration::from_millis(20)).await; + } + + let content = std::fs::read_to_string(&abs_file).unwrap_or_default(); + assert_eq!(nonempty_line_count(&content), 1); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn subtle_integration_parallel_unique_dcs_log_unique_lines() { + let _guard = unknown_dc_test_lock() + .lock() + .expect("unknown dc test lock must be available"); + clear_unknown_dc_log_cache_for_testing(); + + let rel_dir = format!("target/telemt-direct-relay-unique-{}", std::process::id()); + let rel_file = format!("{rel_dir}/unknown-dc.log"); + let abs_dir = std::env::current_dir() + .expect("cwd must be available") + .join(&rel_dir); + std::fs::create_dir_all(&abs_dir).expect("log directory must be creatable"); + let abs_file = abs_dir.join("unknown-dc.log"); + let _ = std::fs::remove_file(&abs_file); + + let mut cfg = ProxyConfig::default(); + cfg.general.unknown_dc_file_log_enabled = true; + cfg.general.unknown_dc_log_path = Some(rel_file); + + let cfg = Arc::new(cfg); + let dcs = [31_901_i16, 31_902, 31_903, 31_904, 31_905, 31_906, 31_907, 31_908]; + let mut tasks = Vec::new(); + + for dc in dcs { + let cfg = Arc::clone(&cfg); + tasks.push(tokio::spawn(async move { + let _ = get_dc_addr_static(dc, cfg.as_ref()); + })); + } + + for task in tasks { + task.await.expect("task must not panic"); + } + + for _ in 0..80 { + if let Ok(content) = std::fs::read_to_string(&abs_file) + && nonempty_line_count(&content) >= 8 + { + return; + } + tokio::time::sleep(std::time::Duration::from_millis(20)).await; + } + + let content = std::fs::read_to_string(&abs_file).unwrap_or_default(); + assert!( + nonempty_line_count(&content) >= 8, + "expected at least one line per unique dc, content: {content}" + ); +} diff --git a/src/proxy/tests/handshake_adversarial_tests.rs b/src/proxy/tests/handshake_adversarial_tests.rs new file mode 100644 index 0000000..da93ef4 --- /dev/null +++ b/src/proxy/tests/handshake_adversarial_tests.rs @@ -0,0 +1,467 @@ +use super::*; +use std::sync::Arc; +use std::net::{IpAddr, Ipv4Addr}; +use std::time::{Duration, Instant}; +use crate::crypto::sha256; + +fn make_valid_mtproto_handshake(secret_hex: &str, proto_tag: ProtoTag, dc_idx: i16) -> [u8; HANDSHAKE_LEN] { + let secret = hex::decode(secret_hex).expect("secret hex must decode"); + let mut handshake = [0x5Au8; HANDSHAKE_LEN]; + for (idx, b) in handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN] + .iter_mut() + .enumerate() + { + *b = (idx as u8).wrapping_add(1); + } + + let dec_prekey = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN]; + let dec_iv_bytes = &handshake[SKIP_LEN + PREKEY_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]; + + let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len()); + dec_key_input.extend_from_slice(dec_prekey); + dec_key_input.extend_from_slice(&secret); + let dec_key = sha256(&dec_key_input); + + let mut dec_iv_arr = [0u8; IV_LEN]; + dec_iv_arr.copy_from_slice(dec_iv_bytes); + let dec_iv = u128::from_be_bytes(dec_iv_arr); + + let mut stream = AesCtr::new(&dec_key, dec_iv); + let keystream = stream.encrypt(&[0u8; HANDSHAKE_LEN]); + + let mut target_plain = [0u8; HANDSHAKE_LEN]; + target_plain[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes()); + target_plain[DC_IDX_POS..DC_IDX_POS + 2].copy_from_slice(&dc_idx.to_le_bytes()); + + for idx in PROTO_TAG_POS..HANDSHAKE_LEN { + handshake[idx] = target_plain[idx] ^ keystream[idx]; + } + + handshake +} + +fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> { + auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig { + let mut cfg = ProxyConfig::default(); + cfg.access.users.clear(); + cfg.access.users.insert("user".to_string(), secret_hex.to_string()); + cfg.access.ignore_time_skew = true; + cfg.general.modes.secure = true; + cfg +} + +// ------------------------------------------------------------------ +// Mutational Bit-Flipping Tests (OWASP ASVS 5.1.4) +// ------------------------------------------------------------------ + +#[tokio::test] +async fn mtproto_handshake_bit_flip_anywhere_rejected() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret_hex = "11223344556677889900aabbccddeeff"; + let base = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 2); + let config = test_config_with_secret_hex(secret_hex); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let peer: SocketAddr = "192.0.2.1:12345".parse().unwrap(); + + // Baseline check + let res = handle_mtproto_handshake(&base, tokio::io::empty(), tokio::io::sink(), peer, &config, &replay_checker, false, None).await; + match res { + HandshakeResult::Success(_) => {}, + _ => panic!("Baseline failed: expected Success"), + } + + // Flip bits in the encrypted part (beyond the key material) + for byte_pos in SKIP_LEN..HANDSHAKE_LEN { + let mut h = base; + h[byte_pos] ^= 0x01; // Flip 1 bit + let res = handle_mtproto_handshake(&h, tokio::io::empty(), tokio::io::sink(), peer, &config, &replay_checker, false, None).await; + assert!(matches!(res, HandshakeResult::BadClient { .. }), "Flip at byte {byte_pos} bit 0 must be rejected"); + } +} + +// ------------------------------------------------------------------ +// Adversarial Probing / Timing Neutrality (OWASP ASVS 5.1.7) +// ------------------------------------------------------------------ + +#[tokio::test] +async fn mtproto_handshake_timing_neutrality_mocked() { + let secret_hex = "00112233445566778899aabbccddeeff"; + let base = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 1); + let config = test_config_with_secret_hex(secret_hex); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let peer: SocketAddr = "192.0.2.2:54321".parse().unwrap(); + + const ITER: usize = 50; + + let mut start = Instant::now(); + for _ in 0..ITER { + let _ = handle_mtproto_handshake(&base, tokio::io::empty(), tokio::io::sink(), peer, &config, &replay_checker, false, None).await; + } + let duration_success = start.elapsed(); + + start = Instant::now(); + for i in 0..ITER { + let mut h = base; + h[SKIP_LEN + (i % 48)] ^= 0xFF; + let _ = handle_mtproto_handshake(&h, tokio::io::empty(), tokio::io::sink(), peer, &config, &replay_checker, false, None).await; + } + let duration_fail = start.elapsed(); + + let avg_diff_ms = (duration_success.as_millis() as f64 - duration_fail.as_millis() as f64).abs() / ITER as f64; + + // Threshold (loose for CI) + assert!(avg_diff_ms < 100.0, "Timing difference too large: {} ms/iter", avg_diff_ms); +} + +// ------------------------------------------------------------------ +// Stress Tests (OWASP ASVS 5.1.6) +// ------------------------------------------------------------------ + +#[tokio::test] +async fn auth_probe_throttle_saturation_stress() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let now = Instant::now(); + + // Record enough failures for one IP to trigger backoff + let target_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)); + for _ in 0..AUTH_PROBE_BACKOFF_START_FAILS { + auth_probe_record_failure(target_ip, now); + } + + assert!(auth_probe_is_throttled(target_ip, now)); + + // Stress test with many unique IPs + for i in 0..500u32 { + let ip = IpAddr::V4(Ipv4Addr::new(203, 0, 113, (i % 256) as u8)); + auth_probe_record_failure(ip, now); + } + + let tracked = AUTH_PROBE_STATE + .get() + .map(|state| state.len()) + .unwrap_or(0); + assert!( + tracked <= AUTH_PROBE_TRACK_MAX_ENTRIES, + "auth probe state grew past hard cap: {tracked} > {AUTH_PROBE_TRACK_MAX_ENTRIES}" + ); +} + +#[tokio::test] +async fn mtproto_handshake_abridged_prefix_rejected() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let mut handshake = [0x5Au8; HANDSHAKE_LEN]; + handshake[0] = 0xef; // Abridged prefix + let config = ProxyConfig::default(); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let peer: SocketAddr = "192.0.2.3:12345".parse().unwrap(); + + let res = handle_mtproto_handshake(&handshake, tokio::io::empty(), tokio::io::sink(), peer, &config, &replay_checker, false, None).await; + // MTProxy stops immediately on 0xef + assert!(matches!(res, HandshakeResult::BadClient { .. })); +} + +#[tokio::test] +async fn mtproto_handshake_preferred_user_mismatch_continues() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret1_hex = "11111111111111111111111111111111"; + let secret2_hex = "22222222222222222222222222222222"; + + let base = make_valid_mtproto_handshake(secret2_hex, ProtoTag::Secure, 1); + let mut config = ProxyConfig::default(); + config.access.users.insert("user1".to_string(), secret1_hex.to_string()); + config.access.users.insert("user2".to_string(), secret2_hex.to_string()); + config.access.ignore_time_skew = true; + config.general.modes.secure = true; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let peer: SocketAddr = "192.0.2.4:12345".parse().unwrap(); + + // Even if we prefer user1, if user2 matches, it should succeed. + let res = handle_mtproto_handshake(&base, tokio::io::empty(), tokio::io::sink(), peer, &config, &replay_checker, false, Some("user1")).await; + if let HandshakeResult::Success((_, _, success)) = res { + assert_eq!(success.user, "user2"); + } else { + panic!("Handshake failed even though user2 matched"); + } +} + +#[tokio::test] +async fn mtproto_handshake_concurrent_flood_stability() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret_hex = "00112233445566778899aabbccddeeff"; + let base = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 1); + let mut config = test_config_with_secret_hex(secret_hex); + config.access.ignore_time_skew = true; + let replay_checker = Arc::new(ReplayChecker::new(1024, Duration::from_secs(60))); + let config = Arc::new(config); + + let mut tasks = Vec::new(); + for i in 0..50 { + let base = base; + let config = Arc::clone(&config); + let replay_checker = Arc::clone(&replay_checker); + let peer: SocketAddr = format!("192.0.2.{}:12345", (i % 254) + 1).parse().unwrap(); + + tasks.push(tokio::spawn(async move { + let res = handle_mtproto_handshake(&base, tokio::io::empty(), tokio::io::sink(), peer, &config, &replay_checker, false, None).await; + matches!(res, HandshakeResult::Success(_)) + })); + } + + // We don't necessarily care if they all succeed (some might fail due to replay if they hit the same chunk), + // but the system must not panic or hang. + for task in tasks { + let _ = task.await.unwrap(); + } +} + +#[tokio::test] +async fn mtproto_replay_is_rejected_across_distinct_peers() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret_hex = "0123456789abcdeffedcba9876543210"; + let handshake = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 2); + let config = test_config_with_secret_hex(secret_hex); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + + let first_peer: SocketAddr = "198.51.100.10:41001".parse().unwrap(); + let second_peer: SocketAddr = "198.51.100.11:41002".parse().unwrap(); + + let first = handle_mtproto_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + first_peer, + &config, + &replay_checker, + false, + None, + ) + .await; + assert!(matches!(first, HandshakeResult::Success(_))); + + let replay = handle_mtproto_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + second_peer, + &config, + &replay_checker, + false, + None, + ) + .await; + assert!(matches!(replay, HandshakeResult::BadClient { .. })); +} + +#[tokio::test] +async fn mtproto_blackhat_mutation_corpus_never_panics_and_stays_fail_closed() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret_hex = "89abcdef012345670123456789abcdef"; + let base = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 2); + let config = test_config_with_secret_hex(secret_hex); + let replay_checker = ReplayChecker::new(8192, Duration::from_secs(60)); + + for i in 0..512usize { + let mut mutated = base; + let pos = (SKIP_LEN + (i * 31) % (HANDSHAKE_LEN - SKIP_LEN)).min(HANDSHAKE_LEN - 1); + mutated[pos] ^= ((i as u8) | 1).rotate_left((i % 8) as u32); + let peer: SocketAddr = SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(198, 18, (i / 254) as u8, (i % 254 + 1) as u8)), + 42000 + (i % 1000) as u16, + ); + + let res = tokio::time::timeout( + Duration::from_millis(250), + handle_mtproto_handshake( + &mutated, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ), + ) + .await + .expect("fuzzed mutation must complete in bounded time"); + + assert!( + matches!(res, HandshakeResult::BadClient { .. } | HandshakeResult::Success(_)), + "mutation corpus must stay within explicit handshake outcomes" + ); + } +} + +#[tokio::test] +async fn auth_probe_success_clears_throttled_peer_state() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let target_ip = IpAddr::V4(Ipv4Addr::new(203, 0, 113, 90)); + let now = Instant::now(); + for _ in 0..AUTH_PROBE_BACKOFF_START_FAILS { + auth_probe_record_failure(target_ip, now); + } + assert!(auth_probe_is_throttled(target_ip, now)); + + auth_probe_record_success(target_ip); + assert!( + !auth_probe_is_throttled(target_ip, now + Duration::from_millis(1)), + "successful auth must clear per-peer throttle state" + ); +} + +#[tokio::test] +async fn mtproto_invalid_storm_over_cap_keeps_probe_map_hard_bounded() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret_hex = "00112233445566778899aabbccddeeff"; + let mut invalid = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 2); + invalid[SKIP_LEN + 3] ^= 0xff; + + let config = test_config_with_secret_hex(secret_hex); + let replay_checker = ReplayChecker::new(64, Duration::from_secs(60)); + + for i in 0..(AUTH_PROBE_TRACK_MAX_ENTRIES + 512) { + let peer: SocketAddr = SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(10, (i / 65535) as u8, ((i / 255) % 255) as u8, (i % 255 + 1) as u8)), + 43000 + (i % 20000) as u16, + ); + let res = handle_mtproto_handshake( + &invalid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + assert!(matches!(res, HandshakeResult::BadClient { .. })); + } + + let tracked = AUTH_PROBE_STATE + .get() + .map(|state| state.len()) + .unwrap_or(0); + assert!( + tracked <= AUTH_PROBE_TRACK_MAX_ENTRIES, + "probe map must remain bounded under invalid storm: {tracked}" + ); +} + +#[tokio::test] +async fn mtproto_property_style_multi_bit_mutations_fail_closed_or_auth_only() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret_hex = "f0e1d2c3b4a5968778695a4b3c2d1e0f"; + let base = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 2); + let config = test_config_with_secret_hex(secret_hex); + let replay_checker = ReplayChecker::new(10_000, Duration::from_secs(60)); + + let mut seed: u64 = 0xC0FF_EE12_3456_789A; + for i in 0..2_048usize { + let mut mutated = base; + for _ in 0..4 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + let idx = SKIP_LEN + (seed as usize % (HANDSHAKE_LEN - SKIP_LEN)); + mutated[idx] ^= ((seed >> 11) as u8).wrapping_add(1); + } + + let peer: SocketAddr = SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(10, 123, (i / 254) as u8, (i % 254 + 1) as u8)), + 45000 + (i % 2000) as u16, + ); + + let outcome = tokio::time::timeout( + Duration::from_millis(250), + handle_mtproto_handshake( + &mutated, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ), + ) + .await + .expect("mutation iteration must complete in bounded time"); + + assert!( + matches!(outcome, HandshakeResult::BadClient { .. } | HandshakeResult::Success(_)), + "mutations must remain fail-closed/auth-only" + ); + } +} + +#[tokio::test] +#[ignore = "heavy soak; run manually"] +async fn mtproto_blackhat_20k_mutation_soak_never_panics() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret_hex = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"; + let base = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 2); + let config = test_config_with_secret_hex(secret_hex); + let replay_checker = ReplayChecker::new(50_000, Duration::from_secs(120)); + + let mut seed: u64 = 0xA5A5_5A5A_DEAD_BEEF; + for i in 0..20_000usize { + let mut mutated = base; + for _ in 0..3 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + let idx = SKIP_LEN + (seed as usize % (HANDSHAKE_LEN - SKIP_LEN)); + mutated[idx] ^= ((seed >> 19) as u8).wrapping_add(1); + } + + let peer: SocketAddr = SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(172, 31, (i / 254) as u8, (i % 254 + 1) as u8)), + 47000 + (i % 15000) as u16, + ); + + let _ = tokio::time::timeout( + Duration::from_millis(250), + handle_mtproto_handshake( + &mutated, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ), + ) + .await + .expect("soak mutation must complete in bounded time"); + } +} diff --git a/src/proxy/tests/handshake_auth_probe_hardening_adversarial_tests.rs b/src/proxy/tests/handshake_auth_probe_hardening_adversarial_tests.rs new file mode 100644 index 0000000..d8fac4f --- /dev/null +++ b/src/proxy/tests/handshake_auth_probe_hardening_adversarial_tests.rs @@ -0,0 +1,187 @@ +use super::*; +use std::net::{IpAddr, Ipv4Addr}; +use std::time::{Duration, Instant}; + +fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> { + auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +#[test] +fn positive_preauth_throttle_activates_after_failure_threshold() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let ip = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 20)); + let now = Instant::now(); + + for _ in 0..AUTH_PROBE_BACKOFF_START_FAILS { + auth_probe_record_failure(ip, now); + } + + assert!( + auth_probe_is_throttled(ip, now), + "peer must be throttled once fail streak reaches threshold" + ); +} + +#[test] +fn negative_unrelated_peer_remains_unthrottled() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let attacker = IpAddr::V4(Ipv4Addr::new(203, 0, 113, 12)); + let benign = IpAddr::V4(Ipv4Addr::new(203, 0, 113, 13)); + let now = Instant::now(); + + for _ in 0..AUTH_PROBE_BACKOFF_START_FAILS { + auth_probe_record_failure(attacker, now); + } + + assert!(auth_probe_is_throttled(attacker, now)); + assert!( + !auth_probe_is_throttled(benign, now), + "throttle state must stay scoped to normalized peer key" + ); +} + +#[test] +fn edge_expired_entry_is_pruned_and_no_longer_throttled() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let ip = IpAddr::V4(Ipv4Addr::new(192, 0, 2, 41)); + let base = Instant::now(); + for _ in 0..AUTH_PROBE_BACKOFF_START_FAILS { + auth_probe_record_failure(ip, base); + } + + let expired_at = base + Duration::from_secs(AUTH_PROBE_TRACK_RETENTION_SECS + 1); + assert!( + !auth_probe_is_throttled(ip, expired_at), + "expired entries must not keep throttling peers" + ); + + let state = auth_probe_state_map(); + assert!( + state.get(&normalize_auth_probe_ip(ip)).is_none(), + "expired lookup should prune stale state" + ); +} + +#[test] +fn adversarial_saturation_grace_requires_extra_failures_before_preauth_throttle() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let ip = IpAddr::V4(Ipv4Addr::new(198, 18, 0, 7)); + let now = Instant::now(); + + for _ in 0..AUTH_PROBE_BACKOFF_START_FAILS { + auth_probe_record_failure(ip, now); + } + auth_probe_note_saturation(now); + + assert!( + !auth_probe_should_apply_preauth_throttle(ip, now), + "during global saturation, peer must receive configured grace window" + ); + + for _ in 0..AUTH_PROBE_SATURATION_GRACE_FAILS { + auth_probe_record_failure(ip, now + Duration::from_millis(1)); + } + + assert!( + auth_probe_should_apply_preauth_throttle(ip, now + Duration::from_millis(1)), + "after grace failures are exhausted, preauth throttle must activate" + ); +} + +#[test] +fn integration_over_cap_insertion_keeps_probe_map_bounded() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let now = Instant::now(); + for idx in 0..(AUTH_PROBE_TRACK_MAX_ENTRIES + 1024) { + let ip = IpAddr::V4(Ipv4Addr::new( + 10, + ((idx / 65_536) % 256) as u8, + ((idx / 256) % 256) as u8, + (idx % 256) as u8, + )); + auth_probe_record_failure(ip, now); + } + + let tracked = auth_probe_state_map().len(); + assert!( + tracked <= AUTH_PROBE_TRACK_MAX_ENTRIES, + "probe map must remain hard bounded under insertion storm" + ); +} + +#[test] +fn light_fuzz_randomized_failures_preserve_cap_and_nonzero_streaks() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let mut seed = 0x4D53_5854_6F66_6175u64; + let now = Instant::now(); + + for _ in 0..8192 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let ip = IpAddr::V4(Ipv4Addr::new( + (seed >> 24) as u8, + (seed >> 16) as u8, + (seed >> 8) as u8, + seed as u8, + )); + auth_probe_record_failure(ip, now + Duration::from_millis((seed & 0x3f) as u64)); + } + + let state = auth_probe_state_map(); + assert!(state.len() <= AUTH_PROBE_TRACK_MAX_ENTRIES); + for entry in state.iter() { + assert!(entry.value().fail_streak > 0); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_parallel_failure_flood_keeps_state_hard_capped() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let start = Instant::now(); + let mut tasks = Vec::new(); + + for worker in 0..8u8 { + tasks.push(tokio::spawn(async move { + for i in 0..4096u32 { + let ip = IpAddr::V4(Ipv4Addr::new( + 172, + worker, + ((i >> 8) & 0xff) as u8, + (i & 0xff) as u8, + )); + auth_probe_record_failure(ip, start + Duration::from_millis((i % 4) as u64)); + } + })); + } + + for task in tasks { + task.await.expect("stress worker must not panic"); + } + + let tracked = auth_probe_state_map().len(); + assert!( + tracked <= AUTH_PROBE_TRACK_MAX_ENTRIES, + "parallel failure flood must not exceed cap" + ); + + let probe = IpAddr::V4(Ipv4Addr::new(172, 3, 4, 5)); + let _ = auth_probe_is_throttled(probe, start + Duration::from_millis(2)); +} diff --git a/src/proxy/tests/handshake_fuzz_security_tests.rs b/src/proxy/tests/handshake_fuzz_security_tests.rs new file mode 100644 index 0000000..d72c9cd --- /dev/null +++ b/src/proxy/tests/handshake_fuzz_security_tests.rs @@ -0,0 +1,270 @@ +use super::*; +use crate::config::ProxyConfig; +use crate::crypto::AesCtr; +use crate::crypto::sha256; +use crate::protocol::constants::ProtoTag; +use crate::stats::ReplayChecker; +use std::net::SocketAddr; +use std::sync::MutexGuard; +use tokio::time::{timeout, Duration as TokioDuration}; + +fn make_mtproto_handshake_with_proto_bytes( + secret_hex: &str, + proto_bytes: [u8; 4], + dc_idx: i16, +) -> [u8; HANDSHAKE_LEN] { + let secret = hex::decode(secret_hex).expect("secret hex must decode"); + let mut handshake = [0x5Au8; HANDSHAKE_LEN]; + for (idx, b) in handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN] + .iter_mut() + .enumerate() + { + *b = (idx as u8).wrapping_add(1); + } + + let dec_prekey = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN]; + let dec_iv_bytes = &handshake[SKIP_LEN + PREKEY_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]; + + let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len()); + dec_key_input.extend_from_slice(dec_prekey); + dec_key_input.extend_from_slice(&secret); + let dec_key = sha256(&dec_key_input); + + let mut dec_iv_arr = [0u8; IV_LEN]; + dec_iv_arr.copy_from_slice(dec_iv_bytes); + let dec_iv = u128::from_be_bytes(dec_iv_arr); + + let mut stream = AesCtr::new(&dec_key, dec_iv); + let keystream = stream.encrypt(&[0u8; HANDSHAKE_LEN]); + + let mut target_plain = [0u8; HANDSHAKE_LEN]; + target_plain[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_bytes); + target_plain[DC_IDX_POS..DC_IDX_POS + 2].copy_from_slice(&dc_idx.to_le_bytes()); + + for idx in PROTO_TAG_POS..HANDSHAKE_LEN { + handshake[idx] = target_plain[idx] ^ keystream[idx]; + } + + handshake +} + +fn make_valid_mtproto_handshake(secret_hex: &str, proto_tag: ProtoTag, dc_idx: i16) -> [u8; HANDSHAKE_LEN] { + make_mtproto_handshake_with_proto_bytes(secret_hex, proto_tag.to_bytes(), dc_idx) +} + +fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig { + let mut cfg = ProxyConfig::default(); + cfg.access.users.clear(); + cfg.access.users.insert("user".to_string(), secret_hex.to_string()); + cfg.access.ignore_time_skew = true; + cfg.general.modes.secure = true; + cfg +} + +fn auth_probe_test_guard() -> MutexGuard<'static, ()> { + auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +#[tokio::test] +async fn mtproto_handshake_duplicate_digest_is_replayed_on_second_attempt() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret_hex = "11223344556677889900aabbccddeeff"; + let base = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 2); + let config = test_config_with_secret_hex(secret_hex); + let replay_checker = ReplayChecker::new(128, TokioDuration::from_secs(60)); + let peer: SocketAddr = "192.0.2.1:12345".parse().unwrap(); + + let first = handle_mtproto_handshake( + &base, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + assert!(matches!(first, HandshakeResult::Success(_))); + + let second = handle_mtproto_handshake( + &base, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + assert!(matches!(second, HandshakeResult::BadClient { .. })); + + clear_auth_probe_state_for_testing(); +} + +#[tokio::test] +async fn mtproto_handshake_fuzz_corpus_never_panics_and_stays_fail_closed() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret_hex = "00112233445566778899aabbccddeeff"; + let base = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 1); + let config = test_config_with_secret_hex(secret_hex); + let replay_checker = ReplayChecker::new(128, TokioDuration::from_secs(60)); + let peer: SocketAddr = "192.0.2.2:54321".parse().unwrap(); + + let mut corpus = Vec::<[u8; HANDSHAKE_LEN]>::new(); + + corpus.push(make_mtproto_handshake_with_proto_bytes( + secret_hex, + [0x00, 0x00, 0x00, 0x00], + 1, + )); + corpus.push(make_mtproto_handshake_with_proto_bytes( + secret_hex, + [0xff, 0xff, 0xff, 0xff], + 1, + )); + corpus.push(make_valid_mtproto_handshake( + "ffeeddccbbaa99887766554433221100", + ProtoTag::Secure, + 1, + )); + + let mut seed = 0xF0F0_F00D_BAAD_CAFEu64; + for _ in 0..32 { + let mut mutated = base; + for _ in 0..4 { + seed = seed.wrapping_mul(2862933555777941757).wrapping_add(3037000493); + let idx = SKIP_LEN + (seed as usize % (PREKEY_LEN + IV_LEN)); + mutated[idx] ^= ((seed >> 19) as u8).wrapping_add(1); + } + corpus.push(mutated); + } + + for (idx, input) in corpus.into_iter().enumerate() { + let result = timeout( + TokioDuration::from_secs(1), + handle_mtproto_handshake( + &input, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ), + ) + .await + .expect("fuzzed handshake must complete in time"); + + assert!( + matches!(result, HandshakeResult::BadClient { .. }), + "corpus item {idx} must fail closed" + ); + } + + clear_auth_probe_state_for_testing(); +} + +#[tokio::test] +async fn mtproto_handshake_mixed_corpus_never_panics_and_exact_duplicates_are_rejected() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret_hex = "99887766554433221100ffeeddccbbaa"; + let base = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 4); + let config = test_config_with_secret_hex(secret_hex); + let replay_checker = ReplayChecker::new(256, TokioDuration::from_secs(60)); + let peer: SocketAddr = "192.0.2.44:45444".parse().unwrap(); + + let first = timeout( + TokioDuration::from_secs(1), + handle_mtproto_handshake( + &base, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ), + ) + .await + .expect("base handshake must not hang"); + assert!(matches!(first, HandshakeResult::Success(_))); + + let replay = timeout( + TokioDuration::from_secs(1), + handle_mtproto_handshake( + &base, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ), + ) + .await + .expect("duplicate handshake must not hang"); + assert!(matches!(replay, HandshakeResult::BadClient { .. })); + + let mut corpus = Vec::<[u8; HANDSHAKE_LEN]>::new(); + + let mut prekey_flip = base; + prekey_flip[SKIP_LEN] ^= 0x80; + corpus.push(prekey_flip); + + let mut iv_flip = base; + iv_flip[SKIP_LEN + PREKEY_LEN] ^= 0x01; + corpus.push(iv_flip); + + let mut tail_flip = base; + tail_flip[SKIP_LEN + PREKEY_LEN + IV_LEN - 1] ^= 0x40; + corpus.push(tail_flip); + + let mut seed = 0xBADC_0FFE_EE11_4242u64; + for _ in 0..24 { + let mut mutated = base; + for _ in 0..3 { + seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1); + let idx = SKIP_LEN + (seed as usize % (PREKEY_LEN + IV_LEN)); + mutated[idx] ^= ((seed >> 16) as u8).wrapping_add(1); + } + corpus.push(mutated); + } + + for (idx, input) in corpus.iter().enumerate() { + let result = timeout( + TokioDuration::from_secs(1), + handle_mtproto_handshake( + input, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ), + ) + .await + .expect("fuzzed handshake must complete in time"); + + assert!( + matches!(result, HandshakeResult::BadClient { .. }), + "mixed corpus item {idx} must fail closed" + ); + } + + clear_auth_probe_state_for_testing(); +} \ No newline at end of file diff --git a/src/proxy/tests/handshake_saturation_poison_security_tests.rs b/src/proxy/tests/handshake_saturation_poison_security_tests.rs new file mode 100644 index 0000000..4c2ca5d --- /dev/null +++ b/src/proxy/tests/handshake_saturation_poison_security_tests.rs @@ -0,0 +1,71 @@ +use super::*; +use std::time::{Duration, Instant}; + +fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> { + auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +fn poison_saturation_mutex() { + let saturation = auth_probe_saturation_state(); + let poison_thread = std::thread::spawn(move || { + let _guard = saturation + .lock() + .expect("saturation mutex must be lockable for poison setup"); + panic!("intentional poison for saturation mutex resilience test"); + }); + let _ = poison_thread.join(); +} + +#[test] +fn auth_probe_saturation_note_recovers_after_mutex_poison() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + poison_saturation_mutex(); + + let now = Instant::now(); + auth_probe_note_saturation(now); + + assert!( + auth_probe_saturation_is_throttled_at_for_testing(now), + "poisoned saturation mutex must not disable saturation throttling" + ); +} + +#[test] +fn auth_probe_saturation_check_recovers_after_mutex_poison() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + poison_saturation_mutex(); + + { + let mut guard = auth_probe_saturation_state_lock(); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: Instant::now() + Duration::from_millis(10), + last_seen: Instant::now(), + }); + } + + assert!( + auth_probe_saturation_is_throttled_for_testing(), + "throttle check must recover poisoned saturation mutex and stay fail-closed" + ); +} + +#[test] +fn clear_auth_probe_state_clears_saturation_even_if_poisoned() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + poison_saturation_mutex(); + + auth_probe_note_saturation(Instant::now()); + assert!(auth_probe_saturation_is_throttled_for_testing()); + + clear_auth_probe_state_for_testing(); + assert!( + !auth_probe_saturation_is_throttled_for_testing(), + "clear helper must clear saturation state even after poison" + ); +} diff --git a/src/proxy/tests/handshake_security_tests.rs b/src/proxy/tests/handshake_security_tests.rs new file mode 100644 index 0000000..b646d1f --- /dev/null +++ b/src/proxy/tests/handshake_security_tests.rs @@ -0,0 +1,3444 @@ +use super::*; +use crate::crypto::{sha256, sha256_hmac}; +use dashmap::DashMap; +use rand::{RngExt, SeedableRng}; +use rand::rngs::StdRng; +use std::net::{IpAddr, Ipv4Addr}; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tokio::sync::Barrier; + +fn make_valid_tls_handshake(secret: &[u8], timestamp: u32) -> Vec { + let session_id_len: usize = 32; + let len = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + session_id_len; + let mut handshake = vec![0x42u8; len]; + + handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8; + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + + let computed = sha256_hmac(secret, &handshake); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] + .copy_from_slice(&digest); + handshake +} + +fn make_valid_tls_client_hello_with_alpn( + secret: &[u8], + timestamp: u32, + alpn_protocols: &[&[u8]], +) -> Vec { + let mut body = Vec::new(); + body.extend_from_slice(&TLS_VERSION); + body.extend_from_slice(&[0u8; 32]); + body.push(32); + body.extend_from_slice(&[0x42u8; 32]); + body.extend_from_slice(&2u16.to_be_bytes()); + body.extend_from_slice(&[0x13, 0x01]); + body.push(1); + body.push(0); + + let mut ext_blob = Vec::new(); + if !alpn_protocols.is_empty() { + let mut alpn_list = Vec::new(); + for proto in alpn_protocols { + alpn_list.push(proto.len() as u8); + alpn_list.extend_from_slice(proto); + } + let mut alpn_data = Vec::new(); + alpn_data.extend_from_slice(&(alpn_list.len() as u16).to_be_bytes()); + alpn_data.extend_from_slice(&alpn_list); + + ext_blob.extend_from_slice(&0x0010u16.to_be_bytes()); + ext_blob.extend_from_slice(&(alpn_data.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&alpn_data); + } + body.extend_from_slice(&(ext_blob.len() as u16).to_be_bytes()); + body.extend_from_slice(&ext_blob); + + let mut handshake = Vec::new(); + handshake.push(0x01); + let body_len = (body.len() as u32).to_be_bytes(); + handshake.extend_from_slice(&body_len[1..4]); + handshake.extend_from_slice(&body); + + let mut record = Vec::new(); + record.push(TLS_RECORD_HANDSHAKE); + record.extend_from_slice(&[0x03, 0x01]); + record.extend_from_slice(&(handshake.len() as u16).to_be_bytes()); + record.extend_from_slice(&handshake); + + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + let computed = sha256_hmac(secret, &record); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] + .copy_from_slice(&digest); + + record +} + +fn make_valid_tls_client_hello_with_sni_and_alpn( + secret: &[u8], + timestamp: u32, + sni_host: &str, + alpn_protocols: &[&[u8]], +) -> Vec { + let mut body = Vec::new(); + body.extend_from_slice(&TLS_VERSION); + body.extend_from_slice(&[0u8; 32]); + body.push(32); + body.extend_from_slice(&[0x42u8; 32]); + body.extend_from_slice(&2u16.to_be_bytes()); + body.extend_from_slice(&[0x13, 0x01]); + body.push(1); + body.push(0); + + let mut ext_blob = Vec::new(); + + let host_bytes = sni_host.as_bytes(); + let mut sni_payload = Vec::new(); + sni_payload.extend_from_slice(&((host_bytes.len() + 3) as u16).to_be_bytes()); + sni_payload.push(0); + sni_payload.extend_from_slice(&(host_bytes.len() as u16).to_be_bytes()); + sni_payload.extend_from_slice(host_bytes); + ext_blob.extend_from_slice(&0x0000u16.to_be_bytes()); + ext_blob.extend_from_slice(&(sni_payload.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&sni_payload); + + if !alpn_protocols.is_empty() { + let mut alpn_list = Vec::new(); + for proto in alpn_protocols { + alpn_list.push(proto.len() as u8); + alpn_list.extend_from_slice(proto); + } + let mut alpn_data = Vec::new(); + alpn_data.extend_from_slice(&(alpn_list.len() as u16).to_be_bytes()); + alpn_data.extend_from_slice(&alpn_list); + + ext_blob.extend_from_slice(&0x0010u16.to_be_bytes()); + ext_blob.extend_from_slice(&(alpn_data.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&alpn_data); + } + + body.extend_from_slice(&(ext_blob.len() as u16).to_be_bytes()); + body.extend_from_slice(&ext_blob); + + let mut handshake = Vec::new(); + handshake.push(0x01); + let body_len = (body.len() as u32).to_be_bytes(); + handshake.extend_from_slice(&body_len[1..4]); + handshake.extend_from_slice(&body); + + let mut record = Vec::new(); + record.push(TLS_RECORD_HANDSHAKE); + record.extend_from_slice(&[0x03, 0x01]); + record.extend_from_slice(&(handshake.len() as u16).to_be_bytes()); + record.extend_from_slice(&handshake); + + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + let computed = sha256_hmac(secret, &record); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] + .copy_from_slice(&digest); + + record +} + +fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig { + let mut cfg = ProxyConfig::default(); + cfg.access.users.clear(); + cfg.access + .users + .insert("user".to_string(), secret_hex.to_string()); + cfg.access.ignore_time_skew = true; + cfg +} + +fn make_valid_mtproto_handshake(secret_hex: &str, proto_tag: ProtoTag, dc_idx: i16) -> [u8; HANDSHAKE_LEN] { + let secret = hex::decode(secret_hex).expect("secret hex must decode for mtproto test helper"); + + let mut handshake = [0x5Au8; HANDSHAKE_LEN]; + for (idx, b) in handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN] + .iter_mut() + .enumerate() + { + *b = (idx as u8).wrapping_add(1); + } + + let dec_prekey = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN]; + let dec_iv_bytes = &handshake[SKIP_LEN + PREKEY_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]; + + let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len()); + dec_key_input.extend_from_slice(dec_prekey); + dec_key_input.extend_from_slice(&secret); + let dec_key = sha256(&dec_key_input); + + let mut dec_iv_arr = [0u8; IV_LEN]; + dec_iv_arr.copy_from_slice(dec_iv_bytes); + let dec_iv = u128::from_be_bytes(dec_iv_arr); + + let mut stream = AesCtr::new(&dec_key, dec_iv); + let keystream = stream.encrypt(&[0u8; HANDSHAKE_LEN]); + + let mut target_plain = [0u8; HANDSHAKE_LEN]; + target_plain[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes()); + target_plain[DC_IDX_POS..DC_IDX_POS + 2].copy_from_slice(&dc_idx.to_le_bytes()); + + for idx in PROTO_TAG_POS..HANDSHAKE_LEN { + handshake[idx] = target_plain[idx] ^ keystream[idx]; + } + + handshake +} + +#[test] +fn test_generate_tg_nonce() { + let client_enc_key = [0x24u8; 32]; + let client_enc_iv = 54321u128; + + let rng = SecureRandom::new(); + let (nonce, _tg_enc_key, _tg_enc_iv, _tg_dec_key, _tg_dec_iv) = generate_tg_nonce( + ProtoTag::Secure, + 2, + &client_enc_key, + client_enc_iv, + &rng, + false, + ); + + assert_eq!(nonce.len(), HANDSHAKE_LEN); + + let tag_bytes: [u8; 4] = nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].try_into().unwrap(); + assert_eq!(ProtoTag::from_bytes(tag_bytes), Some(ProtoTag::Secure)); +} + +#[test] +fn test_encrypt_tg_nonce() { + let client_enc_key = [0x24u8; 32]; + let client_enc_iv = 54321u128; + + let rng = SecureRandom::new(); + let (nonce, _, _, _, _) = generate_tg_nonce( + ProtoTag::Secure, + 2, + &client_enc_key, + client_enc_iv, + &rng, + false, + ); + + let encrypted = encrypt_tg_nonce(&nonce); + + assert_eq!(encrypted.len(), HANDSHAKE_LEN); + assert_eq!(&encrypted[..PROTO_TAG_POS], &nonce[..PROTO_TAG_POS]); + assert_ne!(&encrypted[PROTO_TAG_POS..], &nonce[PROTO_TAG_POS..]); +} + +#[test] +fn test_handshake_success_drop_does_not_panic() { + let success = HandshakeSuccess { + user: "test".to_string(), + dc_idx: 2, + proto_tag: ProtoTag::Secure, + dec_key: [0xAA; 32], + dec_iv: 0xBBBBBBBB, + enc_key: [0xCC; 32], + enc_iv: 0xDDDDDDDD, + peer: "198.51.100.10:1234".parse().unwrap(), + is_tls: true, + }; + + assert_eq!(success.dec_key, [0xAA; 32]); + assert_eq!(success.enc_key, [0xCC; 32]); + + drop(success); +} + +#[test] +fn test_generate_tg_nonce_enc_dec_material_is_consistent() { + let client_enc_key = [0x34u8; 32]; + let client_enc_iv = 0xffeeddccbbaa00998877665544332211u128; + let rng = SecureRandom::new(); + + let (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv) = generate_tg_nonce( + ProtoTag::Secure, + 7, + &client_enc_key, + client_enc_iv, + &rng, + false, + ); + + let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN]; + let dec_key_iv: Vec = enc_key_iv.iter().rev().copied().collect(); + + let mut expected_tg_enc_key = [0u8; 32]; + expected_tg_enc_key.copy_from_slice(&enc_key_iv[..KEY_LEN]); + let mut expected_tg_enc_iv_arr = [0u8; IV_LEN]; + expected_tg_enc_iv_arr.copy_from_slice(&enc_key_iv[KEY_LEN..]); + let expected_tg_enc_iv = u128::from_be_bytes(expected_tg_enc_iv_arr); + + let mut expected_tg_dec_key = [0u8; 32]; + expected_tg_dec_key.copy_from_slice(&dec_key_iv[..KEY_LEN]); + let mut expected_tg_dec_iv_arr = [0u8; IV_LEN]; + expected_tg_dec_iv_arr.copy_from_slice(&dec_key_iv[KEY_LEN..]); + let expected_tg_dec_iv = u128::from_be_bytes(expected_tg_dec_iv_arr); + + assert_eq!(tg_enc_key, expected_tg_enc_key); + assert_eq!(tg_enc_iv, expected_tg_enc_iv); + assert_eq!(tg_dec_key, expected_tg_dec_key); + assert_eq!(tg_dec_iv, expected_tg_dec_iv); + assert_eq!( + i16::from_le_bytes([nonce[DC_IDX_POS], nonce[DC_IDX_POS + 1]]), + 7, + "Generated nonce must keep target dc index in protocol slot" + ); +} + +#[test] +fn test_generate_tg_nonce_fast_mode_embeds_reversed_client_enc_material() { + let client_enc_key = [0xABu8; 32]; + let client_enc_iv = 0x11223344556677889900aabbccddeeffu128; + let rng = SecureRandom::new(); + + let (nonce, _, _, _, _) = generate_tg_nonce( + ProtoTag::Secure, + 9, + &client_enc_key, + client_enc_iv, + &rng, + true, + ); + + let mut expected = Vec::with_capacity(KEY_LEN + IV_LEN); + expected.extend_from_slice(&client_enc_key); + expected.extend_from_slice(&client_enc_iv.to_be_bytes()); + expected.reverse(); + + assert_eq!(&nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN], expected.as_slice()); +} + +#[test] +fn test_encrypt_tg_nonce_with_ciphers_matches_manual_suffix_encryption() { + let client_enc_key = [0x24u8; 32]; + let client_enc_iv = 54321u128; + + let rng = SecureRandom::new(); + let (nonce, _, _, _, _) = generate_tg_nonce( + ProtoTag::Secure, + 2, + &client_enc_key, + client_enc_iv, + &rng, + false, + ); + + let (encrypted, _, _) = encrypt_tg_nonce_with_ciphers(&nonce); + + let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN]; + let mut expected_enc_key = [0u8; 32]; + expected_enc_key.copy_from_slice(&enc_key_iv[..KEY_LEN]); + let mut expected_enc_iv_arr = [0u8; IV_LEN]; + expected_enc_iv_arr.copy_from_slice(&enc_key_iv[KEY_LEN..]); + let expected_enc_iv = u128::from_be_bytes(expected_enc_iv_arr); + + let mut manual_encryptor = AesCtr::new(&expected_enc_key, expected_enc_iv); + let manual = manual_encryptor.encrypt(&nonce); + + assert_eq!(encrypted.len(), HANDSHAKE_LEN); + assert_eq!(&encrypted[..PROTO_TAG_POS], &nonce[..PROTO_TAG_POS]); + assert_eq!( + &encrypted[PROTO_TAG_POS..], + &manual[PROTO_TAG_POS..], + "Encrypted nonce suffix must match AES-CTR output with derived enc key/iv" + ); +} + +#[tokio::test] +async fn tls_replay_second_identical_handshake_is_rejected() { + let secret = [0x11u8; 16]; + let config = test_config_with_secret_hex("11111111111111111111111111111111"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.21:44321".parse().unwrap(); + let handshake = make_valid_tls_handshake(&secret, 0); + + let first = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(first, HandshakeResult::Success(_))); + + let second = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(second, HandshakeResult::BadClient { .. })); +} + +#[tokio::test] +async fn tls_replay_with_ignore_time_skew_and_small_boot_timestamp_is_still_blocked() { + let secret = [0x19u8; 16]; + let config = test_config_with_secret_hex("19191919191919191919191919191919"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.121:44321".parse().unwrap(); + let handshake = make_valid_tls_handshake(&secret, 1); + + let first = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(first, HandshakeResult::Success(_))); + + let replay = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!( + matches!(replay, HandshakeResult::BadClient { .. }), + "ignore_time_skew must not weaken replay rejection for small boot timestamps" + ); +} + +#[tokio::test] +async fn tls_replay_concurrent_identical_handshake_allows_exactly_one_success() { + let secret = [0x77u8; 16]; + let config = Arc::new(test_config_with_secret_hex("77777777777777777777777777777777")); + let replay_checker = Arc::new(ReplayChecker::new(4096, Duration::from_secs(60))); + let rng = Arc::new(SecureRandom::new()); + let handshake = Arc::new(make_valid_tls_handshake(&secret, 0)); + + let mut tasks = Vec::new(); + for _ in 0..50 { + let config = config.clone(); + let replay_checker = replay_checker.clone(); + let rng = rng.clone(); + let handshake = handshake.clone(); + tasks.push(tokio::spawn(async move { + handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + "198.51.100.22:45000".parse().unwrap(), + &config, + &replay_checker, + &rng, + None, + ) + .await + })); + } + + let mut success_count = 0usize; + for task in tasks { + let result = task.await.unwrap(); + if matches!(result, HandshakeResult::Success(_)) { + success_count += 1; + } else { + assert!(matches!(result, HandshakeResult::BadClient { .. })); + } + } + + assert_eq!( + success_count, 1, + "Concurrent replay attempts must allow exactly one successful handshake" + ); +} + +#[tokio::test] +async fn tls_replay_matrix_rotating_peers_first_accept_then_rejects() { + let secret = [0x52u8; 16]; + let config = test_config_with_secret_hex("52525252525252525252525252525252"); + let replay_checker = ReplayChecker::new(4096, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let handshake = make_valid_tls_handshake(&secret, 17); + + let first_peer: SocketAddr = "198.51.100.31:44001".parse().unwrap(); + let first = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + first_peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(first, HandshakeResult::Success(_))); + + for i in 0..128u16 { + let peer = SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(198, 51, 100, ((i % 250) + 1) as u8)), + 45000 + i, + ); + let replay = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!( + matches!(replay, HandshakeResult::BadClient { .. }), + "replay digest must be rejected regardless of source peer rotation" + ); + } +} + +#[tokio::test] +async fn adversarial_tls_replay_churn_allows_only_unique_digests() { + let secret = [0x5Au8; 16]; + let mut config = test_config_with_secret_hex("5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a"); + config.access.ignore_time_skew = true; + let config = Arc::new(config); + let replay_checker = Arc::new(ReplayChecker::new(8192, Duration::from_secs(60))); + let rng = Arc::new(SecureRandom::new()); + + let make_tagged_handshake = |timestamp: u32, tag: u8| { + let session_id_len: usize = 32; + let len = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + session_id_len; + let mut handshake = vec![tag; len]; + + handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8; + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + let computed = sha256_hmac(&secret, &handshake); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] + .copy_from_slice(&digest); + handshake + }; + + let mut tasks = Vec::new(); + + // 128 exact duplicates: only one should pass. + let duplicated = Arc::new(make_valid_tls_handshake(&secret, 999)); + for i in 0..128u16 { + let config = Arc::clone(&config); + let replay_checker = Arc::clone(&replay_checker); + let rng = Arc::clone(&rng); + let duplicated = Arc::clone(&duplicated); + tasks.push(tokio::spawn(async move { + let peer = SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(203, 0, 113, ((i % 250) + 1) as u8)), + 46000 + i, + ); + handle_tls_handshake( + &duplicated, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await + })); + } + + // 128 unique timestamps: all should pass because HMAC digest differs. + for i in 0..128u16 { + let config = Arc::clone(&config); + let replay_checker = Arc::clone(&replay_checker); + let rng = Arc::clone(&rng); + let handshake = make_tagged_handshake(10_000 + i as u32, (i as u8).wrapping_add(0x80)); + tasks.push(tokio::spawn(async move { + let peer = SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(198, 18, 0, ((i % 250) + 1) as u8)), + 47000 + i, + ); + handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await + })); + } + + let mut duplicate_success = 0usize; + let mut duplicate_reject = 0usize; + let mut unique_success = 0usize; + let mut unique_reject = 0usize; + + for (idx, task) in tasks.into_iter().enumerate() { + let result = task.await.unwrap(); + let is_duplicate_group = idx < 128; + match result { + HandshakeResult::Success(_) => { + if is_duplicate_group { + duplicate_success += 1; + } else { + unique_success += 1; + } + } + HandshakeResult::BadClient { .. } => { + if is_duplicate_group { + duplicate_reject += 1; + } else { + unique_reject += 1; + } + } + HandshakeResult::Error(e) => panic!("unexpected handshake error in churn test: {e}"), + } + } + + assert_eq!( + duplicate_success, 1, + "duplicate replay group must allow exactly one successful handshake" + ); + assert_eq!( + duplicate_reject, 127, + "duplicate replay group must reject all remaining replays" + ); + assert_eq!( + unique_success, 128, + "unique digest group must fully pass under replay churn" + ); + assert_eq!( + unique_reject, 0, + "unique digest group must not be falsely rejected as replay" + ); +} + +#[tokio::test] +async fn invalid_tls_probe_does_not_pollute_replay_cache() { + let config = test_config_with_secret_hex("11111111111111111111111111111111"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.23:44322".parse().unwrap(); + + let mut invalid = vec![0x42u8; tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + 32]; + invalid[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = 32; + + let before = replay_checker.stats(); + + let result = handle_tls_handshake( + &invalid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + let after = replay_checker.stats(); + + assert!(matches!(result, HandshakeResult::BadClient { .. })); + assert_eq!(before.total_additions, after.total_additions); + assert_eq!(before.total_hits, after.total_hits); +} + +#[tokio::test] +async fn empty_decoded_secret_is_rejected() { + let _guard = warned_secrets_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_warned_secrets_for_testing(); + let config = test_config_with_secret_hex(""); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.24:44323".parse().unwrap(); + let handshake = make_valid_tls_handshake(&[], 0); + + let result = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(result, HandshakeResult::BadClient { .. })); +} + +#[tokio::test] +async fn wrong_length_decoded_secret_is_rejected() { + let _guard = warned_secrets_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_warned_secrets_for_testing(); + let config = test_config_with_secret_hex("aa"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.25:44324".parse().unwrap(); + let handshake = make_valid_tls_handshake(&[0xaau8], 0); + + let result = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(result, HandshakeResult::BadClient { .. })); +} + +#[tokio::test] +async fn invalid_mtproto_probe_does_not_pollute_replay_cache() { + let config = test_config_with_secret_hex("11111111111111111111111111111111"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let peer: SocketAddr = "198.51.100.26:44325".parse().unwrap(); + let handshake = [0u8; HANDSHAKE_LEN]; + + let before = replay_checker.stats(); + let result = handle_mtproto_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + let after = replay_checker.stats(); + + assert!(matches!(result, HandshakeResult::BadClient { .. })); + assert_eq!(before.total_additions, after.total_additions); + assert_eq!(before.total_hits, after.total_hits); +} + +#[tokio::test] +async fn mixed_secret_lengths_keep_valid_user_authenticating() { + let _probe_guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + let _guard = warned_secrets_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_warned_secrets_for_testing(); + clear_auth_probe_state_for_testing(); + let good_secret = [0x22u8; 16]; + let mut config = ProxyConfig::default(); + config.access.users.clear(); + config + .access + .users + .insert("broken_user".to_string(), "aa".to_string()); + config + .access + .users + .insert("valid_user".to_string(), "22222222222222222222222222222222".to_string()); + config.access.ignore_time_skew = true; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.27:44326".parse().unwrap(); + let handshake = make_valid_tls_handshake(&good_secret, 0); + + let result = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(result, HandshakeResult::Success(_))); +} + +#[tokio::test] +async fn tls_sni_preferred_user_hint_selects_matching_identity_first() { + let shared_secret = [0x3Bu8; 16]; + let mut config = ProxyConfig::default(); + config.access.users.clear(); + config.access.users.insert( + "user-a".to_string(), + "3b3b3b3b3b3b3b3b3b3b3b3b3b3b3b3b".to_string(), + ); + config.access.users.insert( + "user-b".to_string(), + "3b3b3b3b3b3b3b3b3b3b3b3b3b3b3b3b".to_string(), + ); + config.access.ignore_time_skew = true; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.188:44326".parse().unwrap(); + let handshake = make_valid_tls_client_hello_with_sni_and_alpn( + &shared_secret, + 0, + "user-b", + &[b"h2"], + ); + + let result = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + match result { + HandshakeResult::Success((_, _, user)) => { + assert_eq!( + user, "user-b", + "TLS SNI preferred-user hint must select matching identity before equivalent decoys" + ); + } + _ => panic!("TLS handshake must succeed for valid shared-secret SNI case"), + } +} + +#[test] +fn stress_decode_user_secrets_keeps_preferred_user_first_in_large_set() { + let mut config = ProxyConfig::default(); + config.access.users.clear(); + + let preferred_user = "target-user.example".to_string(); + let secret_hex = "7f7f7f7f7f7f7f7f7f7f7f7f7f7f7f7f".to_string(); + + for i in 0..4096usize { + config.access.users.insert( + format!("decoy-{i:04}.example"), + secret_hex.clone(), + ); + } + config + .access + .users + .insert(preferred_user.clone(), secret_hex.clone()); + + let decoded = decode_user_secrets(&config, Some(preferred_user.as_str())); + assert_eq!( + decoded.len(), + config.access.users.len(), + "decoded secret set must preserve full user cardinality under stress" + ); + assert_eq!( + decoded.first().map(|(name, _)| name.as_str()), + Some(preferred_user.as_str()), + "preferred user must be first even under adversarial large user sets" + ); + assert_eq!( + decoded + .iter() + .filter(|(name, _)| name == &preferred_user) + .count(), + 1, + "preferred user must appear exactly once in decoded list" + ); +} + +#[tokio::test] +async fn stress_tls_sni_preferred_user_hint_scales_to_large_user_set() { + let shared_secret = [0x7Fu8; 16]; + let mut config = ProxyConfig::default(); + config.access.users.clear(); + config.access.ignore_time_skew = true; + + let preferred_user = "target-user.example".to_string(); + let secret_hex = "7f7f7f7f7f7f7f7f7f7f7f7f7f7f7f7f".to_string(); + + for i in 0..4096usize { + config.access.users.insert( + format!("decoy-{i:04}.example"), + secret_hex.clone(), + ); + } + config + .access + .users + .insert(preferred_user.clone(), secret_hex); + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.189:44326".parse().unwrap(); + let handshake = make_valid_tls_client_hello_with_sni_and_alpn( + &shared_secret, + 0, + preferred_user.as_str(), + &[b"h2"], + ); + + let result = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + match result { + HandshakeResult::Success((_, _, user)) => { + assert_eq!( + user, + preferred_user, + "SNI preferred-user hint must remain stable under large user cardinality" + ); + } + _ => panic!("TLS handshake must succeed for valid preferred-user stress case"), + } +} + +#[tokio::test] +async fn alpn_enforce_rejects_unsupported_client_alpn() { + let secret = [0x33u8; 16]; + let mut config = test_config_with_secret_hex("33333333333333333333333333333333"); + config.censorship.alpn_enforce = true; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.28:44327".parse().unwrap(); + let handshake = make_valid_tls_client_hello_with_alpn(&secret, 0, &[b"h3"]); + + let result = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(result, HandshakeResult::BadClient { .. })); +} + +#[tokio::test] +async fn alpn_enforce_accepts_h2() { + let secret = [0x44u8; 16]; + let mut config = test_config_with_secret_hex("44444444444444444444444444444444"); + config.censorship.alpn_enforce = true; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.29:44328".parse().unwrap(); + let handshake = make_valid_tls_client_hello_with_alpn(&secret, 0, &[b"h2", b"h3"]); + + let result = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(result, HandshakeResult::Success(_))); +} + +#[tokio::test] +async fn malformed_tls_classes_complete_within_bounded_time() { + let secret = [0x55u8; 16]; + let mut config = test_config_with_secret_hex("55555555555555555555555555555555"); + config.censorship.alpn_enforce = true; + + let replay_checker = ReplayChecker::new(512, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.30:44329".parse().unwrap(); + + let too_short = vec![0x16, 0x03, 0x01]; + + let mut bad_hmac = make_valid_tls_handshake(&secret, 0); + bad_hmac[tls::TLS_DIGEST_POS] ^= 0x01; + + let alpn_mismatch = make_valid_tls_client_hello_with_alpn(&secret, 0, &[b"h3"]); + + for probe in [too_short, bad_hmac, alpn_mismatch] { + let result = tokio::time::timeout( + Duration::from_millis(200), + handle_tls_handshake( + &probe, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ), + ) + .await + .expect("Malformed TLS classes must be rejected within bounded time"); + + assert!(matches!(result, HandshakeResult::BadClient { .. })); + } +} + +#[tokio::test] +async fn tls_invalid_hmac_respects_configured_anti_fingerprint_delay() { + let secret = [0x5Au8; 16]; + let mut config = test_config_with_secret_hex("5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a"); + config.censorship.server_hello_delay_min_ms = 20; + config.censorship.server_hello_delay_max_ms = 20; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.32:44331".parse().unwrap(); + let mut bad_hmac = make_valid_tls_handshake(&secret, 0); + bad_hmac[tls::TLS_DIGEST_POS] ^= 0x01; + + let started = Instant::now(); + let result = handle_tls_handshake( + &bad_hmac, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(result, HandshakeResult::BadClient { .. })); + assert!( + started.elapsed() >= Duration::from_millis(18), + "configured anti-fingerprint delay must apply to invalid TLS handshakes" + ); +} + +#[tokio::test] +async fn tls_alpn_mismatch_respects_configured_anti_fingerprint_delay() { + let secret = [0x6Bu8; 16]; + let mut config = test_config_with_secret_hex("6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b6b"); + config.censorship.alpn_enforce = true; + config.censorship.server_hello_delay_min_ms = 20; + config.censorship.server_hello_delay_max_ms = 20; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.33:44332".parse().unwrap(); + let handshake = make_valid_tls_client_hello_with_alpn(&secret, 0, &[b"h3"]); + + let started = Instant::now(); + let result = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(result, HandshakeResult::BadClient { .. })); + assert!( + started.elapsed() >= Duration::from_millis(18), + "configured anti-fingerprint delay must apply to ALPN-mismatch rejects" + ); +} + +#[tokio::test] +#[ignore = "timing-sensitive; run manually on low-jitter hosts"] +async fn malformed_tls_classes_share_close_latency_buckets() { + const ITER: usize = 24; + const BUCKET_MS: u128 = 10; + + let secret = [0x99u8; 16]; + let mut config = test_config_with_secret_hex("99999999999999999999999999999999"); + config.censorship.alpn_enforce = true; + + let replay_checker = ReplayChecker::new(4096, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.31:44330".parse().unwrap(); + + let too_short = vec![0x16, 0x03, 0x01]; + + let mut bad_hmac = make_valid_tls_handshake(&secret, 0); + bad_hmac[tls::TLS_DIGEST_POS + 1] ^= 0x01; + + let alpn_mismatch = make_valid_tls_client_hello_with_alpn(&secret, 0, &[b"h3"]); + + let mut class_means_ms = Vec::new(); + for probe in [too_short, bad_hmac, alpn_mismatch] { + let mut sum_micros: u128 = 0; + for _ in 0..ITER { + let started = Instant::now(); + let result = handle_tls_handshake( + &probe, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + let elapsed = started.elapsed(); + assert!(matches!(result, HandshakeResult::BadClient { .. })); + sum_micros += elapsed.as_micros(); + } + + class_means_ms.push(sum_micros / ITER as u128 / 1_000); + } + + let min_bucket = class_means_ms + .iter() + .map(|ms| ms / BUCKET_MS) + .min() + .unwrap(); + let max_bucket = class_means_ms + .iter() + .map(|ms| ms / BUCKET_MS) + .max() + .unwrap(); + + assert!( + max_bucket <= min_bucket + 1, + "Malformed TLS classes diverged across latency buckets: means_ms={:?}", + class_means_ms + ); +} + +#[tokio::test] +#[ignore = "timing matrix; run manually with --ignored --nocapture"] +async fn timing_matrix_tls_classes_under_fixed_delay_budget() { + const ITER: usize = 48; + const BUCKET_MS: u128 = 10; + + let secret = [0x77u8; 16]; + let mut config = test_config_with_secret_hex("77777777777777777777777777777777"); + config.censorship.alpn_enforce = true; + config.censorship.server_hello_delay_min_ms = 20; + config.censorship.server_hello_delay_max_ms = 20; + + let rng = SecureRandom::new(); + let base_ip = std::net::Ipv4Addr::new(198, 51, 100, 34); + + let too_short = vec![0x16, 0x03, 0x01]; + let mut bad_hmac = make_valid_tls_handshake(&secret, 0); + bad_hmac[tls::TLS_DIGEST_POS + 1] ^= 0x01; + let alpn_mismatch = make_valid_tls_client_hello_with_alpn(&secret, 0, &[b"h3"]); + let valid_h2 = make_valid_tls_client_hello_with_alpn(&secret, 0, &[b"h2"]); + + let classes = vec![ + ("too_short", too_short), + ("bad_hmac", bad_hmac), + ("alpn_mismatch", alpn_mismatch), + ("valid_h2", valid_h2), + ]; + + for (class, probe) in classes { + let mut samples_ms = Vec::with_capacity(ITER); + for idx in 0..ITER { + clear_auth_probe_state_for_testing(); + let replay_checker = ReplayChecker::new(4096, Duration::from_secs(60)); + let peer: SocketAddr = SocketAddr::from((base_ip, 44_000 + idx as u16)); + let started = Instant::now(); + let result = handle_tls_handshake( + &probe, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + let elapsed = started.elapsed(); + samples_ms.push(elapsed.as_millis()); + + if class == "valid_h2" { + assert!(matches!(result, HandshakeResult::Success(_))); + } else { + assert!(matches!(result, HandshakeResult::BadClient { .. })); + } + } + + samples_ms.sort_unstable(); + let sum: u128 = samples_ms.iter().copied().sum(); + let mean = sum as f64 / samples_ms.len() as f64; + let min = samples_ms[0]; + let p95_idx = ((samples_ms.len() as f64) * 0.95).floor() as usize; + let p95 = samples_ms[p95_idx.min(samples_ms.len() - 1)]; + let max = samples_ms[samples_ms.len() - 1]; + + println!( + "TIMING_MATRIX tls class={} mean_ms={:.2} min_ms={} p95_ms={} max_ms={} bucket_mean={}", + class, + mean, + min, + p95, + max, + (mean as u128) / BUCKET_MS + ); + } +} + +#[test] +fn secure_tag_requires_tls_mode_on_tls_transport() { + let mut config = ProxyConfig::default(); + config.general.modes.classic = false; + config.general.modes.secure = true; + config.general.modes.tls = false; + + assert!( + !mode_enabled_for_proto(&config, ProtoTag::Secure, true), + "Secure tag over TLS must be rejected when tls mode is disabled" + ); + + config.general.modes.tls = true; + assert!( + mode_enabled_for_proto(&config, ProtoTag::Secure, true), + "Secure tag over TLS must be accepted when tls mode is enabled" + ); +} + +#[test] +fn secure_tag_requires_secure_mode_on_direct_transport() { + let mut config = ProxyConfig::default(); + config.general.modes.classic = false; + config.general.modes.secure = false; + config.general.modes.tls = true; + + assert!( + !mode_enabled_for_proto(&config, ProtoTag::Secure, false), + "Secure tag without TLS must be rejected when secure mode is disabled" + ); + + config.general.modes.secure = true; + assert!( + mode_enabled_for_proto(&config, ProtoTag::Secure, false), + "Secure tag without TLS must be accepted when secure mode is enabled" + ); +} + +#[test] +fn mode_policy_matrix_is_stable_for_all_tag_transport_mode_combinations() { + let tags = [ProtoTag::Secure, ProtoTag::Intermediate, ProtoTag::Abridged]; + + for classic in [false, true] { + for secure in [false, true] { + for tls in [false, true] { + let mut config = ProxyConfig::default(); + config.general.modes.classic = classic; + config.general.modes.secure = secure; + config.general.modes.tls = tls; + + for is_tls in [false, true] { + for tag in tags { + let expected = match (tag, is_tls) { + (ProtoTag::Secure, true) => tls, + (ProtoTag::Secure, false) => secure, + (ProtoTag::Intermediate | ProtoTag::Abridged, _) => classic, + }; + + assert_eq!( + mode_enabled_for_proto(&config, tag, is_tls), + expected, + "mode policy drifted for tag={:?}, transport_tls={}, modes=(classic={}, secure={}, tls={})", + tag, + is_tls, + classic, + secure, + tls + ); + } + } + } + } + } +} + +#[test] +fn invalid_secret_warning_keys_do_not_collide_on_colon_boundaries() { + let _guard = warned_secrets_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_warned_secrets_for_testing(); + + warn_invalid_secret_once("a:b", "c", ACCESS_SECRET_BYTES, Some(1)); + warn_invalid_secret_once("a", "b:c", ACCESS_SECRET_BYTES, Some(2)); + + let warned = INVALID_SECRET_WARNED + .get() + .expect("warned set must be initialized"); + let guard = warned.lock().expect("warned set lock must be available"); + assert_eq!( + guard.len(), + 2, + "(name, reason) pairs that stringify to the same colon-joined key must remain distinct" + ); +} + +#[test] +fn invalid_secret_warning_cache_is_bounded() { + let _guard = warned_secrets_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_warned_secrets_for_testing(); + + for idx in 0..(WARNED_SECRET_MAX_ENTRIES + 32) { + let user = format!("warned_user_{idx}"); + warn_invalid_secret_once(&user, "invalid_length", ACCESS_SECRET_BYTES, Some(idx)); + } + + let warned = INVALID_SECRET_WARNED + .get() + .expect("warned set must be initialized"); + let guard = warned.lock().expect("warned set lock must be available"); + assert_eq!( + guard.len(), + WARNED_SECRET_MAX_ENTRIES, + "invalid-secret warning cache must remain bounded" + ); +} + +#[tokio::test] +async fn repeated_invalid_tls_probes_trigger_pre_auth_throttle() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let config = test_config_with_secret_hex("11111111111111111111111111111111"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.61:44361".parse().unwrap(); + + let mut invalid = vec![0x42u8; tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + 32]; + invalid[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = 32; + + for _ in 0..AUTH_PROBE_BACKOFF_START_FAILS { + let result = handle_tls_handshake( + &invalid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(result, HandshakeResult::BadClient { .. })); + } + + assert!( + auth_probe_fail_streak_for_testing(peer.ip()) + .is_some_and(|streak| streak >= AUTH_PROBE_BACKOFF_START_FAILS), + "invalid probe burst must grow pre-auth failure streak to backoff threshold" + ); +} + +#[tokio::test] +async fn successful_tls_handshake_clears_pre_auth_failure_streak() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let secret = [0x23u8; 16]; + let config = test_config_with_secret_hex("23232323232323232323232323232323"); + let replay_checker = ReplayChecker::new(256, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.62:44362".parse().unwrap(); + + let mut invalid = vec![0x42u8; tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + 32]; + invalid[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = 32; + + for expected in 1..AUTH_PROBE_BACKOFF_START_FAILS { + let result = handle_tls_handshake( + &invalid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(result, HandshakeResult::BadClient { .. })); + assert_eq!( + auth_probe_fail_streak_for_testing(peer.ip()), + Some(expected), + "failure streak must grow before a successful authentication" + ); + } + + let valid = make_valid_tls_handshake(&secret, 0); + let success = handle_tls_handshake( + &valid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(success, HandshakeResult::Success(_))); + assert_eq!( + auth_probe_fail_streak_for_testing(peer.ip()), + None, + "successful authentication must clear accumulated pre-auth failures" + ); +} + +#[test] +fn auth_probe_capacity_prunes_stale_entries_for_new_ips() { + let state = DashMap::new(); + let now = Instant::now(); + let stale_seen = now - Duration::from_secs(AUTH_PROBE_TRACK_RETENTION_SECS + 1); + + for idx in 0..AUTH_PROBE_TRACK_MAX_ENTRIES { + let ip = IpAddr::V4(Ipv4Addr::new( + 10, + 1, + ((idx >> 8) & 0xff) as u8, + (idx & 0xff) as u8, + )); + state.insert( + ip, + AuthProbeState { + fail_streak: 1, + blocked_until: now, + last_seen: stale_seen, + }, + ); + } + + let newcomer = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 200)); + auth_probe_record_failure_with_state(&state, newcomer, now); + + assert_eq!( + state.get(&newcomer).map(|entry| entry.fail_streak), + Some(1), + "stale-entry pruning must admit and track a new probe source" + ); + assert!( + state.len() <= AUTH_PROBE_TRACK_MAX_ENTRIES, + "auth probe map must remain bounded after stale pruning" + ); +} + +#[test] +fn auth_probe_capacity_fresh_full_map_still_tracks_newcomer_with_bounded_eviction() { + let _guard = auth_probe_test_lock() + .lock() + .expect("auth probe test lock must be available"); + clear_auth_probe_state_for_testing(); + + let state = DashMap::new(); + let now = Instant::now(); + + for idx in 0..AUTH_PROBE_TRACK_MAX_ENTRIES { + let ip = IpAddr::V4(Ipv4Addr::new( + 172, + 16, + ((idx >> 8) & 0xff) as u8, + (idx & 0xff) as u8, + )); + state.insert( + ip, + AuthProbeState { + fail_streak: 1, + blocked_until: now, + last_seen: now + Duration::from_millis(idx as u64 + 1), + }, + ); + } + + let oldest = IpAddr::V4(Ipv4Addr::new(172, 16, 0, 0)); + state.insert( + oldest, + AuthProbeState { + fail_streak: 1, + blocked_until: now, + last_seen: now - Duration::from_secs(5), + }, + ); + + let newcomer = IpAddr::V4(Ipv4Addr::new(203, 0, 113, 55)); + auth_probe_record_failure_with_state(&state, newcomer, now); + + assert!( + state.get(&newcomer).is_some(), + "fresh-at-cap auth probe map must still track a new source after bounded eviction" + ); + assert!( + state.get(&oldest).is_none(), + "capacity eviction must remove the oldest tracked source first" + ); + assert_eq!( + state.len(), + AUTH_PROBE_TRACK_MAX_ENTRIES, + "auth probe map must stay at configured cap after bounded eviction" + ); + assert!( + auth_probe_saturation_is_throttled_at_for_testing(now), + "capacity pressure should still activate coarse global pre-auth throttling" + ); +} + +#[test] +fn stress_auth_probe_full_map_churn_keeps_bound_and_tracks_newcomers() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let state = DashMap::new(); + let base_now = Instant::now(); + + for idx in 0..AUTH_PROBE_TRACK_MAX_ENTRIES { + let ip = IpAddr::V4(Ipv4Addr::new( + 10, + 2, + ((idx >> 8) & 0xff) as u8, + (idx & 0xff) as u8, + )); + state.insert( + ip, + AuthProbeState { + fail_streak: 1, + blocked_until: base_now, + last_seen: base_now + Duration::from_millis((idx % 2048) as u64), + }, + ); + } + + for step in 0..1024usize { + let newcomer = IpAddr::V4(Ipv4Addr::new( + 203, + 0, + ((step >> 8) & 0xff) as u8, + (step & 0xff) as u8, + )); + let now = base_now + Duration::from_millis(10_000 + step as u64); + auth_probe_record_failure_with_state(&state, newcomer, now); + + assert!( + state.get(&newcomer).is_some(), + "new source must still be tracked under sustained at-capacity churn" + ); + assert_eq!( + state.len(), + AUTH_PROBE_TRACK_MAX_ENTRIES, + "auth probe map size must stay hard-bounded at capacity" + ); + } +} + +#[test] +fn auth_probe_over_cap_churn_still_tracks_newcomer_after_round_limit() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let state = DashMap::new(); + let now = Instant::now(); + let initial = AUTH_PROBE_TRACK_MAX_ENTRIES + 32; + + for idx in 0..initial { + let ip = IpAddr::V4(Ipv4Addr::new( + 10, + 6, + ((idx >> 8) & 0xff) as u8, + (idx & 0xff) as u8, + )); + state.insert( + ip, + AuthProbeState { + fail_streak: 1, + blocked_until: now, + last_seen: now + Duration::from_millis((idx % 1024) as u64), + }, + ); + } + + let newcomer = IpAddr::V4(Ipv4Addr::new(203, 0, 114, 77)); + auth_probe_record_failure_with_state(&state, newcomer, now + Duration::from_secs(1)); + + assert!( + state.get(&newcomer).is_some(), + "new probe source must still be tracked even when map starts above hard cap" + ); + assert!( + state.len() < initial + 1, + "round-limited eviction path must still reclaim capacity under over-cap churn" + ); +} + +#[test] +fn auth_probe_capacity_prefers_evicting_low_fail_streak_entries_first() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let state = DashMap::new(); + let now = Instant::now(); + + // Fill map at capacity with mostly high fail streak entries. + for idx in 0..AUTH_PROBE_TRACK_MAX_ENTRIES { + let ip = IpAddr::V4(Ipv4Addr::new( + 172, + 20, + ((idx >> 8) & 0xff) as u8, + (idx & 0xff) as u8, + )); + state.insert( + ip, + AuthProbeState { + fail_streak: 9, + blocked_until: now, + last_seen: now + Duration::from_millis(idx as u64 + 1), + }, + ); + } + + let low_fail = IpAddr::V4(Ipv4Addr::new(172, 21, 0, 1)); + state.insert( + low_fail, + AuthProbeState { + fail_streak: 1, + blocked_until: now, + last_seen: now + Duration::from_secs(30), + }, + ); + + let high_fail_old = IpAddr::V4(Ipv4Addr::new(172, 21, 0, 2)); + state.insert( + high_fail_old, + AuthProbeState { + fail_streak: 12, + blocked_until: now, + last_seen: now - Duration::from_secs(10), + }, + ); + + let newcomer = IpAddr::V4(Ipv4Addr::new(203, 0, 113, 201)); + auth_probe_record_failure_with_state(&state, newcomer, now); + + assert!(state.get(&newcomer).is_some(), "new source must be tracked"); + assert!( + state.get(&low_fail).is_none(), + "least-penalized entry should be evicted before high-penalty entries" + ); + assert!( + state.get(&high_fail_old).is_some(), + "high fail-streak entry should be preserved under mixed-priority eviction" + ); +} + +#[test] +fn auth_probe_capacity_tie_breaker_evicts_oldest_with_equal_fail_streak() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let state = DashMap::new(); + let now = Instant::now(); + + for idx in 0..(AUTH_PROBE_TRACK_MAX_ENTRIES - 2) { + let ip = IpAddr::V4(Ipv4Addr::new( + 172, + 30, + ((idx >> 8) & 0xff) as u8, + (idx & 0xff) as u8, + )); + state.insert( + ip, + AuthProbeState { + fail_streak: 5, + blocked_until: now, + last_seen: now + Duration::from_millis(idx as u64 + 1), + }, + ); + } + + let oldest = IpAddr::V4(Ipv4Addr::new(172, 31, 0, 1)); + let newer = IpAddr::V4(Ipv4Addr::new(172, 31, 0, 2)); + state.insert( + oldest, + AuthProbeState { + fail_streak: 1, + blocked_until: now, + last_seen: now - Duration::from_secs(20), + }, + ); + state.insert( + newer, + AuthProbeState { + fail_streak: 1, + blocked_until: now, + last_seen: now - Duration::from_secs(5), + }, + ); + + let newcomer = IpAddr::V4(Ipv4Addr::new(203, 0, 113, 202)); + auth_probe_record_failure_with_state(&state, newcomer, now); + + assert!(state.get(&newcomer).is_some(), "new source must be tracked"); + assert!( + state.get(&oldest).is_none(), + "among equal fail streak candidates, oldest entry must be evicted" + ); + assert!( + state.get(&newer).is_some(), + "newer equal-priority entry should be retained" + ); +} + +#[test] +fn stress_auth_probe_capacity_churn_preserves_high_fail_sentinels() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let state = DashMap::new(); + let base_now = Instant::now(); + + let sentinel_a = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 250)); + let sentinel_b = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 251)); + + state.insert( + sentinel_a, + AuthProbeState { + fail_streak: 20, + blocked_until: base_now, + last_seen: base_now - Duration::from_secs(30), + }, + ); + state.insert( + sentinel_b, + AuthProbeState { + fail_streak: 21, + blocked_until: base_now, + last_seen: base_now - Duration::from_secs(31), + }, + ); + + for idx in 0..(AUTH_PROBE_TRACK_MAX_ENTRIES - 2) { + let ip = IpAddr::V4(Ipv4Addr::new( + 10, + 4, + ((idx >> 8) & 0xff) as u8, + (idx & 0xff) as u8, + )); + state.insert( + ip, + AuthProbeState { + fail_streak: 1, + blocked_until: base_now, + last_seen: base_now + Duration::from_millis((idx % 1024) as u64), + }, + ); + } + + for step in 0..1024usize { + let newcomer = IpAddr::V4(Ipv4Addr::new( + 203, + 1, + ((step >> 8) & 0xff) as u8, + (step & 0xff) as u8, + )); + let now = base_now + Duration::from_millis(10_000 + step as u64); + auth_probe_record_failure_with_state(&state, newcomer, now); + + assert_eq!( + state.len(), + AUTH_PROBE_TRACK_MAX_ENTRIES, + "auth probe map must remain hard-bounded at capacity" + ); + assert!( + state.get(&sentinel_a).is_some() && state.get(&sentinel_b).is_some(), + "high fail-streak sentinels should survive low-streak newcomer churn" + ); + } +} + +#[test] +fn auth_probe_ipv6_is_bucketed_by_prefix_64() { + let state = DashMap::new(); + let now = Instant::now(); + + let ip_a = IpAddr::V6("2001:db8:abcd:1234:1:2:3:4".parse().unwrap()); + let ip_b = IpAddr::V6("2001:db8:abcd:1234:ffff:eeee:dddd:cccc".parse().unwrap()); + + auth_probe_record_failure_with_state(&state, normalize_auth_probe_ip(ip_a), now); + auth_probe_record_failure_with_state(&state, normalize_auth_probe_ip(ip_b), now); + + let normalized = normalize_auth_probe_ip(ip_a); + assert_eq!( + state.len(), + 1, + "IPv6 sources in the same /64 must share one pre-auth throttle bucket" + ); + assert_eq!( + state.get(&normalized).map(|entry| entry.fail_streak), + Some(2), + "failures from the same /64 must accumulate in one throttle state" + ); +} + +#[test] +fn auth_probe_ipv6_different_prefixes_use_distinct_buckets() { + let state = DashMap::new(); + let now = Instant::now(); + + let ip_a = IpAddr::V6("2001:db8:1111:2222:1:2:3:4".parse().unwrap()); + let ip_b = IpAddr::V6("2001:db8:1111:3333:1:2:3:4".parse().unwrap()); + + auth_probe_record_failure_with_state(&state, normalize_auth_probe_ip(ip_a), now); + auth_probe_record_failure_with_state(&state, normalize_auth_probe_ip(ip_b), now); + + assert_eq!( + state.len(), + 2, + "different IPv6 /64 prefixes must not share throttle buckets" + ); + assert_eq!( + state.get(&normalize_auth_probe_ip(ip_a)).map(|entry| entry.fail_streak), + Some(1) + ); + assert_eq!( + state.get(&normalize_auth_probe_ip(ip_b)).map(|entry| entry.fail_streak), + Some(1) + ); +} + +#[test] +fn auth_probe_success_clears_whole_ipv6_prefix_bucket() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let now = Instant::now(); + let ip_fail = IpAddr::V6("2001:db8:aaaa:bbbb:1:2:3:4".parse().unwrap()); + let ip_success = IpAddr::V6("2001:db8:aaaa:bbbb:ffff:eeee:dddd:cccc".parse().unwrap()); + + auth_probe_record_failure(ip_fail, now); + assert_eq!( + auth_probe_fail_streak_for_testing(ip_fail), + Some(1), + "precondition: normalized prefix bucket must exist" + ); + + auth_probe_record_success(ip_success); + assert_eq!( + auth_probe_fail_streak_for_testing(ip_fail), + None, + "success from the same /64 must clear the shared bucket" + ); +} + +#[test] +fn auth_probe_eviction_offset_varies_with_input() { + let now = Instant::now(); + let ip1 = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 10)); + let ip2 = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 11)); + + let a = auth_probe_eviction_offset(ip1, now); + let b = auth_probe_eviction_offset(ip1, now); + let c = auth_probe_eviction_offset(ip2, now); + + assert_eq!(a, b, "same input must yield deterministic offset"); + assert_ne!(a, c, "different peer IPs should not collapse to one offset"); +} + +#[test] +fn auth_probe_eviction_offset_changes_with_time_component() { + let ip = IpAddr::V4(Ipv4Addr::new(203, 0, 113, 77)); + let now = Instant::now(); + let later = now + Duration::from_millis(1); + + let a = auth_probe_eviction_offset(ip, now); + let b = auth_probe_eviction_offset(ip, later); + + assert_ne!( + a, b, + "eviction offset must incorporate timestamp entropy and not only peer IP" + ); +} + + +#[test] +fn auth_probe_round_limited_overcap_eviction_marks_saturation_and_keeps_newcomer_trackable() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let state = DashMap::new(); + let now = Instant::now(); + let initial = AUTH_PROBE_TRACK_MAX_ENTRIES + 64; + + let sentinel = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 250)); + state.insert( + sentinel, + AuthProbeState { + fail_streak: 25, + blocked_until: now, + last_seen: now - Duration::from_secs(30), + }, + ); + + for idx in 0..(initial - 1) { + let ip = IpAddr::V4(Ipv4Addr::new( + 10, + 20, + ((idx >> 8) & 0xff) as u8, + (idx & 0xff) as u8, + )); + state.insert( + ip, + AuthProbeState { + fail_streak: 1, + blocked_until: now, + last_seen: now + Duration::from_millis((idx % 1024) as u64), + }, + ); + } + + let newcomer = IpAddr::V4(Ipv4Addr::new(203, 0, 113, 40)); + auth_probe_record_failure_with_state(&state, newcomer, now + Duration::from_millis(1)); + + assert!(state.get(&newcomer).is_some(), "newcomer must still be tracked under over-cap pressure"); + assert!( + state.get(&sentinel).is_some(), + "high fail-streak sentinel must survive round-limited eviction" + ); + assert!( + auth_probe_saturation_is_throttled_at_for_testing(now + Duration::from_millis(1)), + "round-limited over-cap path must activate saturation throttle marker" + ); +} + +#[tokio::test] +async fn gap_t01_short_tls_probe_burst_is_throttled() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let config = test_config_with_secret_hex("11111111111111111111111111111111"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.171:44361".parse().unwrap(); + + let too_short = vec![0x16, 0x03, 0x01]; + + for _ in 0..AUTH_PROBE_BACKOFF_START_FAILS { + let result = handle_tls_handshake( + &too_short, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(result, HandshakeResult::BadClient { .. })); + } + + assert!( + auth_probe_fail_streak_for_testing(peer.ip()) + .is_some_and(|streak| streak >= AUTH_PROBE_BACKOFF_START_FAILS), + "short TLS probe bursts must increase auth-probe fail streak" + ); +} + +#[test] +fn stress_auth_probe_overcap_churn_does_not_starve_high_threat_sentinel_bucket() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let state = DashMap::new(); + let base_now = Instant::now(); + + let sentinel = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 200)); + state.insert( + sentinel, + AuthProbeState { + fail_streak: 30, + blocked_until: base_now, + last_seen: base_now - Duration::from_secs(60), + }, + ); + + for idx in 0..(AUTH_PROBE_TRACK_MAX_ENTRIES + 80) { + let ip = IpAddr::V4(Ipv4Addr::new( + 172, + 22, + ((idx >> 8) & 0xff) as u8, + (idx & 0xff) as u8, + )); + state.insert( + ip, + AuthProbeState { + fail_streak: 1, + blocked_until: base_now, + last_seen: base_now + Duration::from_millis((idx % 2048) as u64), + }, + ); + } + + for step in 0..512usize { + let newcomer = IpAddr::V4(Ipv4Addr::new( + 203, + 2, + ((step >> 8) & 0xff) as u8, + (step & 0xff) as u8, + )); + auth_probe_record_failure_with_state(&state, newcomer, base_now + Duration::from_millis(step as u64 + 1)); + + assert!( + state.get(&sentinel).is_some(), + "step {step}: high-threat sentinel must not be starved by newcomer churn" + ); + assert!(state.get(&newcomer).is_some(), "step {step}: newcomer must be tracked"); + } +} + +#[test] +fn light_fuzz_auth_probe_overcap_eviction_prefers_less_threatening_entries() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let now = Instant::now(); + let mut s: u64 = 0xBADC_0FFE_EE11_2233; + + for round in 0..128usize { + let state = DashMap::new(); + let sentinel = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 180)); + state.insert( + sentinel, + AuthProbeState { + fail_streak: 18, + blocked_until: now, + last_seen: now - Duration::from_secs(5), + }, + ); + + for idx in 0..AUTH_PROBE_TRACK_MAX_ENTRIES { + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + let ip = IpAddr::V4(Ipv4Addr::new( + 10, + ((idx >> 8) & 0xff) as u8, + (idx & 0xff) as u8, + (s & 0xff) as u8, + )); + state.insert( + ip, + AuthProbeState { + fail_streak: 1, + blocked_until: now, + last_seen: now + Duration::from_millis((s & 1023) as u64), + }, + ); + } + + let newcomer = IpAddr::V4(Ipv4Addr::new(203, 10, ((round >> 8) & 0xff) as u8, (round & 0xff) as u8)); + auth_probe_record_failure_with_state(&state, newcomer, now + Duration::from_millis(round as u64 + 1)); + + assert!(state.get(&newcomer).is_some(), "round {round}: newcomer should be tracked"); + assert!( + state.get(&sentinel).is_some(), + "round {round}: high fail-streak sentinel should survive mixed low-threat pool" + ); + } +} +#[test] +fn light_fuzz_auth_probe_eviction_offset_is_deterministic_per_input_pair() { + let mut rng = StdRng::seed_from_u64(0xA11CE5EED); + let base = Instant::now(); + + for _ in 0..4096usize { + let ip = IpAddr::V4(Ipv4Addr::new(rng.random(), rng.random(), rng.random(), rng.random())); + let offset_ns = rng.random_range(0_u64..2_000_000); + let when = base + Duration::from_nanos(offset_ns); + + let first = auth_probe_eviction_offset(ip, when); + let second = auth_probe_eviction_offset(ip, when); + assert_eq!( + first, second, + "eviction offset must be stable for identical (ip, now) pairs" + ); + } +} + +#[test] +fn adversarial_eviction_offset_spread_avoids_single_bucket_collapse() { + let modulus = AUTH_PROBE_TRACK_MAX_ENTRIES; + let mut bucket_hits = vec![0usize; modulus]; + let now = Instant::now(); + + for idx in 0..8192usize { + let ip = IpAddr::V4(Ipv4Addr::new( + 100, + ((idx >> 8) & 0xff) as u8, + (idx & 0xff) as u8, + ((idx.wrapping_mul(37)) & 0xff) as u8, + )); + let bucket = auth_probe_eviction_offset(ip, now) % modulus; + bucket_hits[bucket] += 1; + } + + let non_empty_buckets = bucket_hits.iter().filter(|&&hits| hits > 0).count(); + assert!( + non_empty_buckets >= modulus / 2, + "adversarial sequential input should cover a broad bucket set (covered {non_empty_buckets}/{modulus})" + ); + + let max_hits = bucket_hits.iter().copied().max().unwrap_or(0); + let min_non_zero_hits = bucket_hits + .iter() + .copied() + .filter(|&hits| hits > 0) + .min() + .unwrap_or(0); + assert!( + max_hits <= min_non_zero_hits.saturating_mul(32).max(1), + "bucket skew is unexpectedly extreme for keyed hasher spread (max={max_hits}, min_non_zero={min_non_zero_hits})" + ); +} + +#[test] +fn stress_auth_probe_eviction_offset_high_volume_uniqueness_sanity() { + let now = Instant::now(); + let mut seen = std::collections::HashSet::new(); + + for idx in 0..50_000usize { + let ip = IpAddr::V4(Ipv4Addr::new( + 198, + ((idx >> 16) & 0xff) as u8, + ((idx >> 8) & 0xff) as u8, + (idx & 0xff) as u8, + )); + seen.insert(auth_probe_eviction_offset(ip, now)); + } + + assert!( + seen.len() >= 40_000, + "high-volume eviction offsets should not collapse excessively under keyed hashing" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn auth_probe_concurrent_failures_do_not_lose_fail_streak_updates() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let peer_ip: IpAddr = "198.51.100.90".parse().unwrap(); + let tasks = 128usize; + let barrier = Arc::new(Barrier::new(tasks)); + let mut handles = Vec::with_capacity(tasks); + + for _ in 0..tasks { + let barrier = barrier.clone(); + handles.push(tokio::spawn(async move { + barrier.wait().await; + auth_probe_record_failure(peer_ip, Instant::now()); + })); + } + + for handle in handles { + handle + .await + .expect("concurrent failure recording task must not panic"); + } + + let streak = auth_probe_fail_streak_for_testing(peer_ip) + .expect("tracked peer must exist after concurrent failure burst"); + assert_eq!( + streak as usize, + tasks, + "concurrent failures for one source must account every attempt" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn invalid_probe_noise_from_other_ips_does_not_break_valid_tls_handshake() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let secret = [0x31u8; 16]; + let config = Arc::new(test_config_with_secret_hex("31313131313131313131313131313131")); + let replay_checker = Arc::new(ReplayChecker::new(4096, Duration::from_secs(60))); + let rng = Arc::new(SecureRandom::new()); + let victim_peer: SocketAddr = "198.51.100.91:44391".parse().unwrap(); + let valid = Arc::new(make_valid_tls_handshake(&secret, 0)); + + let mut invalid = vec![0x42u8; tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + 32]; + invalid[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = 32; + let invalid = Arc::new(invalid); + + let mut noise_tasks = Vec::new(); + for idx in 0..96u16 { + let config = config.clone(); + let replay_checker = replay_checker.clone(); + let rng = rng.clone(); + let invalid = invalid.clone(); + noise_tasks.push(tokio::spawn(async move { + let octet = ((idx % 200) + 1) as u8; + let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(203, 0, 113, octet)), 45000 + idx); + let result = handle_tls_handshake( + &invalid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(result, HandshakeResult::BadClient { .. })); + })); + } + + let victim_config = config.clone(); + let victim_replay_checker = replay_checker.clone(); + let victim_rng = rng.clone(); + let victim_valid = valid.clone(); + let victim_task = tokio::spawn(async move { + handle_tls_handshake( + &victim_valid, + tokio::io::empty(), + tokio::io::sink(), + victim_peer, + &victim_config, + &victim_replay_checker, + &victim_rng, + None, + ) + .await + }); + + for task in noise_tasks { + task.await.expect("noise task must not panic"); + } + + let victim_result = victim_task + .await + .expect("victim handshake task must not panic"); + assert!( + matches!(victim_result, HandshakeResult::Success(_)), + "invalid probe noise from other IPs must not block a valid victim handshake" + ); + assert_eq!( + auth_probe_fail_streak_for_testing(victim_peer.ip()), + None, + "successful victim handshake must not retain pre-auth failure streak" + ); +} + +#[test] +fn auth_probe_saturation_state_expires_after_retention_window() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let now = Instant::now(); + let saturation = auth_probe_saturation_state(); + { + let mut guard = saturation + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(30), + last_seen: now - Duration::from_secs(AUTH_PROBE_TRACK_RETENTION_SECS + 1), + }); + } + + assert!( + !auth_probe_saturation_is_throttled_for_testing(), + "expired saturation state must stop throttling and self-clear" + ); + + let guard = saturation + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + assert!(guard.is_none(), "expired saturation state must be removed"); +} + +#[tokio::test] +async fn global_saturation_marker_does_not_block_valid_tls_handshake() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let secret = [0x41u8; 16]; + let config = test_config_with_secret_hex("41414141414141414141414141414141"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.101:45101".parse().unwrap(); + + let now = Instant::now(); + let saturation = auth_probe_saturation_state(); + { + let mut guard = saturation + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(5), + last_seen: now, + }); + } + + let valid = make_valid_tls_handshake(&secret, 0); + let result = handle_tls_handshake( + &valid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!( + matches!(result, HandshakeResult::Success(_)), + "global saturation marker must not block valid authenticated TLS handshakes" + ); + assert_eq!( + auth_probe_fail_streak_for_testing(peer.ip()), + None, + "successful handshake under saturation marker must not retain per-ip probe failures" + ); +} + +#[tokio::test] +async fn expired_global_saturation_allows_valid_tls_handshake() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let secret = [0x55u8; 16]; + let config = test_config_with_secret_hex("55555555555555555555555555555555"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.102:45102".parse().unwrap(); + + let now = Instant::now(); + let saturation = auth_probe_saturation_state(); + { + let mut guard = saturation + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(5), + last_seen: now - Duration::from_secs(AUTH_PROBE_TRACK_RETENTION_SECS + 1), + }); + } + + let valid = make_valid_tls_handshake(&secret, 0); + let result = handle_tls_handshake( + &valid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!( + matches!(result, HandshakeResult::Success(_)), + "expired saturation marker must not block valid handshake" + ); +} + +#[tokio::test] +async fn valid_tls_is_blocked_by_per_ip_preauth_throttle_without_saturation() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let secret = [0x61u8; 16]; + let config = test_config_with_secret_hex("61616161616161616161616161616161"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.103:45103".parse().unwrap(); + + auth_probe_state_map().insert( + normalize_auth_probe_ip(peer.ip()), + AuthProbeState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: Instant::now() + Duration::from_secs(5), + last_seen: Instant::now(), + }, + ); + + let valid = make_valid_tls_handshake(&secret, 0); + let result = handle_tls_handshake( + &valid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(result, HandshakeResult::BadClient { .. })); +} + +#[tokio::test] +async fn saturation_allows_valid_tls_even_when_peer_ip_is_currently_throttled() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let secret = [0x62u8; 16]; + let config = test_config_with_secret_hex("62626262626262626262626262626262"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.104:45104".parse().unwrap(); + let now = Instant::now(); + + auth_probe_state_map().insert( + normalize_auth_probe_ip(peer.ip()), + AuthProbeState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(5), + last_seen: now, + }, + ); + { + let mut guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(5), + last_seen: now, + }); + } + + let valid = make_valid_tls_handshake(&secret, 0); + let result = handle_tls_handshake( + &valid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(result, HandshakeResult::Success(_))); + assert_eq!( + auth_probe_fail_streak_for_testing(peer.ip()), + None, + "successful auth under saturation must clear the peer's throttled state" + ); +} + +#[tokio::test] +async fn saturation_still_rejects_invalid_tls_probe_and_records_failure() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let config = test_config_with_secret_hex("63636363636363636363636363636363"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.105:45105".parse().unwrap(); + let now = Instant::now(); + { + let mut guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(5), + last_seen: now, + }); + } + + let mut invalid = vec![0x42u8; tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + 32]; + invalid[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = 32; + + let result = handle_tls_handshake( + &invalid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(result, HandshakeResult::BadClient { .. })); + assert_eq!( + auth_probe_fail_streak_for_testing(peer.ip()), + Some(1), + "invalid TLS during saturation must still increment per-ip failure tracking" + ); +} + +#[tokio::test] +async fn saturation_grace_exhaustion_preauth_throttles_repeated_invalid_tls_probe() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let config = test_config_with_secret_hex("63636363636363636363636363636363"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.205:45205".parse().unwrap(); + let now = Instant::now(); + auth_probe_state_map().insert( + normalize_auth_probe_ip(peer.ip()), + AuthProbeState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }, + ); + { + let mut guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }); + } + + let mut invalid = vec![0x42u8; tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + 32]; + invalid[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = 32; + + let result = handle_tls_handshake( + &invalid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(result, HandshakeResult::BadClient { .. })); + assert_eq!( + auth_probe_fail_streak_for_testing(peer.ip()), + Some(AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS), + "pre-auth throttle under exhausted saturation grace must reject without re-processing invalid TLS" + ); +} + +#[tokio::test] +async fn saturation_allows_valid_mtproto_even_when_peer_ip_is_currently_throttled() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let secret_hex = "64646464646464646464646464646464"; + let mut config = test_config_with_secret_hex(secret_hex); + config.general.modes.secure = true; + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let peer: SocketAddr = "198.51.100.106:45106".parse().unwrap(); + let now = Instant::now(); + + auth_probe_state_map().insert( + normalize_auth_probe_ip(peer.ip()), + AuthProbeState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(5), + last_seen: now, + }, + ); + { + let mut guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(5), + last_seen: now, + }); + } + + let valid = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 2); + let result = handle_mtproto_handshake( + &valid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + + assert!(matches!(result, HandshakeResult::Success(_))); + assert_eq!( + auth_probe_fail_streak_for_testing(peer.ip()), + None, + "successful mtproto auth under saturation must clear the peer's throttled state" + ); +} + +#[tokio::test] +async fn saturation_still_rejects_invalid_mtproto_probe_and_records_failure() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let config = test_config_with_secret_hex("65656565656565656565656565656565"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let peer: SocketAddr = "198.51.100.107:45107".parse().unwrap(); + let now = Instant::now(); + { + let mut guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(5), + last_seen: now, + }); + } + + let invalid = [0u8; HANDSHAKE_LEN]; + + let result = handle_mtproto_handshake( + &invalid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + + assert!(matches!(result, HandshakeResult::BadClient { .. })); + assert_eq!( + auth_probe_fail_streak_for_testing(peer.ip()), + Some(1), + "invalid mtproto during saturation must still increment per-ip failure tracking" + ); +} + +#[tokio::test] +async fn saturation_grace_exhaustion_preauth_throttles_repeated_invalid_mtproto_probe() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let config = test_config_with_secret_hex("65656565656565656565656565656565"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let peer: SocketAddr = "198.51.100.206:45206".parse().unwrap(); + let now = Instant::now(); + auth_probe_state_map().insert( + normalize_auth_probe_ip(peer.ip()), + AuthProbeState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }, + ); + { + let mut guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }); + } + + let invalid = [0u8; HANDSHAKE_LEN]; + let result = handle_mtproto_handshake( + &invalid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + + assert!(matches!(result, HandshakeResult::BadClient { .. })); + assert_eq!( + auth_probe_fail_streak_for_testing(peer.ip()), + Some(AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS), + "pre-auth throttle under exhausted saturation grace must reject without re-processing invalid MTProto" + ); +} + +#[tokio::test] +async fn saturation_grace_progression_tls_reaches_cap_then_stops_incrementing() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let config = test_config_with_secret_hex("70707070707070707070707070707070"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.207:45207".parse().unwrap(); + let now = Instant::now(); + auth_probe_state_map().insert( + normalize_auth_probe_ip(peer.ip()), + AuthProbeState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }, + ); + { + let mut guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }); + } + + let mut invalid = vec![0x42u8; tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + 32]; + invalid[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = 32; + + for expected in [ + AUTH_PROBE_BACKOFF_START_FAILS + 1, + AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS, + ] { + let result = handle_tls_handshake( + &invalid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(result, HandshakeResult::BadClient { .. })); + assert_eq!(auth_probe_fail_streak_for_testing(peer.ip()), Some(expected)); + } + + { + let mut entry = auth_probe_state_map() + .get_mut(&normalize_auth_probe_ip(peer.ip())) + .expect("peer state must exist before exhaustion recheck"); + entry.fail_streak = AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS; + entry.blocked_until = Instant::now() + Duration::from_secs(1); + entry.last_seen = Instant::now(); + } + + let result = handle_tls_handshake( + &invalid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(result, HandshakeResult::BadClient { .. })); + assert_eq!( + auth_probe_fail_streak_for_testing(peer.ip()), + Some(AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS), + "once grace is exhausted, repeated invalid TLS must be pre-auth throttled without further fail-streak growth" + ); +} + +#[tokio::test] +async fn saturation_grace_progression_mtproto_reaches_cap_then_stops_incrementing() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let config = test_config_with_secret_hex("71717171717171717171717171717171"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let peer: SocketAddr = "198.51.100.208:45208".parse().unwrap(); + let now = Instant::now(); + auth_probe_state_map().insert( + normalize_auth_probe_ip(peer.ip()), + AuthProbeState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }, + ); + { + let mut guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }); + } + + let invalid = [0u8; HANDSHAKE_LEN]; + + for expected in [ + AUTH_PROBE_BACKOFF_START_FAILS + 1, + AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS, + ] { + let result = handle_mtproto_handshake( + &invalid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + assert!(matches!(result, HandshakeResult::BadClient { .. })); + assert_eq!(auth_probe_fail_streak_for_testing(peer.ip()), Some(expected)); + } + + { + let mut entry = auth_probe_state_map() + .get_mut(&normalize_auth_probe_ip(peer.ip())) + .expect("peer state must exist before exhaustion recheck"); + entry.fail_streak = AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS; + entry.blocked_until = Instant::now() + Duration::from_secs(1); + entry.last_seen = Instant::now(); + } + + let result = handle_mtproto_handshake( + &invalid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + assert!(matches!(result, HandshakeResult::BadClient { .. })); + assert_eq!( + auth_probe_fail_streak_for_testing(peer.ip()), + Some(AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS), + "once grace is exhausted, repeated invalid MTProto must be pre-auth throttled without further fail-streak growth" + ); +} + +#[tokio::test] +async fn saturation_grace_boundary_still_admits_valid_tls_before_exhaustion() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let secret = [0x72u8; 16]; + let config = test_config_with_secret_hex("72727272727272727272727272727272"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.209:45209".parse().unwrap(); + let now = Instant::now(); + auth_probe_state_map().insert( + normalize_auth_probe_ip(peer.ip()), + AuthProbeState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS - 1, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }, + ); + { + let mut guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }); + } + + let valid = make_valid_tls_handshake(&secret, 0); + let result = handle_tls_handshake( + &valid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!( + matches!(result, HandshakeResult::Success(_)), + "valid TLS should still pass while peer remains within saturation grace budget" + ); + assert_eq!(auth_probe_fail_streak_for_testing(peer.ip()), None); +} + +#[tokio::test] +async fn saturation_grace_exhaustion_blocks_valid_tls_until_backoff_expires() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let secret = [0x73u8; 16]; + let config = test_config_with_secret_hex("73737373737373737373737373737373"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.210:45210".parse().unwrap(); + let now = Instant::now(); + auth_probe_state_map().insert( + normalize_auth_probe_ip(peer.ip()), + AuthProbeState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS, + blocked_until: now + Duration::from_millis(200), + last_seen: now, + }, + ); + { + let mut guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }); + } + + let valid = make_valid_tls_handshake(&secret, 0); + let blocked = handle_tls_handshake( + &valid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(blocked, HandshakeResult::BadClient { .. })); + + tokio::time::sleep(Duration::from_millis(230)).await; + + let allowed = handle_tls_handshake( + &valid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!( + matches!(allowed, HandshakeResult::Success(_)), + "valid TLS should recover after peer-specific pre-auth backoff has elapsed" + ); + assert_eq!(auth_probe_fail_streak_for_testing(peer.ip()), None); +} + +#[tokio::test] +async fn saturation_grace_exhaustion_is_shared_across_tls_and_mtproto_for_same_peer() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let config = test_config_with_secret_hex("74747474747474747474747474747474"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.211:45211".parse().unwrap(); + let now = Instant::now(); + auth_probe_state_map().insert( + normalize_auth_probe_ip(peer.ip()), + AuthProbeState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }, + ); + { + let mut guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }); + } + + let mut invalid_tls = vec![0x42u8; tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + 32]; + invalid_tls[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = 32; + let invalid_mtproto = [0u8; HANDSHAKE_LEN]; + + let tls_result = handle_tls_handshake( + &invalid_tls, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(tls_result, HandshakeResult::BadClient { .. })); + + let mtproto_result = handle_mtproto_handshake( + &invalid_mtproto, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + assert!(matches!(mtproto_result, HandshakeResult::BadClient { .. })); + + assert_eq!( + auth_probe_fail_streak_for_testing(peer.ip()), + Some(AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS), + "saturation grace exhaustion must gate both TLS and MTProto pre-auth paths for one peer" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn adversarial_same_peer_invalid_tls_storm_does_not_bypass_saturation_grace_cap() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let config = Arc::new(test_config_with_secret_hex("75757575757575757575757575757575")); + let replay_checker = Arc::new(ReplayChecker::new(1024, Duration::from_secs(60))); + let rng = Arc::new(SecureRandom::new()); + let peer: SocketAddr = "198.51.100.212:45212".parse().unwrap(); + let now = Instant::now(); + auth_probe_state_map().insert( + normalize_auth_probe_ip(peer.ip()), + AuthProbeState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }, + ); + { + let mut guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }); + } + + let mut invalid_tls = vec![0x42u8; tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + 32]; + invalid_tls[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = 32; + let invalid_tls = Arc::new(invalid_tls); + + let mut tasks = Vec::new(); + for _ in 0..64usize { + let config = config.clone(); + let replay_checker = replay_checker.clone(); + let rng = rng.clone(); + let invalid_tls = invalid_tls.clone(); + tasks.push(tokio::spawn(async move { + handle_tls_handshake( + &invalid_tls, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await + })); + } + + for task in tasks { + let result = task.await.unwrap(); + assert!(matches!(result, HandshakeResult::BadClient { .. })); + } + + assert_eq!( + auth_probe_fail_streak_for_testing(peer.ip()), + Some(AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS), + "same-peer invalid storm under exhausted grace must stay pre-auth throttled without fail-streak growth" + ); +} + +#[tokio::test] +async fn light_fuzz_saturation_grace_tls_invalid_inputs_never_authenticate_or_panic() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let config = test_config_with_secret_hex("76767676767676767676767676767676"); + let replay_checker = ReplayChecker::new(2048, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.213:45213".parse().unwrap(); + let now = Instant::now(); + auth_probe_state_map().insert( + normalize_auth_probe_ip(peer.ip()), + AuthProbeState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }, + ); + { + let mut guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(1), + last_seen: now, + }); + } + + let mut seeded = StdRng::seed_from_u64(0xD15EA5E5_u64); + for _ in 0..128usize { + let len = seeded.random_range(0usize..96usize); + let mut probe = vec![0u8; len]; + seeded.fill(&mut probe[..]); + + let result = handle_tls_handshake( + &probe, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(result, HandshakeResult::BadClient { .. })); + } + + let streak = auth_probe_fail_streak_for_testing(peer.ip()) + .expect("peer should remain tracked after repeated invalid fuzz probes"); + assert!( + streak >= AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS, + "fuzzed invalid TLS probes under saturation must not reduce fail-streak below exhaustion threshold" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn adversarial_saturation_burst_only_admits_valid_tls_and_mtproto_handshakes() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let secret_hex = "66666666666666666666666666666666"; + let secret = [0x66u8; 16]; + let mut cfg = test_config_with_secret_hex(secret_hex); + cfg.general.modes.secure = true; + let config = Arc::new(cfg); + let replay_checker = Arc::new(ReplayChecker::new(4096, Duration::from_secs(60))); + let rng = Arc::new(SecureRandom::new()); + let now = Instant::now(); + + { + let mut guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(5), + last_seen: now, + }); + } + + let valid_tls = Arc::new(make_valid_tls_handshake(&secret, 0)); + let valid_mtproto = Arc::new(make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 3)); + let mut invalid_tls = vec![0x42u8; tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + 32]; + invalid_tls[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = 32; + let invalid_tls = Arc::new(invalid_tls); + + let mut invalid_tls_tasks = Vec::new(); + for idx in 0..48u16 { + let config = config.clone(); + let replay_checker = replay_checker.clone(); + let rng = rng.clone(); + let invalid_tls = invalid_tls.clone(); + invalid_tls_tasks.push(tokio::spawn(async move { + let octet = ((idx % 200) + 1) as u8; + let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(203, 0, 113, octet)), 46000 + idx); + handle_tls_handshake( + &invalid_tls, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await + })); + } + + let valid_tls_task = { + let config = config.clone(); + let replay_checker = replay_checker.clone(); + let rng = rng.clone(); + let valid_tls = valid_tls.clone(); + tokio::spawn(async move { + handle_tls_handshake( + &valid_tls, + tokio::io::empty(), + tokio::io::sink(), + "198.51.100.108:45108".parse().unwrap(), + &config, + &replay_checker, + &rng, + None, + ) + .await + }) + }; + + let valid_mtproto_task = { + let config = config.clone(); + let replay_checker = replay_checker.clone(); + let valid_mtproto = valid_mtproto.clone(); + tokio::spawn(async move { + handle_mtproto_handshake( + &valid_mtproto, + tokio::io::empty(), + tokio::io::sink(), + "198.51.100.109:45109".parse().unwrap(), + &config, + &replay_checker, + false, + None, + ) + .await + }) + }; + + let mut bad_clients = 0usize; + for task in invalid_tls_tasks { + match task.await.unwrap() { + HandshakeResult::BadClient { .. } => bad_clients += 1, + HandshakeResult::Success(_) => panic!("invalid TLS probe unexpectedly authenticated"), + HandshakeResult::Error(err) => panic!("unexpected error in invalid TLS saturation burst test: {err}"), + } + } + + let valid_tls_result = valid_tls_task.await.unwrap(); + assert!( + matches!(valid_tls_result, HandshakeResult::Success(_)), + "valid TLS probe must authenticate during saturation burst" + ); + + let valid_mtproto_result = valid_mtproto_task.await.unwrap(); + assert!( + matches!(valid_mtproto_result, HandshakeResult::Success(_)), + "valid MTProto probe must authenticate during saturation burst" + ); + + assert_eq!( + bad_clients, + 48, + "all invalid TLS probes in mixed saturation burst must be rejected" + ); +} + +#[tokio::test] +async fn expired_saturation_keeps_per_ip_throttle_enforced_for_valid_tls() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let secret = [0x67u8; 16]; + let config = test_config_with_secret_hex("67676767676767676767676767676767"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.110:45110".parse().unwrap(); + let now = Instant::now(); + + auth_probe_state_map().insert( + normalize_auth_probe_ip(peer.ip()), + AuthProbeState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(5), + last_seen: now, + }, + ); + { + let mut guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(5), + last_seen: now - Duration::from_secs(AUTH_PROBE_TRACK_RETENTION_SECS + 1), + }); + } + + let valid = make_valid_tls_handshake(&secret, 0); + let result = handle_tls_handshake( + &valid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!( + matches!(result, HandshakeResult::BadClient { .. }), + "expired saturation marker must not disable per-ip pre-auth throttle" + ); +} diff --git a/src/proxy/tests/masking_ab_envelope_blur_integration_security_tests.rs b/src/proxy/tests/masking_ab_envelope_blur_integration_security_tests.rs new file mode 100644 index 0000000..1b30067 --- /dev/null +++ b/src/proxy/tests/masking_ab_envelope_blur_integration_security_tests.rs @@ -0,0 +1,517 @@ +use super::*; +use std::collections::BTreeSet; +use tokio::io::duplex; +use tokio::net::TcpListener; +use tokio::time::{Duration, Instant}; + +#[derive(Clone, Copy)] +enum PathClass { + ConnectFail, + ConnectSuccess, + SlowBackend, +} + +fn mean_ms(samples: &[u128]) -> f64 { + if samples.is_empty() { + return 0.0; + } + let sum: u128 = samples.iter().copied().sum(); + sum as f64 / samples.len() as f64 +} + +fn percentile_ms(mut values: Vec, p_num: usize, p_den: usize) -> u128 { + values.sort_unstable(); + if values.is_empty() { + return 0; + } + let idx = ((values.len() - 1) * p_num) / p_den; + values[idx] +} + +fn bucketize_ms(values: &[u128], bucket_ms: u128) -> Vec { + values.iter().map(|v| *v / bucket_ms).collect() +} + +fn best_threshold_accuracy_u128(a: &[u128], b: &[u128]) -> f64 { + let min_v = *a.iter().chain(b.iter()).min().unwrap(); + let max_v = *a.iter().chain(b.iter()).max().unwrap(); + + let mut best = 0.0f64; + for t in min_v..=max_v { + let correct_a = a.iter().filter(|&&x| x <= t).count(); + let correct_b = b.iter().filter(|&&x| x > t).count(); + let acc = (correct_a + correct_b) as f64 / (a.len() + b.len()) as f64; + if acc > best { + best = acc; + } + } + best +} + +fn spread_u128(values: &[u128]) -> u128 { + if values.is_empty() { + return 0; + } + let min_v = *values.iter().min().unwrap(); + let max_v = *values.iter().max().unwrap(); + max_v - min_v +} + +async fn collect_timing_samples(path: PathClass, timing_norm_enabled: bool, n: usize) -> Vec { + let mut out = Vec::with_capacity(n); + for _ in 0..n { + out.push(measure_masking_duration_ms(path, timing_norm_enabled).await); + } + out +} + +async fn measure_masking_duration_ms(path: PathClass, timing_norm_enabled: bool) -> u128 { + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = None; + config.censorship.mask_timing_normalization_enabled = timing_norm_enabled; + config.censorship.mask_timing_normalization_floor_ms = 220; + config.censorship.mask_timing_normalization_ceiling_ms = 260; + + let accept_task = match path { + PathClass::ConnectFail => { + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = 1; + None + } + PathClass::ConnectSuccess => { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + Some(tokio::spawn(async move { + let (_stream, _) = listener.accept().await.unwrap(); + })) + } + PathClass::SlowBackend => { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + Some(tokio::spawn(async move { + let (_stream, _) = listener.accept().await.unwrap(); + tokio::time::sleep(Duration::from_millis(320)).await; + })) + } + }; + + let (client_reader, _client_writer) = duplex(1024); + let (_client_visible_reader, client_visible_writer) = duplex(1024); + + let peer: SocketAddr = "198.51.100.230:57230".parse().unwrap(); + let local: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + b"GET /ab-harness HTTP/1.1\r\nHost: x\r\n\r\n", + peer, + local, + &config, + &beobachten, + ) + .await; + + if let Some(task) = accept_task { + let _ = tokio::time::timeout(Duration::from_secs(2), task).await; + } + + started.elapsed().as_millis() +} + +async fn capture_above_cap_forwarded_len( + body_sent: usize, + above_cap_blur_enabled: bool, + above_cap_blur_max_bytes: usize, +) -> usize { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_shape_hardening = true; + config.censorship.mask_shape_bucket_floor_bytes = 512; + config.censorship.mask_shape_bucket_cap_bytes = 4096; + config.censorship.mask_shape_above_cap_blur = above_cap_blur_enabled; + config.censorship.mask_shape_above_cap_blur_max_bytes = above_cap_blur_max_bytes; + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = Vec::new(); + let _ = tokio::time::timeout(Duration::from_secs(2), stream.read_to_end(&mut got)).await; + got.len() + }); + + let (client_reader, mut client_writer) = duplex(64 * 1024); + let (_client_visible_reader, client_visible_writer) = duplex(64 * 1024); + + let peer: SocketAddr = "198.51.100.231:57231".parse().unwrap(); + let local: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + let mut initial = vec![0u8; 5 + body_sent]; + initial[0] = 0x16; + initial[1] = 0x03; + initial[2] = 0x01; + initial[3..5].copy_from_slice(&7000u16.to_be_bytes()); + initial[5..].fill(0x5A); + + let fallback_task = tokio::spawn(async move { + handle_bad_client( + client_reader, + client_visible_writer, + &initial, + peer, + local, + &config, + &beobachten, + ) + .await; + }); + + client_writer.shutdown().await.unwrap(); + let _ = tokio::time::timeout(Duration::from_secs(4), fallback_task) + .await + .unwrap() + .unwrap(); + + tokio::time::timeout(Duration::from_secs(4), accept_task) + .await + .unwrap() + .unwrap() +} + +#[tokio::test] +async fn integration_ab_harness_envelope_and_blur_improve_obfuscation_vs_baseline() { + const ITER: usize = 8; + + let mut baseline_fail = Vec::with_capacity(ITER); + let mut baseline_success = Vec::with_capacity(ITER); + let mut baseline_slow = Vec::with_capacity(ITER); + + let mut hardened_fail = Vec::with_capacity(ITER); + let mut hardened_success = Vec::with_capacity(ITER); + let mut hardened_slow = Vec::with_capacity(ITER); + + for _ in 0..ITER { + baseline_fail.push(measure_masking_duration_ms(PathClass::ConnectFail, false).await); + baseline_success.push(measure_masking_duration_ms(PathClass::ConnectSuccess, false).await); + baseline_slow.push(measure_masking_duration_ms(PathClass::SlowBackend, false).await); + + hardened_fail.push(measure_masking_duration_ms(PathClass::ConnectFail, true).await); + hardened_success.push(measure_masking_duration_ms(PathClass::ConnectSuccess, true).await); + hardened_slow.push(measure_masking_duration_ms(PathClass::SlowBackend, true).await); + } + + let baseline_means = [ + mean_ms(&baseline_fail), + mean_ms(&baseline_success), + mean_ms(&baseline_slow), + ]; + let hardened_means = [ + mean_ms(&hardened_fail), + mean_ms(&hardened_success), + mean_ms(&hardened_slow), + ]; + + let baseline_range = baseline_means + .iter() + .copied() + .fold((f64::INFINITY, f64::NEG_INFINITY), |(mn, mx), v| { + (mn.min(v), mx.max(v)) + }); + let hardened_range = hardened_means + .iter() + .copied() + .fold((f64::INFINITY, f64::NEG_INFINITY), |(mn, mx), v| { + (mn.min(v), mx.max(v)) + }); + + let baseline_spread = baseline_range.1 - baseline_range.0; + let hardened_spread = hardened_range.1 - hardened_range.0; + + println!( + "ab_harness_timing baseline_means={:?} hardened_means={:?} baseline_spread={:.2} hardened_spread={:.2}", + baseline_means, hardened_means, baseline_spread, hardened_spread + ); + + assert!( + hardened_spread < baseline_spread, + "timing envelope should reduce cross-path mean spread: baseline={baseline_spread:.2} hardened={hardened_spread:.2}" + ); + + let mut baseline_a = BTreeSet::new(); + let mut baseline_b = BTreeSet::new(); + let mut hardened_a = BTreeSet::new(); + let mut hardened_b = BTreeSet::new(); + + for _ in 0..24 { + baseline_a.insert(capture_above_cap_forwarded_len(5000, false, 0).await); + baseline_b.insert(capture_above_cap_forwarded_len(5040, false, 0).await); + + hardened_a.insert(capture_above_cap_forwarded_len(5000, true, 96).await); + hardened_b.insert(capture_above_cap_forwarded_len(5040, true, 96).await); + } + + let baseline_overlap = baseline_a.intersection(&baseline_b).count(); + let hardened_overlap = hardened_a.intersection(&hardened_b).count(); + + println!( + "ab_harness_length baseline_overlap={} hardened_overlap={} baseline_a={} baseline_b={} hardened_a={} hardened_b={}", + baseline_overlap, + hardened_overlap, + baseline_a.len(), + baseline_b.len(), + hardened_a.len(), + hardened_b.len() + ); + + assert_eq!(baseline_overlap, 0, "baseline above-cap classes should be disjoint"); + assert!( + hardened_overlap > baseline_overlap, + "above-cap blur should increase cross-class overlap: baseline={} hardened={}", + baseline_overlap, + hardened_overlap + ); +} + +#[test] +fn timing_classifier_helper_bucketize_is_stable() { + let values = vec![219u128, 220, 239, 240, 259, 260]; + let got = bucketize_ms(&values, 20); + assert_eq!(got, vec![10, 11, 11, 12, 12, 13]); +} + +#[test] +fn timing_classifier_helper_percentile_is_monotonic() { + let samples = vec![210u128, 220, 230, 240, 250, 260, 270, 280]; + let p50 = percentile_ms(samples.clone(), 50, 100); + let p95 = percentile_ms(samples.clone(), 95, 100); + assert!(p95 >= p50); +} + +#[test] +fn timing_classifier_helper_threshold_accuracy_is_perfect_for_disjoint_sets() { + let a = vec![10u128, 11, 12, 13, 14]; + let b = vec![20u128, 21, 22, 23, 24]; + let acc = best_threshold_accuracy_u128(&a, &b); + assert!(acc >= 0.99); +} + +#[test] +fn timing_classifier_helper_threshold_accuracy_drops_for_identical_sets() { + let a = vec![10u128, 11, 12, 13, 14]; + let b = vec![10u128, 11, 12, 13, 14]; + let acc = best_threshold_accuracy_u128(&a, &b); + assert!(acc <= 0.6, "identical sets should not be strongly separable"); +} + +#[test] +fn timing_classifier_helper_bucketed_threshold_reduces_resolution() { + let raw_a = vec![221u128, 223, 225, 227, 229]; + let raw_b = vec![231u128, 233, 235, 237, 239]; + let raw_acc = best_threshold_accuracy_u128(&raw_a, &raw_b); + + let bucketed_a = bucketize_ms(&raw_a, 20); + let bucketed_b = bucketize_ms(&raw_b, 20); + let bucketed_acc = best_threshold_accuracy_u128(&bucketed_a, &bucketed_b); + + assert!(raw_acc >= bucketed_acc); +} + +#[tokio::test] +async fn timing_classifier_baseline_connect_fail_vs_slow_backend_is_highly_separable() { + let fail = collect_timing_samples(PathClass::ConnectFail, false, 8).await; + let slow = collect_timing_samples(PathClass::SlowBackend, false, 8).await; + + let acc = best_threshold_accuracy_u128(&fail, &slow); + assert!(acc >= 0.80, "baseline timing classes should be separable enough"); +} + +#[tokio::test] +async fn timing_classifier_normalized_connect_fail_vs_slow_backend_reduces_separability() { + let baseline_fail = collect_timing_samples(PathClass::ConnectFail, false, 8).await; + let baseline_slow = collect_timing_samples(PathClass::SlowBackend, false, 8).await; + let hardened_fail = collect_timing_samples(PathClass::ConnectFail, true, 8).await; + let hardened_slow = collect_timing_samples(PathClass::SlowBackend, true, 8).await; + + let baseline_acc = best_threshold_accuracy_u128(&baseline_fail, &baseline_slow); + let hardened_acc = best_threshold_accuracy_u128(&hardened_fail, &hardened_slow); + + assert!( + hardened_acc <= baseline_acc, + "normalization should not increase timing separability" + ); +} + +#[tokio::test] +async fn timing_classifier_bucketed_normalized_connect_fail_vs_slow_backend_is_bounded() { + let baseline_fail = collect_timing_samples(PathClass::ConnectFail, false, 10).await; + let baseline_slow = collect_timing_samples(PathClass::SlowBackend, false, 10).await; + let hardened_fail = collect_timing_samples(PathClass::ConnectFail, true, 10).await; + let hardened_slow = collect_timing_samples(PathClass::SlowBackend, true, 10).await; + + let baseline_acc = best_threshold_accuracy_u128( + &bucketize_ms(&baseline_fail, 20), + &bucketize_ms(&baseline_slow, 20), + ); + let hardened_acc = best_threshold_accuracy_u128( + &bucketize_ms(&hardened_fail, 20), + &bucketize_ms(&hardened_slow, 20), + ); + + assert!( + hardened_acc <= baseline_acc, + "normalized bucketed classifier should not outperform baseline: baseline={baseline_acc:.3} hardened={hardened_acc:.3}" + ); +} + +#[tokio::test] +async fn timing_classifier_normalized_connect_fail_samples_stay_in_sane_bounds() { + let samples = collect_timing_samples(PathClass::ConnectFail, true, 6).await; + for s in samples { + assert!((150..=1200).contains(&s), "sample out of sane bounds: {s}"); + } +} + +#[tokio::test] +async fn timing_classifier_normalized_connect_success_samples_stay_in_sane_bounds() { + let samples = collect_timing_samples(PathClass::ConnectSuccess, true, 6).await; + for s in samples { + assert!((150..=1200).contains(&s), "sample out of sane bounds: {s}"); + } +} + +#[tokio::test] +async fn timing_classifier_normalized_slow_backend_samples_stay_in_sane_bounds() { + let samples = collect_timing_samples(PathClass::SlowBackend, true, 6).await; + for s in samples { + assert!((150..=1400).contains(&s), "sample out of sane bounds: {s}"); + } +} + +#[tokio::test] +async fn timing_classifier_normalized_mean_bucket_delta_connect_fail_vs_connect_success_is_small() { + let fail = collect_timing_samples(PathClass::ConnectFail, true, 8).await; + let success = collect_timing_samples(PathClass::ConnectSuccess, true, 8).await; + let fail_mean = mean_ms(&fail); + let success_mean = mean_ms(&success); + let delta_bucket = ((fail_mean as i128 - success_mean as i128).abs()) / 20; + assert!(delta_bucket <= 3, "mean bucket delta too large: {delta_bucket}"); +} + +#[tokio::test] +async fn timing_classifier_normalized_p95_bucket_delta_connect_success_vs_slow_is_small() { + let success = collect_timing_samples(PathClass::ConnectSuccess, true, 10).await; + let slow = collect_timing_samples(PathClass::SlowBackend, true, 10).await; + let p95_success = percentile_ms(success, 95, 100); + let p95_slow = percentile_ms(slow, 95, 100); + let delta_bucket = ((p95_success as i128 - p95_slow as i128).abs()) / 20; + assert!(delta_bucket <= 4, "p95 bucket delta too large: {delta_bucket}"); +} + +#[tokio::test] +async fn timing_classifier_normalized_spread_is_not_worse_than_baseline_for_connect_fail() { + let baseline = collect_timing_samples(PathClass::ConnectFail, false, 8).await; + let hardened = collect_timing_samples(PathClass::ConnectFail, true, 8).await; + let baseline_spread = spread_u128(&baseline); + let hardened_spread = spread_u128(&hardened); + assert!( + hardened_spread <= baseline_spread.saturating_add(600), + "normalized spread exploded unexpectedly: baseline={baseline_spread} hardened={hardened_spread}" + ); +} + +#[tokio::test] +async fn timing_classifier_light_fuzz_pairwise_bucketed_accuracy_stays_bounded_under_normalization() { + let pairs = [ + (PathClass::ConnectFail, PathClass::ConnectSuccess), + (PathClass::ConnectFail, PathClass::SlowBackend), + (PathClass::ConnectSuccess, PathClass::SlowBackend), + ]; + + let mut meaningful_improvement_seen = false; + let mut baseline_sum = 0.0f64; + let mut hardened_sum = 0.0f64; + let mut pair_count = 0usize; + + for (a, b) in pairs { + let baseline_a = collect_timing_samples(a, false, 6).await; + let baseline_b = collect_timing_samples(b, false, 6).await; + let hardened_a = collect_timing_samples(a, true, 6).await; + let hardened_b = collect_timing_samples(b, true, 6).await; + + let baseline_acc = best_threshold_accuracy_u128( + &bucketize_ms(&baseline_a, 20), + &bucketize_ms(&baseline_b, 20), + ); + let hardened_acc = best_threshold_accuracy_u128( + &bucketize_ms(&hardened_a, 20), + &bucketize_ms(&hardened_b, 20), + ); + + // When baseline separability is near-random, tiny sample jitter can make + // hardened appear "worse" without indicating a real side-channel regression. + // Guard hard only on informative baseline pairs. + if baseline_acc >= 0.75 { + assert!( + hardened_acc <= baseline_acc + 0.05, + "normalization should not materially worsen informative pair: baseline={baseline_acc:.3} hardened={hardened_acc:.3}" + ); + } + + if hardened_acc + 0.05 <= baseline_acc { + meaningful_improvement_seen = true; + } + + baseline_sum += baseline_acc; + hardened_sum += hardened_acc; + pair_count += 1; + } + + let baseline_avg = baseline_sum / pair_count as f64; + let hardened_avg = hardened_sum / pair_count as f64; + + assert!( + hardened_avg <= baseline_avg + 0.08, + "normalization should not materially increase average pairwise separability: baseline_avg={baseline_avg:.3} hardened_avg={hardened_avg:.3}" + ); + + // Optional signal only: do not require improvement on every run because + // noisy CI schedulers can flatten pairwise differences at low sample counts. + let _ = meaningful_improvement_seen; +} + +#[tokio::test] +async fn timing_classifier_stress_parallel_sampling_finishes_and_stays_bounded() { + let mut tasks = Vec::new(); + for i in 0..24usize { + tasks.push(tokio::spawn(async move { + let class = match i % 3 { + 0 => PathClass::ConnectFail, + 1 => PathClass::ConnectSuccess, + _ => PathClass::SlowBackend, + }; + let sample = measure_masking_duration_ms(class, true).await; + assert!((100..=1600).contains(&sample), "stress sample out of bounds: {sample}"); + })); + } + + for task in tasks { + tokio::time::timeout(Duration::from_secs(4), task) + .await + .unwrap() + .unwrap(); + } +} diff --git a/src/proxy/tests/masking_adversarial_tests.rs b/src/proxy/tests/masking_adversarial_tests.rs new file mode 100644 index 0000000..955e8ec --- /dev/null +++ b/src/proxy/tests/masking_adversarial_tests.rs @@ -0,0 +1,762 @@ +use super::*; +use std::sync::Arc; +use tokio::io::duplex; +use tokio::net::TcpListener; +use tokio::time::{Instant, Duration}; +use crate::config::ProxyConfig; +use crate::proxy::relay::relay_bidirectional; +use crate::stats::Stats; +use crate::stats::beobachten::BeobachtenStore; +use crate::stream::BufferPool; + +// ------------------------------------------------------------------ +// Probing Indistinguishability (OWASP ASVS 5.1.7) +// ------------------------------------------------------------------ + +#[tokio::test] +async fn masking_probes_indistinguishable_timing() { + let mut config = ProxyConfig::default(); + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = 80; // Should timeout/refuse + + let peer: SocketAddr = "192.0.2.10:443".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + // Test different probe types + let probes = vec![ + (b"GET / HTTP/1.1\r\nHost: x\r\n\r\n".to_vec(), "HTTP"), + (b"SSH-2.0-probe".to_vec(), "SSH"), + (vec![0x16, 0x03, 0x03, 0x00, 0x05, 0x01, 0x00, 0x00, 0x01, 0x00], "TLS-scanner"), + (vec![0x42; 5], "port-scanner"), + ]; + + for (probe, type_name) in probes { + let (client_reader, _client_writer) = duplex(256); + let (_client_visible_reader, client_visible_writer) = duplex(256); + + let start = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local_addr, + &config, + &beobachten, + ).await; + + let elapsed = start.elapsed(); + + // We expect any outcome to take roughly MASK_TIMEOUT (50ms in tests) + // to mask whether the backend was reachable or refused. + assert!(elapsed >= Duration::from_millis(30), "Probe {type_name} finished too fast: {elapsed:?}"); + } +} + +// ------------------------------------------------------------------ +// Masking Budget Stress Tests (OWASP ASVS 5.1.6) +// ------------------------------------------------------------------ + +#[tokio::test] +async fn masking_budget_stress_under_load() { + let mut config = ProxyConfig::default(); + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = 1; // Unlikely port + + let peer: SocketAddr = "192.0.2.20:443".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let beobachten = Arc::new(BeobachtenStore::new()); + + let mut tasks = Vec::new(); + for _ in 0..50 { + let (client_reader, _client_writer) = duplex(256); + let (_client_visible_reader, client_visible_writer) = duplex(256); + let config = config.clone(); + let beobachten = Arc::clone(&beobachten); + + tasks.push(tokio::spawn(async move { + let start = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + b"probe", + peer, + local_addr, + &config, + &beobachten, + ).await; + start.elapsed() + })); + } + + for task in tasks { + let elapsed = task.await.unwrap(); + assert!(elapsed >= Duration::from_millis(30), "Stress probe finished too fast: {elapsed:?}"); + } +} + +// ------------------------------------------------------------------ +// detect_client_type Fingerprint Check +// ------------------------------------------------------------------ + +#[test] +fn test_detect_client_type_boundary_cases() { + // 9 bytes = port-scanner + assert_eq!(detect_client_type(&[0x42; 9]), "port-scanner"); + // 10 bytes = unknown + assert_eq!(detect_client_type(&[0x42; 10]), "unknown"); + + // HTTP verbs without trailing space + assert_eq!(detect_client_type(b"GET/"), "port-scanner"); // because len < 10 + assert_eq!(detect_client_type(b"GET /path"), "HTTP"); +} + +// ------------------------------------------------------------------ +// Priority 2: Slowloris and Slow Read Attacks (OWASP ASVS 5.1.5) +// ------------------------------------------------------------------ + +#[tokio::test] +async fn masking_slowloris_client_idle_timeout_rejected() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let initial = b"GET / HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn({ + let initial = initial.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut observed = vec![0u8; initial.len()]; + stream.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, initial); + + let mut drip = [0u8; 1]; + let drip_read = tokio::time::timeout(Duration::from_millis(220), stream.read_exact(&mut drip)).await; + assert!( + drip_read.is_err() || drip_read.unwrap().is_err(), + "backend must not receive post-timeout slowloris drip bytes" + ); + } + }); + + let mut config = ProxyConfig::default(); + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + + let beobachten = BeobachtenStore::new(); + let peer: SocketAddr = "192.0.2.10:12345".parse().unwrap(); + let local: SocketAddr = "192.0.2.1:443".parse().unwrap(); + + let (mut client_writer, client_reader) = duplex(1024); + let (_client_visible_reader, client_visible_writer) = duplex(1024); + + let handle = tokio::spawn(async move { + handle_bad_client( + client_reader, + client_visible_writer, + &initial, + peer, + local, + &config, + &beobachten, + ) + .await; + }); + + tokio::time::sleep(Duration::from_millis(160)).await; + let _ = client_writer.write_all(b"X").await; + + handle.await.unwrap(); + accept_task.await.unwrap(); +} + +// ------------------------------------------------------------------ +// Priority 2: Fallback Server Down / Fingerprinting (OWASP ASVS 5.1.7) +// ------------------------------------------------------------------ + +#[tokio::test] +async fn masking_fallback_down_mimics_timeout() { + let mut config = ProxyConfig::default(); + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = 1; // Unlikely port + + let (server_reader, server_writer) = duplex(1024); + let beobachten = BeobachtenStore::new(); + let peer: SocketAddr = "192.0.2.12:12345".parse().unwrap(); + let local: SocketAddr = "192.0.2.1:443".parse().unwrap(); + + let start = Instant::now(); + handle_bad_client(server_reader, server_writer, b"GET / HTTP/1.1\r\n", peer, local, &config, &beobachten).await; + + let elapsed = start.elapsed(); + // It should wait for MASK_TIMEOUT (50ms in tests) even if connection was refused immediately + assert!(elapsed >= Duration::from_millis(40), "Must respect connect budget even on failure: {:?}", elapsed); +} + +// ------------------------------------------------------------------ +// Priority 2: SSRF Prevention (OWASP ASVS 5.1.2) +// ------------------------------------------------------------------ + +#[tokio::test] +async fn masking_ssrf_resolve_internal_ranges_blocked() { + use crate::network::dns_overrides::resolve_socket_addr; + + let blocked_ips = ["127.0.0.1", "169.254.169.254", "10.0.0.1", "192.168.1.1", "0.0.0.0"]; + + for ip in blocked_ips { + assert!( + resolve_socket_addr(ip, 80).is_none(), + "runtime DNS overrides must not resolve unconfigured literal host targets" + ); + } +} + +#[tokio::test] +async fn masking_unknown_proxy_protocol_version_falls_back_to_v1_unknown_header() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut header = [0u8; 15]; + stream.read_exact(&mut header).await.unwrap(); + assert_eq!(&header, b"PROXY UNKNOWN\r\n"); + + let mut payload = [0u8; 5]; + stream.read_exact(&mut payload).await.unwrap(); + assert_eq!(&payload, b"probe"); + }); + + let mut config = ProxyConfig::default(); + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_proxy_protocol = 255; + + let peer: SocketAddr = "198.51.100.77:50001".parse().unwrap(); + let local_addr: SocketAddr = "[2001:db8::10]:443".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + let (client_reader, _client_writer) = duplex(128); + let (_client_visible_reader, client_visible_writer) = duplex(128); + + handle_bad_client( + client_reader, + client_visible_writer, + b"probe", + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + accept_task.await.unwrap(); +} + +#[tokio::test] +async fn masking_zero_length_initial_data_does_not_hang_or_panic() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut one = [0u8; 1]; + let n = tokio::time::timeout(Duration::from_millis(150), stream.read(&mut one)) + .await + .unwrap() + .unwrap(); + assert_eq!(n, 0, "backend must observe clean EOF for empty initial payload"); + }); + + let mut config = ProxyConfig::default(); + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + + let peer: SocketAddr = "203.0.113.70:50002".parse().unwrap(); + let local: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + let (client_reader, client_writer) = duplex(64); + drop(client_writer); + let (_client_visible_reader, client_visible_writer) = duplex(64); + + handle_bad_client( + client_reader, + client_visible_writer, + b"", + peer, + local, + &config, + &beobachten, + ) + .await; + + accept_task.await.unwrap(); +} + +#[tokio::test] +async fn masking_oversized_initial_payload_is_forwarded_verbatim() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let payload = vec![0xA5u8; 32 * 1024]; + + let accept_task = tokio::spawn({ + let payload = payload.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut observed = vec![0u8; payload.len()]; + stream.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, payload, "large initial payload must stay byte-for-byte"); + } + }); + + let mut config = ProxyConfig::default(); + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + + let peer: SocketAddr = "203.0.113.71:50003".parse().unwrap(); + let local: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + let (client_reader, _client_writer) = duplex(64); + let (_client_visible_reader, client_visible_writer) = duplex(64); + + handle_bad_client( + client_reader, + client_visible_writer, + &payload, + peer, + local, + &config, + &beobachten, + ) + .await; + + accept_task.await.unwrap(); +} + +#[tokio::test] +async fn masking_refused_backend_keeps_constantish_timing_floor_under_burst() { + let mut config = ProxyConfig::default(); + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = 1; + + let peer: SocketAddr = "203.0.113.72:50004".parse().unwrap(); + let local: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + for _ in 0..16 { + let (client_reader, _client_writer) = duplex(128); + let (_client_visible_reader, client_visible_writer) = duplex(128); + let started = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + b"GET / HTTP/1.1\r\n", + peer, + local, + &config, + &beobachten, + ) + .await; + assert!( + started.elapsed() >= Duration::from_millis(30), + "refused-backend path must keep timing floor to reduce fingerprinting" + ); + } +} + +#[tokio::test] +async fn masking_backend_half_close_then_client_half_close_completes_without_hang() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut pre = [0u8; 4]; + stream.read_exact(&mut pre).await.unwrap(); + assert_eq!(&pre, b"PING"); + stream.write_all(b"PONG").await.unwrap(); + stream.shutdown().await.unwrap(); + }); + + let mut config = ProxyConfig::default(); + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + + let peer: SocketAddr = "203.0.113.73:50005".parse().unwrap(); + let local: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + let (mut client_writer, client_reader) = duplex(256); + let (mut client_visible_reader, client_visible_writer) = duplex(256); + + let handle = tokio::spawn(async move { + handle_bad_client( + client_reader, + client_visible_writer, + b"PING", + peer, + local, + &config, + &beobachten, + ) + .await; + }); + + client_writer.shutdown().await.unwrap(); + + let mut got = [0u8; 4]; + client_visible_reader.read_exact(&mut got).await.unwrap(); + assert_eq!(&got, b"PONG"); + + timeout(Duration::from_secs(2), handle) + .await + .expect("masking task must terminate after bilateral half-close") + .unwrap(); + accept_task.await.unwrap(); +} + +#[tokio::test] +async fn chaos_burst_reconnect_storm_for_masking_and_relay_concurrently() { + const MASKING_SESSIONS: usize = 48; + const RELAY_SESSIONS: usize = 48; + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let backend_reply = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK".to_vec(); + + let backend_task = tokio::spawn({ + let backend_reply = backend_reply.clone(); + async move { + for _ in 0..MASKING_SESSIONS { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut req = [0u8; 32]; + stream.read_exact(&mut req).await.unwrap(); + assert!( + req.starts_with(b"GET /storm/"), + "masking backend must receive storm reconnect probes" + ); + stream.write_all(&backend_reply).await.unwrap(); + stream.shutdown().await.unwrap(); + } + } + }); + + let mut config = ProxyConfig::default(); + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_proxy_protocol = 0; + + let config = Arc::new(config); + let beobachten = Arc::new(BeobachtenStore::new()); + let peer: SocketAddr = "198.51.100.200:55555".parse().unwrap(); + let local: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let mut masking_tasks = Vec::with_capacity(MASKING_SESSIONS); + for i in 0..MASKING_SESSIONS { + let config = Arc::clone(&config); + let beobachten = Arc::clone(&beobachten); + let expected_reply = backend_reply.clone(); + masking_tasks.push(tokio::spawn(async move { + let mut probe = [0u8; 32]; + let template = format!("GET /storm/{i:04} HTTP/1.1\r\n\r\n"); + let bytes = template.as_bytes(); + probe[..bytes.len()].copy_from_slice(bytes); + + let (client_reader, client_writer) = duplex(256); + drop(client_writer); + let (mut client_visible_reader, client_visible_writer) = duplex(1024); + + let handle = tokio::spawn(async move { + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local, + &config, + &beobachten, + ) + .await; + }); + + let mut observed = vec![0u8; expected_reply.len()]; + client_visible_reader.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, expected_reply); + + timeout(Duration::from_secs(2), handle) + .await + .expect("masking reconnect task must complete") + .unwrap(); + })); + } + + let mut relay_tasks = Vec::with_capacity(RELAY_SESSIONS); + for i in 0..RELAY_SESSIONS { + relay_tasks.push(tokio::spawn(async move { + let stats = Arc::new(Stats::new()); + let (mut client_peer, relay_client) = duplex(4096); + let (relay_server, mut server_peer) = duplex(4096); + + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_task = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 1024, + 1024, + "chaos-storm-relay", + stats, + None, + Arc::new(BufferPool::new()), + )); + + let c2s = vec![(i as u8).wrapping_add(1); 64]; + client_peer.write_all(&c2s).await.unwrap(); + let mut c2s_seen = vec![0u8; c2s.len()]; + server_peer.read_exact(&mut c2s_seen).await.unwrap(); + assert_eq!(c2s_seen, c2s); + + let s2c = vec![(i as u8).wrapping_add(17); 96]; + server_peer.write_all(&s2c).await.unwrap(); + let mut s2c_seen = vec![0u8; s2c.len()]; + client_peer.read_exact(&mut s2c_seen).await.unwrap(); + assert_eq!(s2c_seen, s2c); + + drop(client_peer); + drop(server_peer); + timeout(Duration::from_secs(2), relay_task) + .await + .expect("relay reconnect task must complete") + .unwrap() + .unwrap(); + })); + } + + for task in masking_tasks { + timeout(Duration::from_secs(3), task) + .await + .expect("masking storm join must complete") + .unwrap(); + } + + for task in relay_tasks { + timeout(Duration::from_secs(3), task) + .await + .expect("relay storm join must complete") + .unwrap(); + } + + timeout(Duration::from_secs(3), backend_task) + .await + .expect("masking backend accept loop must complete") + .unwrap(); +} + +fn read_env_usize_or_default(name: &str, default: usize) -> usize { + match std::env::var(name) { + Ok(raw) => match raw.parse::() { + Ok(parsed) if parsed > 0 => parsed, + _ => default, + }, + Err(_) => default, + } +} + +#[tokio::test] +#[ignore = "heavy soak; run manually"] +async fn chaos_burst_reconnect_storm_for_masking_and_relay_multiwave_soak() { + let waves = read_env_usize_or_default("CHAOS_WAVES", 4); + let masking_per_wave = read_env_usize_or_default("CHAOS_MASKING_PER_WAVE", 160); + let relay_per_wave = read_env_usize_or_default("CHAOS_RELAY_PER_WAVE", 160); + let total_masking = waves * masking_per_wave; + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let backend_reply = b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\n\r\n".to_vec(); + + let backend_task = tokio::spawn({ + let backend_reply = backend_reply.clone(); + async move { + for _ in 0..total_masking { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut req = [0u8; 32]; + stream.read_exact(&mut req).await.unwrap(); + assert!( + req.starts_with(b"GET /storm/"), + "mask backend must only receive storm probes" + ); + stream.write_all(&backend_reply).await.unwrap(); + stream.shutdown().await.unwrap(); + } + } + }); + + let mut config = ProxyConfig::default(); + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_proxy_protocol = 0; + + let config = Arc::new(config); + let beobachten = Arc::new(BeobachtenStore::new()); + let peer: SocketAddr = "198.51.100.201:56565".parse().unwrap(); + let local: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + for wave in 0..waves { + let mut masking_tasks = Vec::with_capacity(masking_per_wave); + for i in 0..masking_per_wave { + let config = Arc::clone(&config); + let beobachten = Arc::clone(&beobachten); + let expected_reply = backend_reply.clone(); + masking_tasks.push(tokio::spawn(async move { + let mut probe = [0u8; 32]; + let template = format!("GET /storm/{wave:02}-{i:03}\r\n\r\n"); + let bytes = template.as_bytes(); + probe[..bytes.len()].copy_from_slice(bytes); + + let (client_reader, client_writer) = duplex(256); + drop(client_writer); + let (mut client_visible_reader, client_visible_writer) = duplex(1024); + + let handle = tokio::spawn(async move { + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local, + &config, + &beobachten, + ) + .await; + }); + + let mut observed = vec![0u8; expected_reply.len()]; + client_visible_reader.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, expected_reply); + + timeout(Duration::from_secs(3), handle) + .await + .expect("masking storm task must complete") + .unwrap(); + })); + } + + let mut relay_tasks = Vec::with_capacity(relay_per_wave); + for i in 0..relay_per_wave { + relay_tasks.push(tokio::spawn(async move { + let stats = Arc::new(Stats::new()); + let (mut client_peer, relay_client) = duplex(4096); + let (relay_server, mut server_peer) = duplex(4096); + + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_task = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 1024, + 1024, + "chaos-multiwave-relay", + stats, + None, + Arc::new(BufferPool::new()), + )); + + let c2s = vec![(wave as u8).wrapping_add(i as u8).wrapping_add(1); 32]; + client_peer.write_all(&c2s).await.unwrap(); + let mut c2s_seen = vec![0u8; c2s.len()]; + server_peer.read_exact(&mut c2s_seen).await.unwrap(); + assert_eq!(c2s_seen, c2s); + + let s2c = vec![(wave as u8).wrapping_add(i as u8).wrapping_add(17); 48]; + server_peer.write_all(&s2c).await.unwrap(); + let mut s2c_seen = vec![0u8; s2c.len()]; + client_peer.read_exact(&mut s2c_seen).await.unwrap(); + assert_eq!(s2c_seen, s2c); + + drop(client_peer); + drop(server_peer); + timeout(Duration::from_secs(3), relay_task) + .await + .expect("relay storm task must complete") + .unwrap() + .unwrap(); + })); + } + + for task in masking_tasks { + timeout(Duration::from_secs(6), task) + .await + .expect("masking wave task join must complete") + .unwrap(); + } + + for task in relay_tasks { + timeout(Duration::from_secs(6), task) + .await + .expect("relay wave task join must complete") + .unwrap(); + } + } + + timeout(Duration::from_secs(8), backend_task) + .await + .expect("mask backend must complete all accepted storm sessions") + .unwrap(); +} + +#[tokio::test] +#[ignore = "heavy soak; run manually"] +async fn masking_timing_bucket_soak_refused_backend_stays_within_narrow_band() { + let mut config = ProxyConfig::default(); + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = 1; + + let peer: SocketAddr = "203.0.113.74:50006".parse().unwrap(); + let local: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + let mut samples = Vec::with_capacity(128); + for _ in 0..128 { + let (client_reader, _client_writer) = duplex(128); + let (_client_visible_reader, client_visible_writer) = duplex(128); + let started = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + b"GET / HTTP/1.1\r\n", + peer, + local, + &config, + &beobachten, + ) + .await; + samples.push(started.elapsed().as_millis()); + } + + samples.sort_unstable(); + let p10 = samples[samples.len() / 10]; + let p90 = samples[(samples.len() * 9) / 10]; + assert!( + p90.saturating_sub(p10) <= 40, + "timing spread too wide for refused-backend masking path: p10={p10}ms p90={p90}ms" + ); +} diff --git a/src/proxy/tests/masking_security_tests.rs b/src/proxy/tests/masking_security_tests.rs new file mode 100644 index 0000000..9107ca9 --- /dev/null +++ b/src/proxy/tests/masking_security_tests.rs @@ -0,0 +1,1768 @@ +use super::*; +use crate::config::ProxyConfig; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::{duplex, AsyncBufReadExt, BufReader}; +use tokio::net::TcpListener; +#[cfg(unix)] +use tokio::net::UnixListener; +use tokio::time::{Instant, sleep, timeout, Duration}; + +#[tokio::test] +async fn bad_client_probe_is_forwarded_verbatim_to_mask_backend() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = b"GET / HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec(); + let backend_reply = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK".to_vec(); + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut received = vec![0u8; probe.len()]; + stream.read_exact(&mut received).await.unwrap(); + assert_eq!(received, probe); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let peer: SocketAddr = "203.0.113.10:42424".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (client_reader, _client_writer) = duplex(256); + let (mut client_visible_reader, client_visible_writer) = duplex(2048); + + let beobachten = BeobachtenStore::new(); + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + let mut observed = vec![0u8; backend_reply.len()]; + client_visible_reader.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, backend_reply); + accept_task.await.unwrap(); +} + +#[tokio::test] +async fn tls_scanner_probe_keeps_http_like_fallback_surface() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = vec![0x16, 0x03, 0x01, 0x00, 0x10, 0x01, 0x02, 0x03, 0x04]; + let backend_reply = b"HTTP/1.1 400 Bad Request\r\nContent-Length: 0\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut received = vec![0u8; probe.len()]; + stream.read_exact(&mut received).await.unwrap(); + assert_eq!(received, probe); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = true; + config.general.beobachten_minutes = 1; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let peer: SocketAddr = "198.51.100.44:55221".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (client_reader, _client_writer) = duplex(256); + let (mut client_visible_reader, client_visible_writer) = duplex(2048); + + let beobachten = BeobachtenStore::new(); + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + let mut observed = vec![0u8; backend_reply.len()]; + client_visible_reader.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, backend_reply); + + let snapshot = beobachten.snapshot_text(Duration::from_secs(60)); + assert!(snapshot.contains("[TLS-scanner]")); + assert!(snapshot.contains("198.51.100.44-1")); + accept_task.await.unwrap(); +} + +#[test] +fn detect_client_type_covers_ssh_port_scanner_and_unknown() { + assert_eq!(detect_client_type(b"SSH-2.0-OpenSSH_9.7"), "SSH"); + assert_eq!(detect_client_type(b"\x01\x02\x03"), "port-scanner"); + assert_eq!(detect_client_type(b"random-binary-payload"), "unknown"); +} + +#[test] +fn detect_client_type_len_boundary_9_vs_10_bytes() { + assert_eq!(detect_client_type(b"123456789"), "port-scanner"); + assert_eq!(detect_client_type(b"1234567890"), "unknown"); +} + +#[test] +fn build_mask_proxy_header_version_zero_disables_header() { + let peer: SocketAddr = "203.0.113.10:42424".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let header = build_mask_proxy_header(0, peer, local_addr); + assert!(header.is_none(), "version 0 must disable PROXY header"); +} + +#[test] +fn build_mask_proxy_header_v2_matches_builder_output() { + let peer: SocketAddr = "203.0.113.10:42424".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let expected = ProxyProtocolV2Builder::new() + .with_addrs(peer, local_addr) + .build(); + let actual = build_mask_proxy_header(2, peer, local_addr) + .expect("v2 mode must produce a header"); + + assert_eq!(actual, expected, "v2 header bytes must be deterministic"); +} + +#[test] +fn build_mask_proxy_header_v1_mixed_ip_family_uses_generic_unknown_form() { + let peer: SocketAddr = "203.0.113.10:42424".parse().unwrap(); + let local_addr: SocketAddr = "[2001:db8::1]:443".parse().unwrap(); + + let expected = ProxyProtocolV1Builder::new().build(); + let actual = build_mask_proxy_header(1, peer, local_addr) + .expect("v1 mode must produce a header"); + + assert_eq!(actual, expected, "mixed-family v1 must use UNKNOWN form"); +} + +#[tokio::test] +async fn beobachten_records_scanner_class_when_mask_is_disabled() { + let mut config = ProxyConfig::default(); + config.general.beobachten = true; + config.general.beobachten_minutes = 1; + config.censorship.mask = false; + + let peer: SocketAddr = "203.0.113.99:41234".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let initial = b"SSH-2.0-probe"; + + let (mut client_reader_side, client_reader) = duplex(256); + let (_client_visible_reader, client_visible_writer) = duplex(256); + let beobachten = BeobachtenStore::new(); + + let task = tokio::spawn(async move { + handle_bad_client( + client_reader, + client_visible_writer, + initial, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + beobachten + }); + + client_reader_side.write_all(b"noise").await.unwrap(); + drop(client_reader_side); + + let beobachten = timeout(Duration::from_secs(3), task).await.unwrap().unwrap(); + let snapshot = beobachten.snapshot_text(Duration::from_secs(60)); + assert!(snapshot.contains("[SSH]")); + assert!(snapshot.contains("203.0.113.99-1")); +} + +#[tokio::test] +async fn backend_unavailable_falls_back_to_silent_consume() { + let temp_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let unused_port = temp_listener.local_addr().unwrap().port(); + drop(temp_listener); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = unused_port; + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let peer: SocketAddr = "203.0.113.11:42425".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let probe = b"GET /probe HTTP/1.1\r\nHost: x\r\n\r\n"; + + let (mut client_reader_side, client_reader) = duplex(256); + let (mut client_visible_reader, client_visible_writer) = duplex(256); + let beobachten = BeobachtenStore::new(); + + let task = tokio::spawn(async move { + handle_bad_client( + client_reader, + client_visible_writer, + probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + }); + + client_reader_side.write_all(b"noise").await.unwrap(); + drop(client_reader_side); + + timeout(Duration::from_secs(3), task).await.unwrap().unwrap(); + + let mut buf = [0u8; 1]; + let n = timeout(Duration::from_secs(1), client_visible_reader.read(&mut buf)) + .await + .unwrap() + .unwrap(); + assert_eq!(n, 0); +} + +#[tokio::test] +async fn backend_connect_refusal_waits_mask_connect_budget_before_fallback() { + let temp_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let unused_port = temp_listener.local_addr().unwrap().port(); + drop(temp_listener); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = unused_port; + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let peer: SocketAddr = "203.0.113.12:42426".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let probe = b"GET /probe HTTP/1.1\r\nHost: x\r\n\r\n"; + + // Close client reader immediately to force the refusal path to rely on masking budget timing. + let (client_reader_side, client_reader) = duplex(256); + drop(client_reader_side); + let (_client_visible_reader, client_visible_writer) = duplex(256); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + let task = tokio::spawn(async move { + handle_bad_client( + client_reader, + client_visible_writer, + probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + }); + + timeout(Duration::from_millis(35), task) + .await + .expect_err("masking fallback must not complete before connect budget elapses"); + assert!( + started.elapsed() >= Duration::from_millis(35), + "fallback path must absorb immediate connect refusal into connect budget" + ); +} + +#[tokio::test] +async fn backend_reachable_fast_response_waits_mask_outcome_budget() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = b"GET /ok HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec(); + let backend_reply = b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut received = vec![0u8; probe.len()]; + stream.read_exact(&mut received).await.unwrap(); + assert_eq!(received, probe); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let peer: SocketAddr = "203.0.113.13:42427".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (client_writer_side, client_reader) = duplex(256); + drop(client_writer_side); + let (_client_visible_reader, client_visible_writer) = duplex(512); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + assert!( + started.elapsed() >= Duration::from_millis(45), + "reachable mask path must also satisfy coarse outcome budget" + ); + accept_task.await.unwrap(); +} + +#[tokio::test] +async fn proxy_header_write_error_on_tcp_path_still_honors_coarse_outcome_budget() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = b"GET /proxy-hdr-err HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + drop(stream); + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 1; + + let peer: SocketAddr = "203.0.113.88:42430".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (client_reader_side, client_reader) = duplex(256); + drop(client_reader_side); + let (_client_visible_reader, client_visible_writer) = duplex(512); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + let task = tokio::spawn(async move { + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + }); + + timeout(Duration::from_millis(35), task) + .await + .expect_err("proxy-header write error path should remain inside coarse masking budget window"); + assert!( + started.elapsed() >= Duration::from_millis(35), + "proxy-header write error path should avoid immediate-return timing signature" + ); + + accept_task.await.unwrap(); +} + +#[cfg(unix)] +#[tokio::test] +async fn proxy_header_write_error_on_unix_path_still_honors_coarse_outcome_budget() { + let sock_path = format!( + "/tmp/telemt-mask-unix-hdr-err-{}-{}.sock", + std::process::id(), + rand::random::() + ); + let _ = std::fs::remove_file(&sock_path); + + let listener = UnixListener::bind(&sock_path).unwrap(); + let probe = b"GET /unix-hdr-err HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + drop(stream); + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = Some(sock_path.clone()); + config.censorship.mask_proxy_protocol = 1; + + let peer: SocketAddr = "203.0.113.89:42431".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (client_reader_side, client_reader) = duplex(256); + drop(client_reader_side); + let (_client_visible_reader, client_visible_writer) = duplex(512); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + let task = tokio::spawn(async move { + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + }); + + timeout(Duration::from_millis(35), task) + .await + .expect_err("unix proxy-header write error path should remain inside coarse masking budget window"); + assert!( + started.elapsed() >= Duration::from_millis(35), + "unix proxy-header write error path should avoid immediate-return timing signature" + ); + + accept_task.await.unwrap(); + let _ = std::fs::remove_file(sock_path); +} + +#[cfg(unix)] +#[tokio::test] +async fn unix_socket_proxy_protocol_v1_header_is_sent_before_probe() { + let sock_path = format!( + "/tmp/telemt-mask-unix-v1-{}-{}.sock", + std::process::id(), + rand::random::() + ); + let _ = std::fs::remove_file(&sock_path); + + let listener = UnixListener::bind(&sock_path).unwrap(); + let probe = b"GET /unix-v1 HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec(); + let backend_reply = b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + let backend_reply = backend_reply.clone(); + async move { + let (stream, _) = listener.accept().await.unwrap(); + let mut reader = BufReader::new(stream); + + let mut header_line = Vec::new(); + reader.read_until(b'\n', &mut header_line).await.unwrap(); + let header_text = String::from_utf8(header_line).unwrap(); + assert!(header_text.starts_with("PROXY "), "must start with PROXY prefix"); + assert!(header_text.ends_with("\r\n"), "v1 header must end with CRLF"); + + let mut received_probe = vec![0u8; probe.len()]; + reader.read_exact(&mut received_probe).await.unwrap(); + assert_eq!(received_probe, probe); + + let mut stream = reader.into_inner(); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = Some(sock_path.clone()); + config.censorship.mask_proxy_protocol = 1; + + let peer: SocketAddr = "203.0.113.51:51010".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (client_reader, _client_writer) = duplex(256); + let (mut client_visible_reader, client_visible_writer) = duplex(2048); + + let beobachten = BeobachtenStore::new(); + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + let mut observed = vec![0u8; backend_reply.len()]; + client_visible_reader.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, backend_reply); + + accept_task.await.unwrap(); + let _ = std::fs::remove_file(sock_path); +} + +#[cfg(unix)] +#[tokio::test] +async fn unix_socket_proxy_protocol_v2_header_is_sent_before_probe() { + let sock_path = format!( + "/tmp/telemt-mask-unix-v2-{}-{}.sock", + std::process::id(), + rand::random::() + ); + let _ = std::fs::remove_file(&sock_path); + + let listener = UnixListener::bind(&sock_path).unwrap(); + let probe = b"GET /unix-v2 HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec(); + let backend_reply = b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut sig = [0u8; 12]; + stream.read_exact(&mut sig).await.unwrap(); + assert_eq!(&sig, b"\r\n\r\n\0\r\nQUIT\n", "v2 signature must match spec"); + + let mut fixed = [0u8; 4]; + stream.read_exact(&mut fixed).await.unwrap(); + let addr_len = u16::from_be_bytes([fixed[2], fixed[3]]) as usize; + let mut addr_block = vec![0u8; addr_len]; + stream.read_exact(&mut addr_block).await.unwrap(); + + let mut received_probe = vec![0u8; probe.len()]; + stream.read_exact(&mut received_probe).await.unwrap(); + assert_eq!(received_probe, probe); + + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = Some(sock_path.clone()); + config.censorship.mask_proxy_protocol = 2; + + let peer: SocketAddr = "203.0.113.52:51011".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (client_reader, _client_writer) = duplex(256); + let (mut client_visible_reader, client_visible_writer) = duplex(2048); + + let beobachten = BeobachtenStore::new(); + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + let mut observed = vec![0u8; backend_reply.len()]; + client_visible_reader.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, backend_reply); + + accept_task.await.unwrap(); + let _ = std::fs::remove_file(sock_path); +} + +#[tokio::test] +async fn mask_disabled_fast_eof_not_shaped_by_mask_budget() { + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = false; + + let peer: SocketAddr = "203.0.113.14:42428".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (client_writer_side, client_reader) = duplex(256); + drop(client_writer_side); + let (_client_visible_reader, client_visible_writer) = duplex(256); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + b"x", + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + assert!( + started.elapsed() < Duration::from_millis(20), + "mask-disabled fallback should keep immediate EOF behavior" + ); +} + +#[tokio::test] +async fn backend_reachable_slow_response_not_padded_twice() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = b"GET /slow HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec(); + let backend_reply = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK".to_vec(); + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut received = vec![0u8; probe.len()]; + stream.read_exact(&mut received).await.unwrap(); + assert_eq!(received, probe); + sleep(Duration::from_millis(90)).await; + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let peer: SocketAddr = "203.0.113.15:42429".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (client_writer_side, client_reader) = duplex(256); + drop(client_writer_side); + let (_client_visible_reader, client_visible_writer) = duplex(512); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + let elapsed = started.elapsed(); + + assert!(elapsed >= Duration::from_millis(85)); + assert!( + elapsed < Duration::from_millis(170), + "slow reachable backend should not incur an extra full budget after already exceeding it" + ); + accept_task.await.unwrap(); +} + +#[tokio::test] +async fn adversarial_enabled_refused_and_reachable_collapse_to_same_bucket() { + const ITER: usize = 20; + const BUCKET_MS: u128 = 10; + + let probe = b"GET /collapse HTTP/1.1\r\nHost: x\r\n\r\n"; + let peer: SocketAddr = "203.0.113.16:42430".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let mut refused = Vec::with_capacity(ITER); + for _ in 0..ITER { + let temp_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let unused_port = temp_listener.local_addr().unwrap().port(); + drop(temp_listener); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = unused_port; + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let (client_writer_side, client_reader) = duplex(256); + drop(client_writer_side); + let (_client_visible_reader, client_visible_writer) = duplex(256); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + refused.push(started.elapsed().as_millis()); + } + + let mut reachable = Vec::with_capacity(ITER); + for _ in 0..ITER { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe_vec = probe.to_vec(); + let backend_reply = b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut received = vec![0u8; probe_vec.len()]; + stream.read_exact(&mut received).await.unwrap(); + stream.write_all(&backend_reply).await.unwrap(); + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let (client_writer_side, client_reader) = duplex(256); + drop(client_writer_side); + let (_client_visible_reader, client_visible_writer) = duplex(256); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + reachable.push(started.elapsed().as_millis()); + accept_task.await.unwrap(); + } + + let refused_mean = refused.iter().copied().sum::() as f64 / refused.len() as f64; + let reachable_mean = reachable.iter().copied().sum::() as f64 / reachable.len() as f64; + let refused_bucket = (refused_mean as u128) / BUCKET_MS; + let reachable_bucket = (reachable_mean as u128) / BUCKET_MS; + + assert!( + refused_bucket.abs_diff(reachable_bucket) <= 1, + "enabled refused and reachable paths must collapse into the same coarse latency bucket" + ); +} + +#[tokio::test] +async fn light_fuzz_mask_enabled_outcomes_preserve_coarse_budget() { + let mut seed: u64 = 0xA5A5_5A5A_1337_4242; + let mut next = || { + seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1); + seed + }; + + let peer: SocketAddr = "203.0.113.17:42431".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + for _ in 0..40 { + let probe_len = (next() as usize % 96).saturating_add(8); + let mut probe = vec![0u8; probe_len]; + for byte in &mut probe { + *byte = next() as u8; + } + + let use_reachable = (next() & 1) == 0; + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let (client_writer_side, client_reader) = duplex(512); + drop(client_writer_side); + let (_client_visible_reader, client_visible_writer) = duplex(512); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + if use_reachable { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + + let probe_vec = probe.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut observed = vec![0u8; probe_vec.len()]; + stream.read_exact(&mut observed).await.unwrap(); + }); + + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + accept_task.await.unwrap(); + } else { + let temp_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let unused_port = temp_listener.local_addr().unwrap().port(); + drop(temp_listener); + + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = unused_port; + + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + } + + assert!( + started.elapsed() >= Duration::from_millis(45), + "mask-enabled fallback must preserve coarse timing budget under varied probe shapes" + ); + } +} + +#[tokio::test] +async fn mask_disabled_consumes_client_data_without_response() { + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = false; + + let peer: SocketAddr = "198.51.100.12:45454".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let initial = b"scanner"; + + let (mut client_reader_side, client_reader) = duplex(256); + let (mut client_visible_reader, client_visible_writer) = duplex(256); + let beobachten = BeobachtenStore::new(); + + let task = tokio::spawn(async move { + handle_bad_client( + client_reader, + client_visible_writer, + initial, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + }); + + client_reader_side.write_all(b"untrusted payload").await.unwrap(); + drop(client_reader_side); + + timeout(Duration::from_secs(3), task).await.unwrap().unwrap(); + + let mut buf = [0u8; 1]; + let n = timeout(Duration::from_secs(1), client_visible_reader.read(&mut buf)) + .await + .unwrap() + .unwrap(); + assert_eq!(n, 0); +} + +#[tokio::test] +async fn proxy_protocol_v1_header_is_sent_before_probe() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = b"GET / HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec(); + let backend_reply = b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + let backend_reply = backend_reply.clone(); + async move { + let (stream, _) = listener.accept().await.unwrap(); + let mut reader = BufReader::new(stream); + + let mut header_line = Vec::new(); + reader.read_until(b'\n', &mut header_line).await.unwrap(); + let header_text = String::from_utf8(header_line.clone()).unwrap(); + assert!(header_text.starts_with("PROXY TCP4 ")); + assert!(header_text.ends_with("\r\n")); + + let mut received_probe = vec![0u8; probe.len()]; + reader.read_exact(&mut received_probe).await.unwrap(); + assert_eq!(received_probe, probe); + + let mut stream = reader.into_inner(); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 1; + + let peer: SocketAddr = "203.0.113.15:50001".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (client_reader, _client_writer) = duplex(256); + let (mut client_visible_reader, client_visible_writer) = duplex(2048); + + let beobachten = BeobachtenStore::new(); + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + let mut observed = vec![0u8; backend_reply.len()]; + client_visible_reader.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, backend_reply); + accept_task.await.unwrap(); +} + +#[tokio::test] +async fn proxy_protocol_v2_header_is_sent_before_probe() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = b"GET / HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec(); + let backend_reply = b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut sig = [0u8; 12]; + stream.read_exact(&mut sig).await.unwrap(); + assert_eq!(&sig, b"\r\n\r\n\0\r\nQUIT\n"); + + let mut fixed = [0u8; 4]; + stream.read_exact(&mut fixed).await.unwrap(); + let addr_len = u16::from_be_bytes([fixed[2], fixed[3]]) as usize; + + let mut addr_block = vec![0u8; addr_len]; + stream.read_exact(&mut addr_block).await.unwrap(); + + let mut received_probe = vec![0u8; probe.len()]; + stream.read_exact(&mut received_probe).await.unwrap(); + assert_eq!(received_probe, probe); + + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 2; + + let peer: SocketAddr = "203.0.113.18:50004".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (client_reader, _client_writer) = duplex(256); + let (mut client_visible_reader, client_visible_writer) = duplex(2048); + + let beobachten = BeobachtenStore::new(); + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + let mut observed = vec![0u8; backend_reply.len()]; + client_visible_reader.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, backend_reply); + accept_task.await.unwrap(); +} + +#[tokio::test] +async fn proxy_protocol_v1_mixed_family_falls_back_to_unknown_header() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = b"GET /mix HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec(); + let backend_reply = b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + let backend_reply = backend_reply.clone(); + async move { + let (stream, _) = listener.accept().await.unwrap(); + let mut reader = BufReader::new(stream); + + let mut header_line = Vec::new(); + reader.read_until(b'\n', &mut header_line).await.unwrap(); + let header_text = String::from_utf8(header_line).unwrap(); + assert_eq!(header_text, "PROXY UNKNOWN\r\n"); + + let mut received_probe = vec![0u8; probe.len()]; + reader.read_exact(&mut received_probe).await.unwrap(); + assert_eq!(received_probe, probe); + + let mut stream = reader.into_inner(); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 1; + + let peer: SocketAddr = "203.0.113.20:50006".parse().unwrap(); + let local_addr: SocketAddr = "[::1]:443".parse().unwrap(); + + let (client_reader, _client_writer) = duplex(256); + let (mut client_visible_reader, client_visible_writer) = duplex(2048); + + let beobachten = BeobachtenStore::new(); + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + let mut observed = vec![0u8; backend_reply.len()]; + client_visible_reader.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, backend_reply); + accept_task.await.unwrap(); +} + +#[cfg(unix)] +#[tokio::test] +async fn unix_socket_mask_path_forwards_probe_and_response() { + let sock_path = format!("/tmp/telemt-mask-test-{}-{}.sock", std::process::id(), rand::random::()); + let _ = std::fs::remove_file(&sock_path); + + let listener = UnixListener::bind(&sock_path).unwrap(); + let probe = b"GET /unix HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec(); + let backend_reply = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK".to_vec(); + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut received = vec![0u8; probe.len()]; + stream.read_exact(&mut received).await.unwrap(); + assert_eq!(received, probe); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = Some(sock_path.clone()); + config.censorship.mask_proxy_protocol = 0; + + let peer: SocketAddr = "203.0.113.30:50010".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (client_reader, _client_writer) = duplex(256); + let (mut client_visible_reader, client_visible_writer) = duplex(2048); + + let beobachten = BeobachtenStore::new(); + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + let mut observed = vec![0u8; backend_reply.len()]; + client_visible_reader.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, backend_reply); + + accept_task.await.unwrap(); + let _ = std::fs::remove_file(sock_path); +} + +#[tokio::test] +async fn mask_disabled_slowloris_connection_is_closed_by_consume_timeout() { + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = false; + + let peer: SocketAddr = "198.51.100.33:45455".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (_client_reader_side, client_reader) = duplex(256); + let (_client_visible_reader, client_visible_writer) = duplex(256); + let beobachten = BeobachtenStore::new(); + + let task = tokio::spawn(async move { + handle_bad_client( + client_reader, + client_visible_writer, + b"slowloris", + peer, + local_addr, + &config, + &beobachten, + ) + .await; + }); + + timeout(Duration::from_secs(1), task).await.unwrap().unwrap(); +} + +#[tokio::test] +async fn mask_enabled_idle_relay_is_closed_by_idle_timeout_before_global_relay_timeout() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = b"GET /idle HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut received = vec![0u8; probe.len()]; + stream.read_exact(&mut received).await.unwrap(); + assert_eq!(received, probe); + sleep(Duration::from_millis(300)).await; + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let peer: SocketAddr = "198.51.100.34:45456".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (_client_reader_side, client_reader) = duplex(512); + let (_client_visible_reader, client_visible_writer) = duplex(512); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + let elapsed = started.elapsed(); + assert!( + elapsed < Duration::from_millis(150), + "idle unauth relay must terminate on idle timeout instead of waiting for full relay timeout" + ); + + accept_task.await.unwrap(); +} + +struct PendingWriter; + +impl tokio::io::AsyncWrite for PendingWriter { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _buf: &[u8], + ) -> Poll> { + Poll::Pending + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} + +struct DropTrackedPendingReader { + dropped: Arc, +} + +impl tokio::io::AsyncRead for DropTrackedPendingReader { + fn poll_read( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + Poll::Pending + } +} + +impl Drop for DropTrackedPendingReader { + fn drop(&mut self) { + self.dropped.store(true, Ordering::SeqCst); + } +} + +struct DropTrackedPendingWriter { + dropped: Arc, +} + +impl tokio::io::AsyncWrite for DropTrackedPendingWriter { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _buf: &[u8], + ) -> Poll> { + Poll::Pending + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} + +impl Drop for DropTrackedPendingWriter { + fn drop(&mut self) { + self.dropped.store(true, Ordering::SeqCst); + } +} + +#[tokio::test] +async fn proxy_header_write_timeout_returns_false() { + let mut writer = PendingWriter; + let ok = write_proxy_header_with_timeout(&mut writer, b"PROXY UNKNOWN\r\n").await; + assert!(!ok, "Proxy header writes that never complete must time out"); +} + +#[tokio::test] +async fn relay_to_mask_keeps_backend_to_client_flow_when_client_to_backend_stalls() { + let (mut client_feed_writer, client_feed_reader) = duplex(64); + let (mut client_visible_reader, client_visible_writer) = duplex(64); + let (mut backend_feed_writer, backend_feed_reader) = duplex(64); + + // Make client->mask direction immediately active so the c2m path blocks on PendingWriter. + client_feed_writer.write_all(b"X").await.unwrap(); + + let relay = tokio::spawn(async move { + relay_to_mask( + client_feed_reader, + client_visible_writer, + backend_feed_reader, + PendingWriter, + b"", + false, + 0, + 0, + false, + 0, + ) + .await; + }); + + // Allow relay tasks to start, then emulate mask backend response. + sleep(Duration::from_millis(20)).await; + backend_feed_writer.write_all(b"HTTP/1.1 200 OK\r\n\r\n").await.unwrap(); + backend_feed_writer.shutdown().await.unwrap(); + + let mut observed = vec![0u8; 19]; + timeout(Duration::from_secs(1), client_visible_reader.read_exact(&mut observed)) + .await + .unwrap() + .unwrap(); + assert_eq!(observed, b"HTTP/1.1 200 OK\r\n\r\n"); + + relay.abort(); + let _ = relay.await; +} + +#[tokio::test] +async fn relay_to_mask_preserves_backend_response_after_client_half_close() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let request = b"GET / HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec(); + let response = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK".to_vec(); + + let backend_task = tokio::spawn({ + let request = request.clone(); + let response = response.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut observed_req = vec![0u8; request.len()]; + stream.read_exact(&mut observed_req).await.unwrap(); + assert_eq!(observed_req, request); + stream.write_all(&response).await.unwrap(); + stream.shutdown().await.unwrap(); + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let peer: SocketAddr = "203.0.113.77:55001".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (mut client_write, client_read) = duplex(1024); + let (mut client_visible_reader, client_visible_writer) = duplex(2048); + let beobachten = BeobachtenStore::new(); + + let fallback_task = tokio::spawn(async move { + handle_bad_client( + client_read, + client_visible_writer, + &request, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + }); + + client_write.shutdown().await.unwrap(); + + let mut observed_resp = vec![0u8; response.len()]; + timeout(Duration::from_secs(1), client_visible_reader.read_exact(&mut observed_resp)) + .await + .unwrap() + .unwrap(); + assert_eq!(observed_resp, response); + + timeout(Duration::from_secs(1), fallback_task).await.unwrap().unwrap(); + timeout(Duration::from_secs(1), backend_task).await.unwrap().unwrap(); +} + +#[tokio::test] +async fn relay_to_mask_timeout_cancels_and_drops_all_io_endpoints() { + let reader_dropped = Arc::new(AtomicBool::new(false)); + let writer_dropped = Arc::new(AtomicBool::new(false)); + let mask_reader_dropped = Arc::new(AtomicBool::new(false)); + let mask_writer_dropped = Arc::new(AtomicBool::new(false)); + + let reader = DropTrackedPendingReader { + dropped: reader_dropped.clone(), + }; + let writer = DropTrackedPendingWriter { + dropped: writer_dropped.clone(), + }; + let mask_read = DropTrackedPendingReader { + dropped: mask_reader_dropped.clone(), + }; + let mask_write = DropTrackedPendingWriter { + dropped: mask_writer_dropped.clone(), + }; + + let timed = timeout( + Duration::from_millis(40), + relay_to_mask( + reader, + writer, + mask_read, + mask_write, + b"", + false, + 0, + 0, + false, + 0, + ), + ) + .await; + + assert!(timed.is_err(), "stalled relay must be bounded by timeout"); + + assert!(reader_dropped.load(Ordering::SeqCst)); + assert!(writer_dropped.load(Ordering::SeqCst)); + assert!(mask_reader_dropped.load(Ordering::SeqCst)); + assert!(mask_writer_dropped.load(Ordering::SeqCst)); +} + +#[tokio::test] +#[ignore = "timing matrix; run manually with --ignored --nocapture"] +async fn timing_matrix_masking_classes_under_controlled_inputs() { + const ITER: usize = 24; + const BUCKET_MS: u128 = 10; + + let probe = b"GET /timing HTTP/1.1\r\nHost: x\r\n\r\n"; + let peer: SocketAddr = "203.0.113.40:51000".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + // Class 1: masking disabled with immediate EOF (fast fail-closed consume path). + let mut disabled_samples = Vec::with_capacity(ITER); + for _ in 0..ITER { + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = false; + + let (client_writer_side, client_reader) = duplex(256); + drop(client_writer_side); + let (_client_visible_reader, client_visible_writer) = duplex(256); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + disabled_samples.push(started.elapsed().as_millis()); + } + + // Class 2: masking enabled, backend connect refused. + let mut refused_samples = Vec::with_capacity(ITER); + for _ in 0..ITER { + let temp_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let unused_port = temp_listener.local_addr().unwrap().port(); + drop(temp_listener); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = unused_port; + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let (client_writer_side, client_reader) = duplex(256); + drop(client_writer_side); + let (_client_visible_reader, client_visible_writer) = duplex(256); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + refused_samples.push(started.elapsed().as_millis()); + } + + // Class 3: masking enabled, backend reachable and immediately responds. + let mut reachable_samples = Vec::with_capacity(ITER); + for _ in 0..ITER { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let backend_reply = b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\n\r\n".to_vec(); + let probe_vec = probe.to_vec(); + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut received = vec![0u8; probe_vec.len()]; + stream.read_exact(&mut received).await.unwrap(); + assert_eq!(received, probe_vec); + stream.write_all(&backend_reply).await.unwrap(); + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let (client_writer_side, client_reader) = duplex(256); + drop(client_writer_side); + let (_client_visible_reader, client_visible_writer) = duplex(256); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + reachable_samples.push(started.elapsed().as_millis()); + accept_task.await.unwrap(); + } + + fn summarize(samples_ms: &mut [u128]) -> (f64, u128, u128, u128) { + samples_ms.sort_unstable(); + let sum: u128 = samples_ms.iter().copied().sum(); + let mean = sum as f64 / samples_ms.len() as f64; + let min = samples_ms[0]; + let p95_idx = ((samples_ms.len() as f64) * 0.95).floor() as usize; + let p95 = samples_ms[p95_idx.min(samples_ms.len() - 1)]; + let max = samples_ms[samples_ms.len() - 1]; + (mean, min, p95, max) + } + + let (disabled_mean, disabled_min, disabled_p95, disabled_max) = summarize(&mut disabled_samples); + let (refused_mean, refused_min, refused_p95, refused_max) = summarize(&mut refused_samples); + let (reachable_mean, reachable_min, reachable_p95, reachable_max) = summarize(&mut reachable_samples); + + println!( + "TIMING_MATRIX masking class=disabled_eof mean_ms={:.2} min_ms={} p95_ms={} max_ms={} bucket_mean={}", + disabled_mean, + disabled_min, + disabled_p95, + disabled_max, + (disabled_mean as u128) / BUCKET_MS + ); + println!( + "TIMING_MATRIX masking class=enabled_refused_eof mean_ms={:.2} min_ms={} p95_ms={} max_ms={} bucket_mean={}", + refused_mean, + refused_min, + refused_p95, + refused_max, + (refused_mean as u128) / BUCKET_MS + ); + println!( + "TIMING_MATRIX masking class=enabled_reachable_eof mean_ms={:.2} min_ms={} p95_ms={} max_ms={} bucket_mean={}", + reachable_mean, + reachable_min, + reachable_p95, + reachable_max, + (reachable_mean as u128) / BUCKET_MS + ); +} + +#[tokio::test] +async fn backend_connect_refusal_completes_within_bounded_mask_budget() { + let temp_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let unused_port = temp_listener.local_addr().unwrap().port(); + drop(temp_listener); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = unused_port; + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let peer: SocketAddr = "203.0.113.41:51001".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let probe = b"GET /bounded HTTP/1.1\r\nHost: x\r\n\r\n"; + + let (_client_reader_side, client_reader) = duplex(256); + let (_client_visible_reader, client_visible_writer) = duplex(256); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + let elapsed = started.elapsed(); + assert!( + elapsed >= Duration::from_millis(45), + "connect refusal path must respect minimum masking budget" + ); + assert!( + elapsed < Duration::from_millis(500), + "connect refusal path must stay bounded and avoid unbounded stall" + ); +} + +#[tokio::test] +async fn reachable_backend_one_response_then_silence_is_cut_by_idle_timeout() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = b"GET /oneshot HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec(); + let response = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK".to_vec(); + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + let response = response.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut received = vec![0u8; probe.len()]; + stream.read_exact(&mut received).await.unwrap(); + assert_eq!(received, probe); + stream.write_all(&response).await.unwrap(); + sleep(Duration::from_millis(300)).await; + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let peer: SocketAddr = "203.0.113.42:51002".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (_client_reader_side, client_reader) = duplex(256); + let (mut client_visible_reader, client_visible_writer) = duplex(512); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + let elapsed = started.elapsed(); + + let mut observed = vec![0u8; response.len()]; + client_visible_reader.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, response); + assert!( + elapsed < Duration::from_millis(190), + "idle backend silence after first response must be cut by relay idle timeout" + ); + + accept_task.await.unwrap(); +} + +#[tokio::test] +async fn adversarial_client_drip_feed_longer_than_idle_timeout_is_cut_off() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let initial = b"GET /drip HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn({ + let initial = initial.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut observed = vec![0u8; initial.len()]; + stream.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, initial); + + let mut extra = [0u8; 1]; + let read_res = timeout(Duration::from_millis(220), stream.read_exact(&mut extra)).await; + assert!( + read_res.is_err() || read_res.unwrap().is_err(), + "drip-fed post-probe byte arriving after idle timeout should not be forwarded" + ); + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let peer: SocketAddr = "203.0.113.43:51003".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (mut client_writer_side, client_reader) = duplex(256); + let (_client_visible_reader, client_visible_writer) = duplex(256); + let beobachten = BeobachtenStore::new(); + + let relay_task = tokio::spawn(async move { + handle_bad_client( + client_reader, + client_visible_writer, + &initial, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + }); + + sleep(Duration::from_millis(160)).await; + let _ = client_writer_side.write_all(b"X").await; + drop(client_writer_side); + + timeout(Duration::from_secs(1), relay_task).await.unwrap().unwrap(); + accept_task.await.unwrap(); +} diff --git a/src/proxy/tests/masking_shape_above_cap_blur_security_tests.rs b/src/proxy/tests/masking_shape_above_cap_blur_security_tests.rs new file mode 100644 index 0000000..d2d522f --- /dev/null +++ b/src/proxy/tests/masking_shape_above_cap_blur_security_tests.rs @@ -0,0 +1,102 @@ +use super::*; +use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpListener; +use tokio::time::Duration; + +async fn capture_forwarded_len( + body_sent: usize, + shape_hardening: bool, + above_cap_blur: bool, + above_cap_blur_max_bytes: usize, +) -> usize { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_shape_hardening = shape_hardening; + config.censorship.mask_shape_bucket_floor_bytes = 512; + config.censorship.mask_shape_bucket_cap_bytes = 4096; + config.censorship.mask_shape_above_cap_blur = above_cap_blur; + config.censorship.mask_shape_above_cap_blur_max_bytes = above_cap_blur_max_bytes; + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = Vec::new(); + let _ = tokio::time::timeout(Duration::from_secs(2), stream.read_to_end(&mut got)).await; + got.len() + }); + + let (server_reader, mut client_writer) = duplex(64 * 1024); + let (_client_visible_reader, client_visible_writer) = duplex(64 * 1024); + let peer: SocketAddr = "198.51.100.220:57120".parse().unwrap(); + let local: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + let mut probe = vec![0u8; 5 + body_sent]; + probe[0] = 0x16; + probe[1] = 0x03; + probe[2] = 0x01; + probe[3..5].copy_from_slice(&7000u16.to_be_bytes()); + probe[5..].fill(0x5A); + + let fallback = tokio::spawn(async move { + handle_bad_client( + server_reader, + client_visible_writer, + &probe, + peer, + local, + &config, + &beobachten, + ) + .await; + }); + + client_writer.shutdown().await.unwrap(); + let _ = tokio::time::timeout(Duration::from_secs(4), fallback) + .await + .unwrap() + .unwrap(); + + tokio::time::timeout(Duration::from_secs(4), accept_task) + .await + .unwrap() + .unwrap() +} + +#[tokio::test] +async fn above_cap_blur_disabled_keeps_exact_above_cap_length() { + let body_sent = 5000usize; + let observed = capture_forwarded_len(body_sent, true, false, 0).await; + assert_eq!(observed, 5 + body_sent); +} + +#[tokio::test] +async fn above_cap_blur_enabled_adds_bounded_random_tail() { + let body_sent = 5000usize; + let base = 5 + body_sent; + let max_extra = 64usize; + + let mut saw_extra = false; + for _ in 0..20 { + let observed = capture_forwarded_len(body_sent, true, true, max_extra).await; + assert!(observed >= base, "observed={observed} base={base}"); + assert!( + observed <= base + max_extra, + "observed={observed} base={} max_extra={max_extra}", + base + ); + if observed > base { + saw_extra = true; + } + } + + assert!( + saw_extra, + "at least one run should produce above-cap blur bytes under randomization" + ); +} diff --git a/src/proxy/tests/masking_shape_classifier_resistance_adversarial_tests.rs b/src/proxy/tests/masking_shape_classifier_resistance_adversarial_tests.rs new file mode 100644 index 0000000..9e8c5b7 --- /dev/null +++ b/src/proxy/tests/masking_shape_classifier_resistance_adversarial_tests.rs @@ -0,0 +1,324 @@ +use super::*; +use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpListener; +use tokio::time::Duration; + +async fn capture_forwarded_len( + body_sent: usize, + shape_hardening: bool, + above_cap_blur: bool, + above_cap_blur_max_bytes: usize, +) -> usize { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_shape_hardening = shape_hardening; + config.censorship.mask_shape_bucket_floor_bytes = 512; + config.censorship.mask_shape_bucket_cap_bytes = 4096; + config.censorship.mask_shape_above_cap_blur = above_cap_blur; + config.censorship.mask_shape_above_cap_blur_max_bytes = above_cap_blur_max_bytes; + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = Vec::new(); + let _ = tokio::time::timeout(Duration::from_secs(2), stream.read_to_end(&mut got)).await; + got.len() + }); + + let (client_reader, mut client_writer) = duplex(64 * 1024); + let (_client_visible_reader, client_visible_writer) = duplex(64 * 1024); + + let mut initial = vec![0u8; 5 + body_sent]; + initial[0] = 0x16; + initial[1] = 0x03; + initial[2] = 0x01; + initial[3..5].copy_from_slice(&7000u16.to_be_bytes()); + initial[5..].fill(0x5A); + + let peer: SocketAddr = "198.51.100.250:57450".parse().unwrap(); + let local: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + let fallback = tokio::spawn(async move { + handle_bad_client( + client_reader, + client_visible_writer, + &initial, + peer, + local, + &config, + &beobachten, + ) + .await; + }); + + client_writer.shutdown().await.unwrap(); + let _ = tokio::time::timeout(Duration::from_secs(3), fallback) + .await + .unwrap() + .unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap() +} + +fn best_threshold_accuracy(a: &[usize], b: &[usize]) -> f64 { + let min_v = *a.iter().chain(b.iter()).min().unwrap(); + let max_v = *a.iter().chain(b.iter()).max().unwrap(); + + let mut best = 0.0f64; + for t in min_v..=max_v { + let correct_a = a.iter().filter(|&&x| x <= t).count(); + let correct_b = b.iter().filter(|&&x| x > t).count(); + let acc = (correct_a + correct_b) as f64 / (a.len() + b.len()) as f64; + if acc > best { + best = acc; + } + } + best +} + +fn nearest_centroid_classifier_accuracy( + samples_a: &[usize], + samples_b: &[usize], + samples_c: &[usize], +) -> f64 { + let mean = |xs: &[usize]| -> f64 { + xs.iter().copied().sum::() as f64 / xs.len() as f64 + }; + + let ca = mean(samples_a); + let cb = mean(samples_b); + let cc = mean(samples_c); + + let mut correct = 0usize; + let mut total = 0usize; + + for &x in samples_a { + total += 1; + let xf = x as f64; + let d = [ + (xf - ca).abs(), + (xf - cb).abs(), + (xf - cc).abs(), + ]; + if d[0] <= d[1] && d[0] <= d[2] { + correct += 1; + } + } + + for &x in samples_b { + total += 1; + let xf = x as f64; + let d = [ + (xf - ca).abs(), + (xf - cb).abs(), + (xf - cc).abs(), + ]; + if d[1] <= d[0] && d[1] <= d[2] { + correct += 1; + } + } + + for &x in samples_c { + total += 1; + let xf = x as f64; + let d = [ + (xf - ca).abs(), + (xf - cb).abs(), + (xf - cc).abs(), + ]; + if d[2] <= d[0] && d[2] <= d[1] { + correct += 1; + } + } + + correct as f64 / total as f64 +} + +#[tokio::test] +async fn masking_shape_classifier_resistance_blur_reduces_threshold_attack_accuracy() { + const SAMPLES: usize = 120; + const MAX_EXTRA: usize = 96; + const CLASS_A_BODY: usize = 5000; + const CLASS_B_BODY: usize = 5040; + + let mut baseline_a = Vec::with_capacity(SAMPLES); + let mut baseline_b = Vec::with_capacity(SAMPLES); + let mut hardened_a = Vec::with_capacity(SAMPLES); + let mut hardened_b = Vec::with_capacity(SAMPLES); + + for _ in 0..SAMPLES { + baseline_a.push(capture_forwarded_len(CLASS_A_BODY, true, false, 0).await); + baseline_b.push(capture_forwarded_len(CLASS_B_BODY, true, false, 0).await); + hardened_a.push(capture_forwarded_len(CLASS_A_BODY, true, true, MAX_EXTRA).await); + hardened_b.push(capture_forwarded_len(CLASS_B_BODY, true, true, MAX_EXTRA).await); + } + + let baseline_acc = best_threshold_accuracy(&baseline_a, &baseline_b); + let hardened_acc = best_threshold_accuracy(&hardened_a, &hardened_b); + + // Baseline classes are deterministic/non-overlapping -> near-perfect threshold attack. + assert!(baseline_acc >= 0.99, "baseline separability unexpectedly low: {baseline_acc:.3}"); + // Blur must materially reduce the best one-dimensional length classifier. + assert!( + hardened_acc <= 0.90, + "blur should degrade threshold attack accuracy, got {hardened_acc:.3}" + ); + assert!( + hardened_acc <= baseline_acc - 0.08, + "blur must reduce threshold accuracy by a meaningful margin: baseline={baseline_acc:.3}, hardened={hardened_acc:.3}" + ); +} + +#[tokio::test] +async fn masking_shape_classifier_resistance_blur_increases_cross_class_overlap() { + const SAMPLES: usize = 96; + const MAX_EXTRA: usize = 96; + const CLASS_A_BODY: usize = 5000; + const CLASS_B_BODY: usize = 5040; + + let mut baseline_a = std::collections::BTreeSet::new(); + let mut baseline_b = std::collections::BTreeSet::new(); + let mut hardened_a = std::collections::BTreeSet::new(); + let mut hardened_b = std::collections::BTreeSet::new(); + + for _ in 0..SAMPLES { + baseline_a.insert(capture_forwarded_len(CLASS_A_BODY, true, false, 0).await); + baseline_b.insert(capture_forwarded_len(CLASS_B_BODY, true, false, 0).await); + hardened_a.insert(capture_forwarded_len(CLASS_A_BODY, true, true, MAX_EXTRA).await); + hardened_b.insert(capture_forwarded_len(CLASS_B_BODY, true, true, MAX_EXTRA).await); + } + + let baseline_overlap = baseline_a.intersection(&baseline_b).count(); + let hardened_overlap = hardened_a.intersection(&hardened_b).count(); + + assert_eq!(baseline_overlap, 0, "baseline classes should not overlap"); + assert!( + hardened_overlap >= 8, + "blur should create meaningful overlap between classes, got overlap={hardened_overlap}" + ); +} + +#[tokio::test] +async fn masking_shape_classifier_resistance_parallel_probe_campaign_keeps_blur_bounds() { + const MAX_EXTRA: usize = 128; + + let mut tasks = Vec::new(); + for i in 0..64usize { + tasks.push(tokio::spawn(async move { + let body = 4300 + (i % 700); + let observed = capture_forwarded_len(body, true, true, MAX_EXTRA).await; + let base = 5 + body; + assert!( + observed >= base && observed <= base + MAX_EXTRA, + "campaign bounds violated for i={i}: observed={observed} base={base}" + ); + })); + } + + for task in tasks { + tokio::time::timeout(Duration::from_secs(3), task) + .await + .unwrap() + .unwrap(); + } +} + +#[tokio::test] +async fn masking_shape_classifier_resistance_edge_max_extra_one_has_two_point_support() { + const BODY: usize = 5000; + const BASE: usize = 5 + BODY; + + let mut seen = std::collections::BTreeSet::new(); + for _ in 0..64 { + let observed = capture_forwarded_len(BODY, true, true, 1).await; + assert!( + observed == BASE || observed == BASE + 1, + "max_extra=1 must only produce two-point support" + ); + seen.insert(observed); + } + + assert_eq!(seen.len(), 2, "both support points should appear under repeated sampling"); +} + +#[tokio::test] +async fn masking_shape_classifier_resistance_negative_blur_without_shape_hardening_is_noop() { + const BODY_A: usize = 5000; + const BODY_B: usize = 5040; + + let mut as_observed = std::collections::BTreeSet::new(); + let mut bs_observed = std::collections::BTreeSet::new(); + for _ in 0..48 { + as_observed.insert(capture_forwarded_len(BODY_A, false, true, 96).await); + bs_observed.insert(capture_forwarded_len(BODY_B, false, true, 96).await); + } + + assert_eq!(as_observed.len(), 1, "without shape hardening class A must stay deterministic"); + assert_eq!(bs_observed.len(), 1, "without shape hardening class B must stay deterministic"); + assert_ne!(as_observed, bs_observed, "distinct classes should remain separable without shaping"); +} + +#[tokio::test] +async fn masking_shape_classifier_resistance_adversarial_three_class_centroid_attack_degrades_with_blur() { + const SAMPLES: usize = 80; + const MAX_EXTRA: usize = 96; + const C1: usize = 5000; + const C2: usize = 5040; + const C3: usize = 5080; + + let mut base1 = Vec::with_capacity(SAMPLES); + let mut base2 = Vec::with_capacity(SAMPLES); + let mut base3 = Vec::with_capacity(SAMPLES); + let mut hard1 = Vec::with_capacity(SAMPLES); + let mut hard2 = Vec::with_capacity(SAMPLES); + let mut hard3 = Vec::with_capacity(SAMPLES); + + for _ in 0..SAMPLES { + base1.push(capture_forwarded_len(C1, true, false, 0).await); + base2.push(capture_forwarded_len(C2, true, false, 0).await); + base3.push(capture_forwarded_len(C3, true, false, 0).await); + + hard1.push(capture_forwarded_len(C1, true, true, MAX_EXTRA).await); + hard2.push(capture_forwarded_len(C2, true, true, MAX_EXTRA).await); + hard3.push(capture_forwarded_len(C3, true, true, MAX_EXTRA).await); + } + + let base_acc = nearest_centroid_classifier_accuracy(&base1, &base2, &base3); + let hard_acc = nearest_centroid_classifier_accuracy(&hard1, &hard2, &hard3); + + assert!(base_acc >= 0.99, "baseline centroid separability should be near-perfect"); + assert!(hard_acc <= 0.88, "blur should materially degrade 3-class centroid attack"); + assert!(hard_acc <= base_acc - 0.1, "accuracy drop should be meaningful"); +} + +#[tokio::test] +async fn masking_shape_classifier_resistance_light_fuzz_bounds_hold_for_randomized_above_cap_campaign() { + let mut s: u64 = 0xDEAD_BEEF_CAFE_BABE; + for _ in 0..96 { + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + let body = 4097 + (s as usize % 2048); + + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + let max_extra = 1 + (s as usize % 128); + + let observed = capture_forwarded_len(body, true, true, max_extra).await; + let base = 5 + body; + assert!( + observed >= base && observed <= base + max_extra, + "fuzz bounds violated: body={body} observed={observed} max_extra={max_extra}" + ); + } +} diff --git a/src/proxy/tests/masking_shape_guard_adversarial_tests.rs b/src/proxy/tests/masking_shape_guard_adversarial_tests.rs new file mode 100644 index 0000000..fc0b0b8 --- /dev/null +++ b/src/proxy/tests/masking_shape_guard_adversarial_tests.rs @@ -0,0 +1,371 @@ +use super::*; +use tokio::io::{duplex, empty, sink, AsyncReadExt, AsyncWriteExt}; +use tokio::time::{sleep, timeout, Duration}; + +fn oracle_len( + total_sent: usize, + shape_enabled: bool, + ended_by_eof: bool, + initial_len: usize, + floor: usize, + cap: usize, +) -> usize { + if shape_enabled && ended_by_eof && initial_len > 0 { + next_mask_shape_bucket(total_sent, floor, cap) + } else { + total_sent + } +} + +async fn run_relay_case( + initial: Vec, + extra: Vec, + close_client: bool, + shape_enabled: bool, + floor: usize, + cap: usize, + above_cap_blur: bool, + above_cap_blur_max_bytes: usize, +) -> Vec { + let (client_reader, mut client_writer) = duplex(8192); + let (mut mask_observer, mask_writer) = duplex(8192); + + let relay = tokio::spawn(async move { + relay_to_mask( + client_reader, + sink(), + empty(), + mask_writer, + &initial, + shape_enabled, + floor, + cap, + above_cap_blur, + above_cap_blur_max_bytes, + ) + .await; + }); + + if !extra.is_empty() { + client_writer.write_all(&extra).await.unwrap(); + } + + if close_client { + client_writer.shutdown().await.unwrap(); + } + + timeout(Duration::from_secs(2), relay).await.unwrap().unwrap(); + + if !close_client { + drop(client_writer); + } + + let mut observed = Vec::new(); + timeout(Duration::from_secs(2), mask_observer.read_to_end(&mut observed)) + .await + .unwrap() + .unwrap(); + observed +} + +#[tokio::test] +async fn masking_shape_guard_negative_timeout_path_never_shapes_even_with_blur_enabled() { + let initial = b"GET /timeout-path HTTP/1.1\r\n".to_vec(); + let extra = vec![0xCC; 700]; + let total = initial.len() + extra.len(); + + let observed = run_relay_case( + initial.clone(), + extra.clone(), + false, + true, + 512, + 4096, + true, + 1024, + ) + .await; + + assert_eq!(observed.len(), total, "timeout path must stay unshaped"); + assert_eq!(&observed[..initial.len()], initial.as_slice()); + assert_eq!(&observed[initial.len()..], extra.as_slice()); +} + +#[tokio::test] +async fn masking_shape_guard_positive_clean_eof_path_shapes_and_preserves_prefix() { + let initial = b"GET /ok HTTP/1.1\r\n".to_vec(); + let extra = vec![0x55; 300]; + let total = initial.len() + extra.len(); + + let observed = run_relay_case(initial.clone(), extra.clone(), true, true, 512, 4096, false, 0).await; + + let expected_len = oracle_len(total, true, true, initial.len(), 512, 4096); + assert_eq!(observed.len(), expected_len, "clean EOF path must be bucket-shaped"); + assert_eq!(&observed[..initial.len()], initial.as_slice()); + assert_eq!(&observed[initial.len()..(initial.len() + extra.len())], extra.as_slice()); +} + +#[tokio::test] +async fn masking_shape_guard_edge_empty_initial_remains_transparent_under_clean_eof() { + let initial = Vec::new(); + let extra = vec![0xA1; 257]; + + let observed = run_relay_case(initial, extra.clone(), true, true, 512, 4096, false, 0).await; + + assert_eq!(observed.len(), extra.len(), "empty initial_data must never trigger shaping"); + assert_eq!(observed, extra); +} + +#[tokio::test] +async fn masking_shape_guard_light_fuzz_oracle_matches_for_eof_and_timeout_variants() { + let floor = 512usize; + let cap = 4096usize; + + // Deterministic xorshift to keep this fuzz test stable in CI. + let mut s: u64 = 0x9E37_79B9_7F4A_7C15; + for _ in 0..96 { + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + let initial_len = (s as usize) % 48; + + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + let extra_len = (s as usize) % 1800; + + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + let close_client = (s & 1) == 0; + + let initial = vec![0x42; initial_len]; + let extra = vec![0x99; extra_len]; + let total = initial_len + extra_len; + + let observed = run_relay_case( + initial.clone(), + extra.clone(), + close_client, + true, + floor, + cap, + false, + 0, + ) + .await; + + let expected = oracle_len(total, true, close_client, initial_len, floor, cap); + assert_eq!( + observed.len(), + expected, + "oracle mismatch: initial_len={initial_len} extra_len={extra_len} close_client={close_client}" + ); + + if initial_len > 0 { + assert_eq!(&observed[..initial_len], initial.as_slice()); + } + if extra_len > 0 { + assert_eq!( + &observed[initial_len..(initial_len + extra_len)], + extra.as_slice(), + "payload prefix must remain byte-for-byte before any optional shaping tail" + ); + } + } +} + +#[tokio::test] +async fn masking_shape_guard_stress_parallel_mixed_sessions_keep_oracle_and_no_hangs() { + let mut tasks = Vec::new(); + + for i in 0..48usize { + tasks.push(tokio::spawn(async move { + let initial_len = if i % 3 == 0 { 0 } else { 5 + (i % 19) }; + let extra_len = 64 + (i * 37 % 1300); + let close_client = i % 2 == 0; + + let initial = vec![i as u8; initial_len]; + let extra = vec![0xE0 | ((i as u8) & 0x0F); extra_len]; + let total = initial_len + extra_len; + + let observed = run_relay_case( + initial.clone(), + extra.clone(), + close_client, + true, + 512, + 4096, + false, + 0, + ) + .await; + + let expected = oracle_len(total, true, close_client, initial_len, 512, 4096); + assert_eq!( + observed.len(), + expected, + "stress oracle mismatch for worker={i} close_client={close_client}" + ); + + if initial_len > 0 { + assert_eq!(&observed[..initial_len], initial.as_slice()); + } + if extra_len > 0 { + assert_eq!(&observed[initial_len..(initial_len + extra_len)], extra.as_slice()); + } + })); + } + + for task in tasks { + timeout(Duration::from_secs(3), task).await.unwrap().unwrap(); + } +} + +#[tokio::test] +async fn masking_shape_guard_integration_slow_drip_timeout_is_cut_without_tail_leak() { + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let initial = b"GET /drip-guard HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn({ + let initial = initial.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut observed = vec![0u8; initial.len()]; + stream.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, initial); + + let mut one = [0u8; 1]; + let r = timeout(Duration::from_millis(220), stream.read_exact(&mut one)).await; + assert!(r.is_err() || r.unwrap().is_err(), "no post-timeout drip/tail may reach backend"); + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_shape_hardening = true; + config.censorship.mask_shape_bucket_floor_bytes = 512; + config.censorship.mask_shape_bucket_cap_bytes = 4096; + + let peer: SocketAddr = "198.51.100.245:53101".parse().unwrap(); + let local: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (mut client_writer, client_reader) = duplex(1024); + let (_client_visible_reader, client_visible_writer) = duplex(1024); + let beobachten = BeobachtenStore::new(); + + let relay = tokio::spawn(async move { + handle_bad_client( + client_reader, + client_visible_writer, + &initial, + peer, + local, + &config, + &beobachten, + ) + .await; + }); + + sleep(Duration::from_millis(160)).await; + let _ = client_writer.write_all(b"X").await; + + timeout(Duration::from_secs(2), relay).await.unwrap().unwrap(); + timeout(Duration::from_secs(2), accept_task).await.unwrap().unwrap(); +} + +#[tokio::test] +async fn masking_shape_guard_above_cap_blur_statistical_quality_and_bounds() { + let base_len = 5005usize; // 5-byte header + 5000 payload + let max_extra = 64usize; + let mut extras = Vec::new(); + + for _ in 0..192 { + let observed = run_relay_case( + vec![0x16, 0x03, 0x01, 0x1B, 0x58], + vec![0xAA; 5000], + true, + true, + 512, + 4096, + true, + max_extra, + ) + .await; + + assert!( + observed.len() >= base_len && observed.len() <= base_len + max_extra, + "above-cap blur length must stay in bounded window" + ); + extras.push(observed.len() - base_len); + } + + let unique: std::collections::BTreeSet<_> = extras.iter().copied().collect(); + let mean = extras.iter().copied().sum::() as f64 / extras.len() as f64; + + // For uniform [0..=64], mean is ~32. Keep wide bounds to avoid CI flakiness. + assert!( + (20.0..=44.0).contains(&mean), + "blur mean drifted too far from expected center, mean={mean:.2}" + ); + assert!( + unique.len() >= 16, + "blur distribution appears too low-entropy, unique_extras={}", + unique.len() + ); +} + +#[tokio::test] +async fn masking_shape_guard_above_cap_blur_parallel_stress_keeps_bounds() { + let max_extra = 96usize; + let mut tasks = Vec::new(); + + for i in 0..64usize { + tasks.push(tokio::spawn(async move { + let body_len = 4500 + (i % 256); + let base_len = 5 + body_len; + + let observed = run_relay_case( + vec![0x16, 0x03, 0x01, 0x1B, 0x58], + vec![0xA0 | ((i as u8) & 0x0F); body_len], + true, + true, + 512, + 4096, + true, + max_extra, + ) + .await; + + assert!( + observed.len() >= base_len && observed.len() <= base_len + max_extra, + "parallel blur bounds violated for worker={i}: observed_len={} base_len={} max_extra={}", + observed.len(), + base_len, + max_extra + ); + })); + } + + for task in tasks { + timeout(Duration::from_secs(3), task).await.unwrap().unwrap(); + } +} + +#[tokio::test] +async fn masking_shape_guard_above_cap_blur_disabled_keeps_exact_length_even_on_clean_eof() { + let initial = vec![0x16, 0x03, 0x01, 0x1B, 0x58]; + let body = vec![0x77; 5200]; + let expected = initial.len() + body.len(); + + let observed = run_relay_case(initial, body, true, true, 512, 4096, false, 0).await; + assert_eq!( + observed.len(), + expected, + "without above-cap blur the output must remain exact even on clean EOF" + ); +} diff --git a/src/proxy/tests/masking_shape_guard_security_tests.rs b/src/proxy/tests/masking_shape_guard_security_tests.rs new file mode 100644 index 0000000..72c208f --- /dev/null +++ b/src/proxy/tests/masking_shape_guard_security_tests.rs @@ -0,0 +1,167 @@ +use super::*; +use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpListener; +use tokio::time::{timeout, Duration}; + +#[tokio::test] +async fn shape_guard_empty_initial_data_keeps_transparent_length_on_clean_eof() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let client_payload = vec![0x7A; 64]; + + let accept_task = tokio::spawn({ + let expected = client_payload.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = Vec::new(); + stream.read_to_end(&mut got).await.unwrap(); + assert_eq!(got, expected, "empty initial_data path must not inject shape padding"); + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_shape_hardening = true; + config.censorship.mask_shape_bucket_floor_bytes = 512; + config.censorship.mask_shape_bucket_cap_bytes = 4096; + + let peer: SocketAddr = "203.0.113.90:52001".parse().unwrap(); + let local: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + let (mut client_writer, client_reader) = duplex(2048); + let (_client_visible_reader, client_visible_writer) = duplex(2048); + + let relay_task = tokio::spawn(async move { + handle_bad_client( + client_reader, + client_visible_writer, + b"", + peer, + local, + &config, + &beobachten, + ) + .await; + }); + + client_writer.write_all(&client_payload).await.unwrap(); + client_writer.shutdown().await.unwrap(); + + timeout(Duration::from_secs(2), relay_task).await.unwrap().unwrap(); + timeout(Duration::from_secs(2), accept_task).await.unwrap().unwrap(); +} + +#[tokio::test] +async fn shape_guard_timeout_exit_does_not_append_padding_after_initial_probe() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let initial = b"GET /timeout-shape-guard HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn({ + let initial = initial.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut observed = vec![0u8; initial.len()]; + stream.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, initial); + + let mut one = [0u8; 1]; + let read_res = timeout(Duration::from_millis(220), stream.read_exact(&mut one)).await; + assert!( + read_res.is_err() || read_res.unwrap().is_err(), + "idle-timeout path must not append shape padding after initial probe" + ); + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_shape_hardening = true; + config.censorship.mask_shape_bucket_floor_bytes = 512; + config.censorship.mask_shape_bucket_cap_bytes = 4096; + + let peer: SocketAddr = "203.0.113.91:52002".parse().unwrap(); + let local: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + let (_client_reader_side, client_reader) = duplex(2048); + let (_client_visible_reader, client_visible_writer) = duplex(2048); + + handle_bad_client( + client_reader, + client_visible_writer, + &initial, + peer, + local, + &config, + &beobachten, + ) + .await; + + timeout(Duration::from_secs(2), accept_task).await.unwrap().unwrap(); +} + +#[tokio::test] +async fn shape_guard_clean_eof_with_nonempty_initial_still_applies_bucket_padding() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let initial = b"GET /shape-bucket HTTP/1.1\r\n".to_vec(); + let extra = vec![0x41; 31]; + + let accept_task = tokio::spawn({ + let initial = initial.clone(); + let extra = extra.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = Vec::new(); + stream.read_to_end(&mut got).await.unwrap(); + + let expected_prefix_len = initial.len() + extra.len(); + assert_eq!(&got[..initial.len()], initial.as_slice()); + assert_eq!(&got[initial.len()..expected_prefix_len], extra.as_slice()); + assert_eq!(got.len(), 512, "clean EOF path should still shape to floor bucket"); + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_shape_hardening = true; + config.censorship.mask_shape_bucket_floor_bytes = 512; + config.censorship.mask_shape_bucket_cap_bytes = 4096; + + let peer: SocketAddr = "203.0.113.92:52003".parse().unwrap(); + let local: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + let (mut client_writer, client_reader) = duplex(4096); + let (_client_visible_reader, client_visible_writer) = duplex(4096); + + let relay_task = tokio::spawn(async move { + handle_bad_client( + client_reader, + client_visible_writer, + &initial, + peer, + local, + &config, + &beobachten, + ) + .await; + }); + + client_writer.write_all(&extra).await.unwrap(); + client_writer.shutdown().await.unwrap(); + + timeout(Duration::from_secs(2), relay_task).await.unwrap().unwrap(); + timeout(Duration::from_secs(2), accept_task).await.unwrap().unwrap(); +} diff --git a/src/proxy/tests/masking_shape_hardening_adversarial_tests.rs b/src/proxy/tests/masking_shape_hardening_adversarial_tests.rs new file mode 100644 index 0000000..eade371 --- /dev/null +++ b/src/proxy/tests/masking_shape_hardening_adversarial_tests.rs @@ -0,0 +1,129 @@ +use super::*; +use tokio::io::{duplex, empty, sink, AsyncReadExt, AsyncWrite}; + +struct CountingWriter { + written: usize, +} + +impl CountingWriter { + fn new() -> Self { + Self { written: 0 } + } +} + +impl AsyncWrite for CountingWriter { + fn poll_write( + mut self: std::pin::Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll> { + self.written = self.written.saturating_add(buf.len()); + std::task::Poll::Ready(Ok(buf.len())) + } + + fn poll_flush( + self: std::pin::Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + std::task::Poll::Ready(Ok(())) + } + + fn poll_shutdown( + self: std::pin::Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + std::task::Poll::Ready(Ok(())) + } +} + +#[test] +fn shape_bucket_clamps_to_cap_when_next_power_of_two_exceeds_cap() { + let bucket = next_mask_shape_bucket(1200, 1000, 1500); + assert_eq!(bucket, 1500); +} + +#[test] +fn shape_bucket_never_drops_below_total_for_valid_ranges() { + for total in [1usize, 32, 127, 512, 999, 1000, 1001, 1499, 1500, 1501] { + let bucket = next_mask_shape_bucket(total, 1000, 1500); + assert!(bucket >= total || total >= 1500, "bucket={bucket} total={total}"); + } +} + +#[tokio::test] +async fn maybe_write_shape_padding_writes_exact_delta() { + let mut writer = CountingWriter::new(); + maybe_write_shape_padding(&mut writer, 1200, true, 1000, 1500, false, 0).await; + assert_eq!(writer.written, 300); +} + +#[tokio::test] +async fn maybe_write_shape_padding_skips_when_disabled() { + let mut writer = CountingWriter::new(); + maybe_write_shape_padding(&mut writer, 1200, false, 1000, 1500, false, 0).await; + assert_eq!(writer.written, 0); +} + +#[tokio::test] +async fn relay_to_mask_applies_cap_clamped_padding_for_non_power_of_two_cap() { + let initial = vec![0x16, 0x03, 0x01, 0x04, 0x00]; + let extra = vec![0xAB; 1195]; + + let (client_reader, mut client_writer) = duplex(4096); + let (mut mask_observer, mask_writer) = duplex(4096); + + let relay = tokio::spawn(async move { + relay_to_mask( + client_reader, + sink(), + empty(), + mask_writer, + &initial, + true, + 1000, + 1500, + false, + 0, + ) + .await; + }); + + client_writer.write_all(&extra).await.unwrap(); + client_writer.shutdown().await.unwrap(); + + relay.await.unwrap(); + + let mut observed = Vec::new(); + mask_observer.read_to_end(&mut observed).await.unwrap(); + assert_eq!(observed.len(), 1500); + assert_eq!(&observed[..5], &[0x16, 0x03, 0x01, 0x04, 0x00]); + assert!(observed[5..1200].iter().all(|b| *b == 0xAB)); + assert_eq!(observed[1200..].len(), 300); +} + +#[test] +fn shape_bucket_light_fuzz_monotonicity_and_bounds() { + let floor = 512usize; + let cap = 4096usize; + let mut prev = 0usize; + + for step in 1usize..=3000 { + let total = ((step * 37) ^ (step << 3)) % (cap + 512); + let bucket = next_mask_shape_bucket(total, floor, cap); + + if total < cap { + assert!(bucket >= total, "bucket={bucket} total={total}"); + assert!(bucket <= cap, "bucket={bucket} cap={cap}"); + } else { + assert_eq!(bucket, total, "above-cap totals must remain unchanged"); + } + + if total >= prev { + // For non-decreasing inputs, bucket class must not regress. + let prev_bucket = next_mask_shape_bucket(prev, floor, cap); + assert!(bucket >= prev_bucket || total >= cap); + } + + prev = total; + } +} diff --git a/src/proxy/tests/masking_timing_normalization_security_tests.rs b/src/proxy/tests/masking_timing_normalization_security_tests.rs new file mode 100644 index 0000000..a5959b4 --- /dev/null +++ b/src/proxy/tests/masking_timing_normalization_security_tests.rs @@ -0,0 +1,120 @@ +use super::*; +use tokio::io::duplex; +use tokio::net::TcpListener; +use tokio::time::{Duration, Instant}; + +#[derive(Clone, Copy)] +enum MaskPath { + ConnectFail, + ConnectSuccess, + SlowBackend, +} + +async fn measure_bad_client_duration_ms(path: MaskPath, floor_ms: u64, ceiling_ms: u64) -> u128 { + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = None; + config.censorship.mask_timing_normalization_enabled = true; + config.censorship.mask_timing_normalization_floor_ms = floor_ms; + config.censorship.mask_timing_normalization_ceiling_ms = ceiling_ms; + + let accept_task = match path { + MaskPath::ConnectFail => { + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = 1; + None + } + MaskPath::ConnectSuccess => { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + Some(tokio::spawn(async move { + let (_stream, _) = listener.accept().await.unwrap(); + })) + } + MaskPath::SlowBackend => { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + Some(tokio::spawn(async move { + let (_stream, _) = listener.accept().await.unwrap(); + tokio::time::sleep(Duration::from_millis(320)).await; + })) + } + }; + + let (client_reader, _client_writer) = duplex(1024); + let (_client_visible_reader, client_visible_writer) = duplex(1024); + + let peer: SocketAddr = "198.51.100.221:57121".parse().unwrap(); + let local: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + let started = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + b"GET /timing-normalize HTTP/1.1\r\nHost: x\r\n\r\n", + peer, + local, + &config, + &beobachten, + ) + .await; + + if let Some(task) = accept_task { + let _ = tokio::time::timeout(Duration::from_secs(2), task).await; + } + + started.elapsed().as_millis() +} + +#[tokio::test] +async fn timing_normalization_envelope_applies_to_connect_fail_and_success() { + let floor = 160u64; + let ceiling = 180u64; + + let fail = measure_bad_client_duration_ms(MaskPath::ConnectFail, floor, ceiling).await; + let success = measure_bad_client_duration_ms(MaskPath::ConnectSuccess, floor, ceiling).await; + + assert!( + fail >= floor as u128, + "connect-fail duration below floor: {fail}ms < {floor}ms" + ); + assert!( + fail <= (ceiling + 60) as u128, + "connect-fail duration exceeded relaxed ceiling: {fail}ms > {}ms", + ceiling + 60 + ); + + assert!( + success >= floor as u128, + "connect-success duration below floor: {success}ms < {floor}ms" + ); + assert!( + success <= (ceiling + 60) as u128, + "connect-success duration exceeded relaxed ceiling: {success}ms > {}ms", + ceiling + 60 + ); + + let delta = fail.abs_diff(success); + assert!( + delta <= 80, + "timing normalization should reduce path divergence (delta={}ms)", + delta + ); +} + +#[tokio::test] +async fn timing_normalization_does_not_sleep_if_path_already_exceeds_ceiling() { + let floor = 120u64; + let ceiling = 150u64; + + let slow = measure_bad_client_duration_ms(MaskPath::SlowBackend, floor, ceiling).await; + + assert!(slow >= 280, "slow backend path should remain slow (got {slow}ms)"); + assert!(slow <= 520, "slow backend path should remain bounded in tests (got {slow}ms)"); +} diff --git a/src/proxy/tests/masking_timing_sidechannel_redteam_expected_fail_tests.rs b/src/proxy/tests/masking_timing_sidechannel_redteam_expected_fail_tests.rs new file mode 100644 index 0000000..3c4a342 --- /dev/null +++ b/src/proxy/tests/masking_timing_sidechannel_redteam_expected_fail_tests.rs @@ -0,0 +1,200 @@ +use super::*; +use tokio::io::duplex; +use tokio::net::TcpListener; +use tokio::time::{Duration, Instant}; + +#[derive(Clone, Copy)] +enum TimingPath { + ConnectFail, + ConnectSuccess, + SlowBackend, +} + +async fn measure_path_duration_ms(path: TimingPath) -> u128 { + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = None; + + let maybe_accept = match path { + TimingPath::ConnectFail => { + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = 1; + None + } + TimingPath::ConnectSuccess => { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + Some(tokio::spawn(async move { + let (_stream, _) = listener.accept().await.unwrap(); + })) + } + TimingPath::SlowBackend => { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + Some(tokio::spawn(async move { + let (_stream, _) = listener.accept().await.unwrap(); + tokio::time::sleep(Duration::from_millis(350)).await; + })) + } + }; + + let peer: SocketAddr = "198.51.100.213:57013".parse().unwrap(); + let local: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let beobachten = BeobachtenStore::new(); + + let (client_reader, _client_writer) = duplex(1024); + let (_client_visible_reader, client_visible_writer) = duplex(1024); + + let started = Instant::now(); + handle_bad_client( + client_reader, + client_visible_writer, + b"GET /timing HTTP/1.1\r\nHost: x\r\n\r\n", + peer, + local, + &config, + &beobachten, + ) + .await; + + if let Some(task) = maybe_accept { + let _ = tokio::time::timeout(Duration::from_secs(2), task).await; + } + + started.elapsed().as_millis() +} + +fn summarize(values: &[u128]) -> (u128, u128, f64) { + let min = *values.iter().min().unwrap_or(&0); + let max = *values.iter().max().unwrap_or(&0); + let sum: u128 = values.iter().copied().sum(); + let mean = if values.is_empty() { + 0.0 + } else { + sum as f64 / values.len() as f64 + }; + (min, max, mean) +} + +#[tokio::test] +#[ignore = "red-team expected-fail: strict path-indistinguishability target"] +async fn redteam_timing_01_connect_fail_success_slow_backend_must_be_within_10ms() { + const ITER: usize = 8; + + let mut fail = Vec::with_capacity(ITER); + let mut success = Vec::with_capacity(ITER); + let mut slow = Vec::with_capacity(ITER); + + for _ in 0..ITER { + fail.push(measure_path_duration_ms(TimingPath::ConnectFail).await); + success.push(measure_path_duration_ms(TimingPath::ConnectSuccess).await); + slow.push(measure_path_duration_ms(TimingPath::SlowBackend).await); + } + + let (_, fail_max, fail_mean) = summarize(&fail); + let (_, success_max, success_mean) = summarize(&success); + let (_, slow_max, slow_mean) = summarize(&slow); + + let global_min = *fail + .iter() + .chain(success.iter()) + .chain(slow.iter()) + .min() + .unwrap(); + let global_max = *fail + .iter() + .chain(success.iter()) + .chain(slow.iter()) + .max() + .unwrap(); + + println!( + "redteam_timing path=connect_fail mean_ms={:.2} max_ms={}", + fail_mean, fail_max + ); + println!( + "redteam_timing path=connect_success mean_ms={:.2} max_ms={}", + success_mean, success_max + ); + println!( + "redteam_timing path=slow_backend mean_ms={:.2} max_ms={}", + slow_mean, slow_max + ); + + assert!( + global_max.saturating_sub(global_min) <= 10, + "strict model expects all masking outcomes in one 10ms bucket: min={global_min} max={global_max}" + ); +} + +#[tokio::test] +#[ignore = "red-team expected-fail: strict classifier-separability target"] +async fn redteam_timing_02_path_classifier_centroid_accuracy_must_be_below_40pct() { + const ITER: usize = 12; + + let mut fail = Vec::with_capacity(ITER); + let mut success = Vec::with_capacity(ITER); + let mut slow = Vec::with_capacity(ITER); + + for _ in 0..ITER { + fail.push(measure_path_duration_ms(TimingPath::ConnectFail).await as f64); + success.push(measure_path_duration_ms(TimingPath::ConnectSuccess).await as f64); + slow.push(measure_path_duration_ms(TimingPath::SlowBackend).await as f64); + } + + let mean = |v: &Vec| -> f64 { v.iter().sum::() / v.len() as f64 }; + let c_fail = mean(&fail); + let c_success = mean(&success); + let c_slow = mean(&slow); + + let mut correct = 0usize; + let mut total = 0usize; + + let classify = |x: f64, c0: f64, c1: f64, c2: f64| -> usize { + let d0 = (x - c0).abs(); + let d1 = (x - c1).abs(); + let d2 = (x - c2).abs(); + if d0 <= d1 && d0 <= d2 { + 0 + } else if d1 <= d0 && d1 <= d2 { + 1 + } else { + 2 + } + }; + + for &x in &fail { + total += 1; + if classify(x, c_fail, c_success, c_slow) == 0 { + correct += 1; + } + } + for &x in &success { + total += 1; + if classify(x, c_fail, c_success, c_slow) == 1 { + correct += 1; + } + } + for &x in &slow { + total += 1; + if classify(x, c_fail, c_success, c_slow) == 2 { + correct += 1; + } + } + + let accuracy = correct as f64 / total as f64; + println!( + "redteam_timing_classifier accuracy={:.3} c_fail={:.2} c_success={:.2} c_slow={:.2}", + accuracy, c_fail, c_success, c_slow + ); + + assert!( + accuracy <= 0.40, + "strict model expects poor classifier; observed accuracy={accuracy:.3}" + ); +} diff --git a/src/proxy/tests/middle_relay_desync_all_full_dedup_security_tests.rs b/src/proxy/tests/middle_relay_desync_all_full_dedup_security_tests.rs new file mode 100644 index 0000000..574a3f9 --- /dev/null +++ b/src/proxy/tests/middle_relay_desync_all_full_dedup_security_tests.rs @@ -0,0 +1,179 @@ +use super::*; +use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::thread; + +#[test] +fn desync_all_full_bypass_does_not_initialize_or_grow_dedup_cache() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("desync dedup test lock must be available"); + clear_desync_dedup_for_testing(); + + let initial_len = DESYNC_DEDUP.get().map(|dedup| dedup.len()).unwrap_or(0); + let now = Instant::now(); + + for i in 0..20_000u64 { + assert!( + should_emit_full_desync(0xD35E_D000_0000_0000u64 ^ i, true, now), + "desync_all_full path must always emit" + ); + } + + let after_len = DESYNC_DEDUP.get().map(|dedup| dedup.len()).unwrap_or(0); + assert_eq!( + after_len, initial_len, + "desync_all_full bypass must not allocate or accumulate dedup entries" + ); +} + +#[test] +fn desync_all_full_bypass_keeps_existing_dedup_entries_unchanged() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("desync dedup test lock must be available"); + clear_desync_dedup_for_testing(); + + let dedup = DESYNC_DEDUP.get_or_init(DashMap::new); + let seed_time = Instant::now() - Duration::from_secs(7); + dedup.insert(0xAAAABBBBCCCCDDDD, seed_time); + dedup.insert(0x1111222233334444, seed_time); + + let now = Instant::now(); + for i in 0..2048u64 { + assert!( + should_emit_full_desync(0xF011_F000_0000_0000u64 ^ i, true, now), + "desync_all_full must bypass suppression and dedup refresh" + ); + } + + assert_eq!(dedup.len(), 2, "bypass path must not mutate dedup cardinality"); + assert_eq!( + *dedup + .get(&0xAAAABBBBCCCCDDDD) + .expect("seed key must remain"), + seed_time, + "bypass path must not refresh existing dedup timestamps" + ); + assert_eq!( + *dedup + .get(&0x1111222233334444) + .expect("seed key must remain"), + seed_time, + "bypass path must not touch unrelated dedup entries" + ); +} + +#[test] +fn edge_all_full_burst_does_not_poison_later_false_path_tracking() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("desync dedup test lock must be available"); + clear_desync_dedup_for_testing(); + + let now = Instant::now(); + for i in 0..8192u64 { + assert!(should_emit_full_desync(0xABCD_0000_0000_0000 ^ i, true, now)); + } + + let tracked_key = 0xDEAD_BEEF_0000_0001u64; + assert!( + should_emit_full_desync(tracked_key, false, now), + "first false-path event after all_full burst must still be tracked and emitted" + ); + + let dedup = DESYNC_DEDUP + .get() + .expect("false path should initialize dedup"); + assert!(dedup.get(&tracked_key).is_some()); +} + +#[test] +fn adversarial_mixed_sequence_true_steps_never_change_cache_len() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("desync dedup test lock must be available"); + clear_desync_dedup_for_testing(); + + let dedup = DESYNC_DEDUP.get_or_init(DashMap::new); + for i in 0..256u64 { + dedup.insert(0x1000_0000_0000_0000 ^ i, Instant::now()); + } + + let mut seed = 0xC0DE_CAFE_BAAD_F00Du64; + for i in 0..4096u64 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let flag_all_full = (seed & 0x1) == 1; + let key = 0x7000_0000_0000_0000u64 ^ i ^ seed; + let before = dedup.len(); + let _ = should_emit_full_desync(key, flag_all_full, Instant::now()); + let after = dedup.len(); + + if flag_all_full { + assert_eq!(after, before, "all_full step must not mutate dedup length"); + } + } +} + +#[test] +fn light_fuzz_all_full_mode_always_emits_and_stays_bounded() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("desync dedup test lock must be available"); + clear_desync_dedup_for_testing(); + + let mut seed = 0x1234_5678_9ABC_DEF0u64; + let before = DESYNC_DEDUP.get().map(|d| d.len()).unwrap_or(0); + + for _ in 0..20_000 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + let key = seed ^ 0x55AA_55AA_55AA_55AAu64; + assert!(should_emit_full_desync(key, true, Instant::now())); + } + + let after = DESYNC_DEDUP.get().map(|d| d.len()).unwrap_or(0); + assert_eq!(after, before); + assert!(after <= DESYNC_DEDUP_MAX_ENTRIES); +} + +#[test] +fn stress_parallel_all_full_storm_does_not_grow_or_mutate_cache() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("desync dedup test lock must be available"); + clear_desync_dedup_for_testing(); + + let dedup = DESYNC_DEDUP.get_or_init(DashMap::new); + let seed_time = Instant::now() - Duration::from_secs(2); + for i in 0..1024u64 { + dedup.insert(0x8888_0000_0000_0000 ^ i, seed_time); + } + let before_len = dedup.len(); + + let emits = Arc::new(AtomicUsize::new(0)); + let mut workers = Vec::new(); + for worker in 0..16u64 { + let emits = Arc::clone(&emits); + workers.push(thread::spawn(move || { + let now = Instant::now(); + for i in 0..4096u64 { + let key = 0xFACE_0000_0000_0000u64 ^ (worker << 20) ^ i; + if should_emit_full_desync(key, true, now) { + emits.fetch_add(1, Ordering::Relaxed); + } + } + })); + } + + for worker in workers { + worker.join().expect("worker must not panic"); + } + + assert_eq!(emits.load(Ordering::Relaxed), 16 * 4096); + assert_eq!(dedup.len(), before_len, "parallel all_full storm must not mutate cache len"); +} diff --git a/src/proxy/tests/middle_relay_idle_policy_security_tests.rs b/src/proxy/tests/middle_relay_idle_policy_security_tests.rs new file mode 100644 index 0000000..0efc904 --- /dev/null +++ b/src/proxy/tests/middle_relay_idle_policy_security_tests.rs @@ -0,0 +1,799 @@ +use super::*; +use crate::crypto::AesCtr; +use crate::stats::Stats; +use crate::stream::{BufferPool, CryptoReader}; +use std::sync::{Arc, Mutex, OnceLock}; +use std::sync::atomic::AtomicU64; +use tokio::io::AsyncWriteExt; +use tokio::io::duplex; +use tokio::time::{Duration as TokioDuration, Instant as TokioInstant, timeout}; + +fn make_crypto_reader(reader: T) -> CryptoReader +where + T: AsyncRead + Unpin + Send + 'static, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoReader::new(reader, AesCtr::new(&key, iv)) +} + +fn encrypt_for_reader(plaintext: &[u8]) -> Vec { + let key = [0u8; 32]; + let iv = 0u128; + let mut cipher = AesCtr::new(&key, iv); + cipher.encrypt(plaintext) +} + +fn make_forensics(conn_id: u64, started_at: Instant) -> RelayForensicsState { + RelayForensicsState { + trace_id: 0xA000_0000 + conn_id, + conn_id, + user: format!("idle-test-user-{conn_id}"), + peer: "127.0.0.1:50000".parse().expect("peer parse must succeed"), + peer_hash: hash_ip("127.0.0.1".parse().expect("ip parse must succeed")), + started_at, + bytes_c2me: 0, + bytes_me2c: Arc::new(AtomicU64::new(0)), + desync_all_full: false, + } +} + +fn make_idle_policy(soft_ms: u64, hard_ms: u64, grace_ms: u64) -> RelayClientIdlePolicy { + RelayClientIdlePolicy { + enabled: true, + soft_idle: Duration::from_millis(soft_ms), + hard_idle: Duration::from_millis(hard_ms), + grace_after_downstream_activity: Duration::from_millis(grace_ms), + legacy_frame_read_timeout: Duration::from_millis(hard_ms), + } +} + +fn idle_pressure_test_lock() -> &'static Mutex<()> { + static TEST_LOCK: OnceLock> = OnceLock::new(); + TEST_LOCK.get_or_init(|| Mutex::new(())) +} + +fn acquire_idle_pressure_test_lock() -> std::sync::MutexGuard<'static, ()> { + match idle_pressure_test_lock().lock() { + Ok(guard) => guard, + Err(poisoned) => poisoned.into_inner(), + } +} + +#[tokio::test] +async fn idle_policy_soft_mark_then_hard_close_increments_reason_counters() { + let (reader, _writer) = duplex(1024); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(1, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + let idle_policy = make_idle_policy(40, 120, 0); + let last_downstream_activity_ms = AtomicU64::new(0); + + let start = TokioInstant::now(); + let result = timeout( + TokioDuration::from_secs(2), + read_client_payload_with_idle_policy( + &mut crypto_reader, + ProtoTag::Intermediate, + 1024, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ), + ) + .await + .expect("idle test must complete"); + + assert!(matches!(result, Err(ProxyError::Io(ref e)) if e.kind() == std::io::ErrorKind::TimedOut)); + let err_text = match result { + Err(ProxyError::Io(ref e)) => e.to_string(), + _ => String::new(), + }; + assert!( + err_text.contains("middle-relay hard idle timeout"), + "hard close must expose a clear timeout reason" + ); + assert!( + start.elapsed() >= TokioDuration::from_millis(80), + "hard timeout must not trigger before idle deadline window" + ); + assert_eq!(stats.get_relay_idle_soft_mark_total(), 1); + assert_eq!(stats.get_relay_idle_hard_close_total(), 1); +} + +#[tokio::test] +async fn idle_policy_downstream_activity_grace_extends_hard_deadline() { + let (reader, _writer) = duplex(1024); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(2, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + let idle_policy = make_idle_policy(30, 60, 100); + let last_downstream_activity_ms = AtomicU64::new(20); + + let start = TokioInstant::now(); + let result = timeout( + TokioDuration::from_secs(2), + read_client_payload_with_idle_policy( + &mut crypto_reader, + ProtoTag::Intermediate, + 1024, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ), + ) + .await + .expect("grace test must complete"); + + assert!(matches!(result, Err(ProxyError::Io(ref e)) if e.kind() == std::io::ErrorKind::TimedOut)); + assert!( + start.elapsed() >= TokioDuration::from_millis(100), + "recent downstream activity must extend hard idle deadline" + ); +} + +#[tokio::test] +async fn relay_idle_policy_disabled_keeps_legacy_timeout_behavior() { + let (reader, _writer) = duplex(1024); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let forensics = make_forensics(3, Instant::now()); + let mut frame_counter = 0u64; + + let result = read_client_payload( + &mut crypto_reader, + ProtoTag::Intermediate, + 1024, + Duration::from_millis(60), + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + ) + .await; + + assert!(matches!(result, Err(ProxyError::Io(ref e)) if e.kind() == std::io::ErrorKind::TimedOut)); + let err_text = match result { + Err(ProxyError::Io(ref e)) => e.to_string(), + _ => String::new(), + }; + assert!( + err_text.contains("middle-relay client frame read timeout"), + "legacy mode must keep expected timeout reason" + ); + assert_eq!(stats.get_relay_idle_soft_mark_total(), 0); + assert_eq!(stats.get_relay_idle_hard_close_total(), 0); +} + +#[tokio::test] +async fn adversarial_partial_frame_trickle_cannot_bypass_hard_idle_close() { + let (reader, mut writer) = duplex(1024); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(4, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + let idle_policy = make_idle_policy(30, 90, 0); + let last_downstream_activity_ms = AtomicU64::new(0); + + let mut plaintext = Vec::with_capacity(12); + plaintext.extend_from_slice(&8u32.to_le_bytes()); + plaintext.extend_from_slice(&[1, 2, 3, 4, 5, 6, 7, 8]); + let encrypted = encrypt_for_reader(&plaintext); + writer + .write_all(&encrypted[..1]) + .await + .expect("must write a single trickle byte"); + + let result = timeout( + TokioDuration::from_secs(2), + read_client_payload_with_idle_policy( + &mut crypto_reader, + ProtoTag::Intermediate, + 1024, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ), + ) + .await + .expect("partial frame trickle test must complete"); + + assert!(matches!(result, Err(ProxyError::Io(ref e)) if e.kind() == std::io::ErrorKind::TimedOut)); + assert_eq!(frame_counter, 0, "partial trickle must not count as a valid frame"); +} + +#[tokio::test] +async fn successful_client_frame_resets_soft_idle_mark() { + let (reader, mut writer) = duplex(1024); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(5, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + idle_state.soft_idle_marked = true; + let idle_policy = make_idle_policy(200, 300, 0); + let last_downstream_activity_ms = AtomicU64::new(0); + + let payload = [9u8, 8, 7, 6, 5, 4, 3, 2]; + let mut plaintext = Vec::with_capacity(4 + payload.len()); + plaintext.extend_from_slice(&(payload.len() as u32).to_le_bytes()); + plaintext.extend_from_slice(&payload); + let encrypted = encrypt_for_reader(&plaintext); + writer + .write_all(&encrypted) + .await + .expect("must write full encrypted frame"); + + let read = read_client_payload_with_idle_policy( + &mut crypto_reader, + ProtoTag::Intermediate, + 1024, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ) + .await + .expect("frame read must succeed") + .expect("frame must be returned"); + + assert_eq!(read.0.as_ref(), &payload); + assert_eq!(frame_counter, 1); + assert!( + !idle_state.soft_idle_marked, + "a valid client frame must clear soft-idle mark" + ); +} + +#[tokio::test] +async fn protocol_desync_small_frame_updates_reason_counter() { + let (reader, mut writer) = duplex(1024); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let forensics = make_forensics(6, Instant::now()); + let mut frame_counter = 0u64; + + let mut plaintext = Vec::with_capacity(7); + plaintext.extend_from_slice(&3u32.to_le_bytes()); + plaintext.extend_from_slice(&[1u8, 2, 3]); + let encrypted = encrypt_for_reader(&plaintext); + writer.write_all(&encrypted).await.expect("must write frame"); + + let result = read_client_payload( + &mut crypto_reader, + ProtoTag::Secure, + 1024, + TokioDuration::from_secs(1), + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + ) + .await; + + assert!(matches!(result, Err(ProxyError::Proxy(ref msg)) if msg.contains("Frame too small"))); + assert_eq!(stats.get_relay_protocol_desync_close_total(), 1); +} + +#[tokio::test] +async fn stress_many_idle_sessions_fail_closed_without_hang() { + let mut tasks = Vec::with_capacity(24); + + for idx in 0..24u64 { + tasks.push(tokio::spawn(async move { + let (reader, _writer) = duplex(256); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let session_started_at = Instant::now(); + let forensics = make_forensics(100 + idx, session_started_at); + let mut frame_counter = 0u64; + let mut idle_state = RelayClientIdleState::new(session_started_at); + let idle_policy = make_idle_policy(20, 50, 10); + let last_downstream_activity_ms = AtomicU64::new(0); + + let result = timeout( + TokioDuration::from_secs(2), + read_client_payload_with_idle_policy( + &mut crypto_reader, + ProtoTag::Intermediate, + 1024, + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + &idle_policy, + &mut idle_state, + &last_downstream_activity_ms, + session_started_at, + ), + ) + .await + .expect("stress task must complete"); + + assert!(matches!(result, Err(ProxyError::Io(ref e)) if e.kind() == std::io::ErrorKind::TimedOut)); + assert_eq!(stats.get_relay_idle_hard_close_total(), 1); + assert_eq!(frame_counter, 0); + })); + } + + for task in tasks { + task.await.expect("stress task must not panic"); + } +} + +#[test] +fn pressure_evicts_oldest_idle_candidate_with_deterministic_ordering() { + let _guard = acquire_idle_pressure_test_lock(); + clear_relay_idle_pressure_state_for_testing(); + let stats = Stats::new(); + + assert!(mark_relay_idle_candidate(10)); + assert!(mark_relay_idle_candidate(11)); + assert_eq!(oldest_relay_idle_candidate(), Some(10)); + + note_relay_pressure_event(); + + let mut seen_for_newer = 0u64; + assert!( + !maybe_evict_idle_candidate_on_pressure(11, &mut seen_for_newer, &stats), + "newer idle candidate must not be evicted while older candidate exists" + ); + assert_eq!(oldest_relay_idle_candidate(), Some(10)); + + let mut seen_for_oldest = 0u64; + assert!( + maybe_evict_idle_candidate_on_pressure(10, &mut seen_for_oldest, &stats), + "oldest idle candidate must be evicted first under pressure" + ); + assert_eq!(oldest_relay_idle_candidate(), Some(11)); + assert_eq!(stats.get_relay_pressure_evict_total(), 1); + + clear_relay_idle_pressure_state_for_testing(); +} + +#[test] +fn pressure_does_not_evict_without_new_pressure_signal() { + let _guard = acquire_idle_pressure_test_lock(); + clear_relay_idle_pressure_state_for_testing(); + let stats = Stats::new(); + + assert!(mark_relay_idle_candidate(21)); + let mut seen = relay_pressure_event_seq(); + + assert!( + !maybe_evict_idle_candidate_on_pressure(21, &mut seen, &stats), + "without new pressure signal, candidate must stay" + ); + assert_eq!(stats.get_relay_pressure_evict_total(), 0); + assert_eq!(oldest_relay_idle_candidate(), Some(21)); + + clear_relay_idle_pressure_state_for_testing(); +} + +#[test] +fn stress_pressure_eviction_preserves_fifo_across_many_candidates() { + let _guard = acquire_idle_pressure_test_lock(); + clear_relay_idle_pressure_state_for_testing(); + let stats = Stats::new(); + + let mut seen_per_conn = std::collections::HashMap::new(); + for conn_id in 1000u64..1064u64 { + assert!(mark_relay_idle_candidate(conn_id)); + seen_per_conn.insert(conn_id, 0u64); + } + + for expected in 1000u64..1064u64 { + note_relay_pressure_event(); + + let mut seen = *seen_per_conn + .get(&expected) + .expect("per-conn pressure cursor must exist"); + assert!( + maybe_evict_idle_candidate_on_pressure(expected, &mut seen, &stats), + "expected conn_id {expected} must be evicted next by deterministic FIFO ordering" + ); + seen_per_conn.insert(expected, seen); + + let next = if expected == 1063 { + None + } else { + Some(expected + 1) + }; + assert_eq!(oldest_relay_idle_candidate(), next); + } + + assert_eq!(stats.get_relay_pressure_evict_total(), 64); + clear_relay_idle_pressure_state_for_testing(); +} + +#[test] +fn blackhat_single_pressure_event_must_not_evict_more_than_one_candidate() { + let _guard = acquire_idle_pressure_test_lock(); + clear_relay_idle_pressure_state_for_testing(); + let stats = Stats::new(); + + assert!(mark_relay_idle_candidate(301)); + assert!(mark_relay_idle_candidate(302)); + assert!(mark_relay_idle_candidate(303)); + + let mut seen_301 = 0u64; + let mut seen_302 = 0u64; + let mut seen_303 = 0u64; + + // Single pressure event should authorize at most one eviction globally. + note_relay_pressure_event(); + + let evicted_301 = maybe_evict_idle_candidate_on_pressure(301, &mut seen_301, &stats); + let evicted_302 = maybe_evict_idle_candidate_on_pressure(302, &mut seen_302, &stats); + let evicted_303 = maybe_evict_idle_candidate_on_pressure(303, &mut seen_303, &stats); + + let evicted_total = [evicted_301, evicted_302, evicted_303] + .iter() + .filter(|value| **value) + .count(); + + assert_eq!( + evicted_total, 1, + "single pressure event must not cascade-evict multiple idle candidates" + ); + + clear_relay_idle_pressure_state_for_testing(); +} + +#[test] +fn blackhat_pressure_counter_must_track_global_budget_not_per_session_cursor() { + let _guard = acquire_idle_pressure_test_lock(); + clear_relay_idle_pressure_state_for_testing(); + let stats = Stats::new(); + + assert!(mark_relay_idle_candidate(401)); + assert!(mark_relay_idle_candidate(402)); + + let mut seen_oldest = 0u64; + let mut seen_next = 0u64; + + note_relay_pressure_event(); + + assert!( + maybe_evict_idle_candidate_on_pressure(401, &mut seen_oldest, &stats), + "oldest candidate must consume pressure budget first" + ); + + assert!( + !maybe_evict_idle_candidate_on_pressure(402, &mut seen_next, &stats), + "next candidate must not consume the same pressure budget" + ); + + assert_eq!( + stats.get_relay_pressure_evict_total(), + 1, + "single pressure budget must produce exactly one eviction" + ); + + clear_relay_idle_pressure_state_for_testing(); +} + +#[test] +fn blackhat_stale_pressure_before_idle_mark_must_not_trigger_eviction() { + let _guard = acquire_idle_pressure_test_lock(); + clear_relay_idle_pressure_state_for_testing(); + let stats = Stats::new(); + + // Pressure happened before any idle candidate existed. + note_relay_pressure_event(); + assert!(mark_relay_idle_candidate(501)); + + let mut seen = 0u64; + assert!( + !maybe_evict_idle_candidate_on_pressure(501, &mut seen, &stats), + "stale pressure (before soft-idle mark) must not evict newly marked candidate" + ); + + clear_relay_idle_pressure_state_for_testing(); +} + +#[test] +fn blackhat_stale_pressure_must_not_evict_any_of_newly_marked_batch() { + let _guard = acquire_idle_pressure_test_lock(); + clear_relay_idle_pressure_state_for_testing(); + let stats = Stats::new(); + + note_relay_pressure_event(); + assert!(mark_relay_idle_candidate(511)); + assert!(mark_relay_idle_candidate(512)); + assert!(mark_relay_idle_candidate(513)); + + let mut seen_511 = 0u64; + let mut seen_512 = 0u64; + let mut seen_513 = 0u64; + + let evicted = [ + maybe_evict_idle_candidate_on_pressure(511, &mut seen_511, &stats), + maybe_evict_idle_candidate_on_pressure(512, &mut seen_512, &stats), + maybe_evict_idle_candidate_on_pressure(513, &mut seen_513, &stats), + ] + .iter() + .filter(|value| **value) + .count(); + + assert_eq!( + evicted, 0, + "stale pressure event must not evict any candidate from a newly marked batch" + ); + + clear_relay_idle_pressure_state_for_testing(); +} + +#[test] +fn blackhat_stale_pressure_seen_without_candidates_must_be_globally_invalidated() { + let _guard = acquire_idle_pressure_test_lock(); + clear_relay_idle_pressure_state_for_testing(); + let stats = Stats::new(); + + note_relay_pressure_event(); + + // Session A observed pressure while there were no candidates. + let mut seen_a = 0u64; + assert!( + !maybe_evict_idle_candidate_on_pressure(999_001, &mut seen_a, &stats), + "no candidate existed, so no eviction is possible" + ); + + // Candidate appears later; Session B must not be able to consume stale pressure. + assert!(mark_relay_idle_candidate(521)); + let mut seen_b = 0u64; + assert!( + !maybe_evict_idle_candidate_on_pressure(521, &mut seen_b, &stats), + "once pressure is observed with empty candidate set, it must not be replayed later" + ); + + clear_relay_idle_pressure_state_for_testing(); +} + +#[test] +fn blackhat_stale_pressure_must_not_survive_candidate_churn() { + let _guard = acquire_idle_pressure_test_lock(); + clear_relay_idle_pressure_state_for_testing(); + let stats = Stats::new(); + + note_relay_pressure_event(); + assert!(mark_relay_idle_candidate(531)); + clear_relay_idle_candidate(531); + assert!(mark_relay_idle_candidate(532)); + + let mut seen = 0u64; + assert!( + !maybe_evict_idle_candidate_on_pressure(532, &mut seen, &stats), + "stale pressure must not survive clear+remark churn cycles" + ); + + clear_relay_idle_pressure_state_for_testing(); +} + +#[test] +fn blackhat_pressure_seq_saturation_must_not_disable_future_pressure_accounting() { + let _guard = acquire_idle_pressure_test_lock(); + clear_relay_idle_pressure_state_for_testing(); + + { + let mut guard = relay_idle_candidate_registry() + .lock() + .expect("registry lock must be available"); + guard.pressure_event_seq = u64::MAX; + guard.pressure_consumed_seq = u64::MAX - 1; + } + + // A new pressure event should still be representable; saturating at MAX creates a permanent lockout. + note_relay_pressure_event(); + let after = relay_pressure_event_seq(); + assert_ne!( + after, + u64::MAX, + "pressure sequence saturation must not permanently freeze event progression" + ); + + clear_relay_idle_pressure_state_for_testing(); +} + +#[test] +fn blackhat_pressure_seq_saturation_must_not_break_multiple_distinct_events() { + let _guard = acquire_idle_pressure_test_lock(); + clear_relay_idle_pressure_state_for_testing(); + + { + let mut guard = relay_idle_candidate_registry() + .lock() + .expect("registry lock must be available"); + guard.pressure_event_seq = u64::MAX; + guard.pressure_consumed_seq = u64::MAX; + } + + note_relay_pressure_event(); + let first = relay_pressure_event_seq(); + note_relay_pressure_event(); + let second = relay_pressure_event_seq(); + + assert!( + second > first, + "distinct pressure events must remain distinguishable even at sequence boundary" + ); + + clear_relay_idle_pressure_state_for_testing(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn integration_race_single_pressure_event_allows_at_most_one_eviction_under_parallel_claims() { + let _guard = acquire_idle_pressure_test_lock(); + clear_relay_idle_pressure_state_for_testing(); + + let stats = Arc::new(Stats::new()); + let sessions = 16usize; + let rounds = 200usize; + let conn_ids: Vec = (10_000u64..10_000u64 + sessions as u64).collect(); + let mut seen_per_session = vec![0u64; sessions]; + + for conn_id in &conn_ids { + assert!(mark_relay_idle_candidate(*conn_id)); + } + + for round in 0..rounds { + note_relay_pressure_event(); + + let mut joins = Vec::with_capacity(sessions); + for (idx, conn_id) in conn_ids.iter().enumerate() { + let mut seen = seen_per_session[idx]; + let conn_id = *conn_id; + let stats = stats.clone(); + joins.push(tokio::spawn(async move { + let evicted = maybe_evict_idle_candidate_on_pressure(conn_id, &mut seen, stats.as_ref()); + (idx, conn_id, seen, evicted) + })); + } + + let mut evicted_this_round = 0usize; + let mut evicted_conn = None; + for join in joins { + let (idx, conn_id, seen, evicted) = join.await.expect("race task must not panic"); + seen_per_session[idx] = seen; + if evicted { + evicted_this_round += 1; + evicted_conn = Some(conn_id); + } + } + + assert!( + evicted_this_round <= 1, + "round {round}: one pressure event must never produce more than one eviction" + ); + if let Some(conn) = evicted_conn { + assert!( + mark_relay_idle_candidate(conn), + "round {round}: evicted conn must be re-markable as idle candidate" + ); + } + } + + assert!( + stats.get_relay_pressure_evict_total() <= rounds as u64, + "eviction total must never exceed number of pressure events" + ); + assert!( + stats.get_relay_pressure_evict_total() > 0, + "parallel race must still observe at least one successful eviction" + ); + + clear_relay_idle_pressure_state_for_testing(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn integration_race_burst_pressure_with_churn_preserves_empty_set_invalidation_and_budget() { + let _guard = acquire_idle_pressure_test_lock(); + clear_relay_idle_pressure_state_for_testing(); + + let stats = Arc::new(Stats::new()); + let sessions = 12usize; + let rounds = 120usize; + let conn_ids: Vec = (20_000u64..20_000u64 + sessions as u64).collect(); + let mut seen_per_session = vec![0u64; sessions]; + + for conn_id in &conn_ids { + assert!(mark_relay_idle_candidate(*conn_id)); + } + + let mut expected_total_evictions = 0u64; + + for round in 0..rounds { + let empty_phase = round % 5 == 0; + if empty_phase { + for conn_id in &conn_ids { + clear_relay_idle_candidate(*conn_id); + } + } + + note_relay_pressure_event(); + + let mut joins = Vec::with_capacity(sessions); + for (idx, conn_id) in conn_ids.iter().enumerate() { + let mut seen = seen_per_session[idx]; + let conn_id = *conn_id; + let stats = stats.clone(); + joins.push(tokio::spawn(async move { + let evicted = maybe_evict_idle_candidate_on_pressure(conn_id, &mut seen, stats.as_ref()); + (idx, conn_id, seen, evicted) + })); + } + + let mut evicted_this_round = 0usize; + let mut evicted_conn = None; + for join in joins { + let (idx, conn_id, seen, evicted) = join.await.expect("burst race task must not panic"); + seen_per_session[idx] = seen; + if evicted { + evicted_this_round += 1; + evicted_conn = Some(conn_id); + } + } + + if empty_phase { + assert_eq!( + evicted_this_round, 0, + "round {round}: empty candidate phase must not allow stale-pressure eviction" + ); + for conn_id in &conn_ids { + assert!(mark_relay_idle_candidate(*conn_id)); + } + } else { + assert!( + evicted_this_round <= 1, + "round {round}: pressure budget must cap at one eviction" + ); + if let Some(conn_id) = evicted_conn { + expected_total_evictions = expected_total_evictions.saturating_add(1); + assert!(mark_relay_idle_candidate(conn_id)); + } + } + } + + assert_eq!( + stats.get_relay_pressure_evict_total(), + expected_total_evictions, + "global pressure eviction counter must match observed per-round successful consumes" + ); + + clear_relay_idle_pressure_state_for_testing(); +} diff --git a/src/proxy/tests/middle_relay_security_tests.rs b/src/proxy/tests/middle_relay_security_tests.rs new file mode 100644 index 0000000..874e5ea --- /dev/null +++ b/src/proxy/tests/middle_relay_security_tests.rs @@ -0,0 +1,2517 @@ +use super::*; +use crate::proxy::handshake::HandshakeSuccess; +use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController}; +use bytes::Bytes; +use crate::crypto::AesCtr; +use crate::crypto::SecureRandom; +use crate::config::{GeneralConfig, MeRouteNoWriterMode, MeSocksKdfPolicy, MeWriterPickMode}; +use crate::network::probe::NetworkDecision; +use crate::stats::Stats; +use crate::stream::{BufferPool, CryptoReader, CryptoWriter, PooledBuffer}; +use crate::transport::middle_proxy::MePool; +use rand::rngs::StdRng; +use rand::{RngExt, SeedableRng}; +use std::collections::{HashMap, HashSet}; +use std::net::SocketAddr; +use std::sync::Arc; +use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering}; +use std::thread; +use tokio::sync::Barrier; +use tokio::io::AsyncReadExt; +use tokio::io::AsyncWriteExt; +use tokio::io::duplex; +use tokio::time::{Duration as TokioDuration, timeout}; +use std::sync::{Mutex, OnceLock}; + +fn make_pooled_payload(data: &[u8]) -> PooledBuffer { + let pool = Arc::new(BufferPool::with_config(data.len().max(1), 4)); + let mut payload = pool.get(); + payload.resize(data.len(), 0); + payload[..data.len()].copy_from_slice(data); + payload +} + +fn make_pooled_payload_from(pool: &Arc, data: &[u8]) -> PooledBuffer { + let mut payload = pool.get(); + payload.resize(data.len(), 0); + payload[..data.len()].copy_from_slice(data); + payload +} + +fn quota_user_lock_test_lock() -> &'static Mutex<()> { + static TEST_LOCK: OnceLock> = OnceLock::new(); + TEST_LOCK.get_or_init(|| Mutex::new(())) +} + +#[test] +fn should_yield_sender_only_on_budget_with_backlog() { + assert!(!should_yield_c2me_sender(0, true)); + assert!(!should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET - 1, true)); + assert!(!should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET, false)); + assert!(should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET, true)); +} + +#[tokio::test] +async fn enqueue_c2me_command_uses_try_send_fast_path() { + let (tx, mut rx) = mpsc::channel::(2); + enqueue_c2me_command( + &tx, + C2MeCommand::Data { + payload: make_pooled_payload(&[1, 2, 3]), + flags: 0, + }, + ) + .await + .unwrap(); + + let recv = timeout(TokioDuration::from_millis(50), rx.recv()) + .await + .unwrap() + .unwrap(); + match recv { + C2MeCommand::Data { payload, flags } => { + assert_eq!(payload.as_ref(), &[1, 2, 3]); + assert_eq!(flags, 0); + } + C2MeCommand::Close => panic!("unexpected close command"), + } +} + +#[tokio::test] +async fn enqueue_c2me_command_falls_back_to_send_when_queue_is_full() { + let (tx, mut rx) = mpsc::channel::(1); + tx.send(C2MeCommand::Data { + payload: make_pooled_payload(&[9]), + flags: 9, + }) + .await + .unwrap(); + + let tx2 = tx.clone(); + let producer = tokio::spawn(async move { + enqueue_c2me_command( + &tx2, + C2MeCommand::Data { + payload: make_pooled_payload(&[7, 7]), + flags: 7, + }, + ) + .await + .unwrap(); + }); + + let _ = timeout(TokioDuration::from_millis(100), rx.recv()) + .await + .unwrap(); + producer.await.unwrap(); + + let recv = timeout(TokioDuration::from_millis(100), rx.recv()) + .await + .unwrap() + .unwrap(); + match recv { + C2MeCommand::Data { payload, flags } => { + assert_eq!(payload.as_ref(), &[7, 7]); + assert_eq!(flags, 7); + } + C2MeCommand::Close => panic!("unexpected close command"), + } +} + +#[tokio::test] +async fn enqueue_c2me_command_closed_channel_recycles_payload() { + let pool = Arc::new(BufferPool::with_config(64, 4)); + let payload = make_pooled_payload_from(&pool, &[1, 2, 3, 4]); + let (tx, rx) = mpsc::channel::(1); + drop(rx); + + let result = enqueue_c2me_command( + &tx, + C2MeCommand::Data { + payload, + flags: 0, + }, + ) + .await; + + assert!(result.is_err(), "closed queue must fail enqueue"); + drop(result); + assert!( + pool.stats().pooled >= 1, + "payload must return to pool when enqueue fails on closed channel" + ); +} + +#[tokio::test] +async fn enqueue_c2me_command_full_then_closed_recycles_waiting_payload() { + let pool = Arc::new(BufferPool::with_config(64, 4)); + let (tx, rx) = mpsc::channel::(1); + + tx.send(C2MeCommand::Data { + payload: make_pooled_payload_from(&pool, &[9]), + flags: 1, + }) + .await + .unwrap(); + + let tx2 = tx.clone(); + let pool2 = pool.clone(); + let blocked_send = tokio::spawn(async move { + enqueue_c2me_command( + &tx2, + C2MeCommand::Data { + payload: make_pooled_payload_from(&pool2, &[7, 7, 7]), + flags: 2, + }, + ) + .await + }); + + tokio::time::sleep(TokioDuration::from_millis(10)).await; + drop(rx); + + let result = timeout(TokioDuration::from_secs(1), blocked_send) + .await + .expect("blocked send task must finish") + .expect("blocked send task must not panic"); + + assert!( + result.is_err(), + "closing receiver while sender is blocked must fail enqueue" + ); + drop(result); + assert!( + pool.stats().pooled >= 2, + "both queued and blocked payloads must return to pool after channel close" + ); +} + +#[tokio::test] +async fn enqueue_c2me_command_full_queue_times_out_without_receiver_progress() { + let (tx, _rx) = mpsc::channel::(1); + tx.send(C2MeCommand::Data { + payload: make_pooled_payload(&[1]), + flags: 0, + }) + .await + .unwrap(); + + let started = Instant::now(); + let result = enqueue_c2me_command( + &tx, + C2MeCommand::Data { + payload: make_pooled_payload(&[2, 2]), + flags: 1, + }, + ) + .await; + + assert!( + result.is_err(), + "enqueue must fail when queue stays full beyond bounded timeout" + ); + assert!( + started.elapsed() < TokioDuration::from_millis(400), + "full-queue timeout must resolve promptly" + ); +} + +#[test] +fn desync_dedup_cache_is_bounded() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("desync dedup test lock must be available"); + clear_desync_dedup_for_testing(); + + let now = Instant::now(); + for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 { + assert!( + should_emit_full_desync(key, false, now), + "unique keys up to cap must be tracked" + ); + } + + assert!( + should_emit_full_desync(u64::MAX, false, now), + "new key above cap must emit once after bounded eviction for forensic visibility" + ); + + assert!( + !should_emit_full_desync(u64::MAX, false, now), + "already tracked key inside dedup window must stay suppressed" + ); +} + +#[test] +fn quota_user_lock_cache_reuses_entry_for_same_user() { + let a = quota_user_lock("quota-user-a"); + let b = quota_user_lock("quota-user-a"); + assert!(Arc::ptr_eq(&a, &b), "same user must reuse same quota lock"); +} + +#[test] +fn quota_user_lock_cache_is_bounded_under_unique_churn() { + let _guard = quota_user_lock_test_lock() + .lock() + .expect("quota user lock test lock must be available"); + + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + for idx in 0..(QUOTA_USER_LOCKS_MAX + 128) { + let user = format!("quota-user-{idx}"); + let lock = quota_user_lock(&user); + drop(lock); + } + + assert!( + map.len() <= QUOTA_USER_LOCKS_MAX, + "quota lock cache must stay within configured bound" + ); +} + +#[test] +fn quota_user_lock_cache_saturation_returns_ephemeral_lock_without_growth() { + let _guard = quota_user_lock_test_lock() + .lock() + .expect("quota user lock test lock must be available"); + + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + for attempt in 0..8u32 { + map.clear(); + + let prefix = format!("quota-held-user-{}-{attempt}", std::process::id()); + let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); + for idx in 0..QUOTA_USER_LOCKS_MAX { + let user = format!("{prefix}-{idx}"); + retained.push(quota_user_lock(&user)); + } + + if map.len() != QUOTA_USER_LOCKS_MAX { + drop(retained); + continue; + } + + let overflow_user = format!("quota-overflow-user-{}-{attempt}", std::process::id()); + let overflow_a = quota_user_lock(&overflow_user); + let overflow_b = quota_user_lock(&overflow_user); + + assert_eq!( + map.len(), + QUOTA_USER_LOCKS_MAX, + "overflow acquisition must not grow cache past hard limit" + ); + assert!( + map.get(&overflow_user).is_none(), + "overflow path should not cache new user lock when map is saturated and all entries are retained" + ); + assert!( + !Arc::ptr_eq(&overflow_a, &overflow_b), + "overflow user lock should be ephemeral under saturation to preserve bounded cache size" + ); + + drop(retained); + return; + } + + panic!( + "unable to observe stable saturated lock-cache precondition after bounded retries" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn adversarial_quota_race_under_lock_cache_saturation_still_allows_only_one_winner() { + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); + for idx in 0..QUOTA_USER_LOCKS_MAX { + let user = format!("quota-saturated-user-{idx}"); + retained.push(quota_user_lock(&user)); + } + + assert_eq!( + map.len(), + QUOTA_USER_LOCKS_MAX, + "precondition: cache must be saturated for overflow-user race test" + ); + + let stats = Stats::new(); + let bytes_me2c = AtomicU64::new(0); + let user = "gap-t04-saturated-lock-race-user"; + let barrier = Arc::new(Barrier::new(2)); + + let one = run_quota_race_attempt(&stats, &bytes_me2c, user, 0x55, 9101, barrier.clone()); + let two = run_quota_race_attempt(&stats, &bytes_me2c, user, 0x66, 9102, barrier); + let (r1, r2) = tokio::join!(one, two); + + assert!( + matches!(r1, Ok(_) | Err(ProxyError::DataQuotaExceeded { .. })) + && matches!(r2, Ok(_) | Err(ProxyError::DataQuotaExceeded { .. })), + "both racers must resolve cleanly without unexpected errors" + ); + assert!( + matches!(r1, Err(ProxyError::DataQuotaExceeded { .. })) + || matches!(r2, Err(ProxyError::DataQuotaExceeded { .. })), + "at least one racer must be quota-rejected even when lock cache is saturated" + ); + assert_eq!( + stats.get_user_total_octets(user), + 1, + "saturated lock cache must not permit double-success quota overshoot" + ); + + drop(retained); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_quota_race_under_lock_cache_saturation_never_allows_double_success() { + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); + for idx in 0..QUOTA_USER_LOCKS_MAX { + let user = format!("quota-saturated-stress-holder-{idx}"); + retained.push(quota_user_lock(&user)); + } + + let stats = Stats::new(); + let bytes_me2c = AtomicU64::new(0); + + for round in 0..128u64 { + let user = format!("gap-t04-saturated-race-round-{round}"); + let barrier = Arc::new(Barrier::new(2)); + + let one = run_quota_race_attempt( + &stats, + &bytes_me2c, + &user, + 0x71, + 12_000 + round, + barrier.clone(), + ); + let two = run_quota_race_attempt( + &stats, + &bytes_me2c, + &user, + 0x72, + 13_000 + round, + barrier, + ); + + let (r1, r2) = tokio::join!(one, two); + assert!( + matches!(r1, Ok(_) | Err(ProxyError::DataQuotaExceeded { .. })) + && matches!(r2, Ok(_) | Err(ProxyError::DataQuotaExceeded { .. })), + "round {round}: racers must resolve cleanly" + ); + assert!( + matches!(r1, Err(ProxyError::DataQuotaExceeded { .. })) + || matches!(r2, Err(ProxyError::DataQuotaExceeded { .. })), + "round {round}: at least one racer must be quota-rejected" + ); + assert_eq!( + stats.get_user_total_octets(&user), + 1, + "round {round}: saturated cache must still enforce exactly one forwarded byte" + ); + } + + drop(retained); +} + +#[test] +fn adversarial_forensics_trace_id_should_not_alias_conn_id() { + let now = Instant::now(); + let trace_id = 0x1122_3344_5566_7788; + let conn_id = 0x8877_6655_4433_2211; + let state = RelayForensicsState { + trace_id, + conn_id, + user: "trace-user".to_string(), + peer: "198.51.100.17:443".parse().unwrap(), + peer_hash: 0x8877_6655_4433_2211, + started_at: now, + bytes_c2me: 0, + bytes_me2c: Arc::new(AtomicU64::new(0)), + desync_all_full: false, + }; + + assert_ne!( + state.trace_id, state.conn_id, + "security expectation: trace correlation should be independent of connection identity" + ); + assert_eq!(state.trace_id, trace_id); + assert_eq!(state.conn_id, conn_id); +} + +#[tokio::test] +async fn abridged_ack_uses_big_endian_confirm_bytes_after_decryption() { + let (mut writer_side, reader_side) = duplex(8); + let key = [0u8; 32]; + let iv = 0u128; + let mut writer = CryptoWriter::new(reader_side, AesCtr::new(&key, iv), 8 * 1024); + + write_client_ack(&mut writer, ProtoTag::Abridged, 0x11_22_33_44) + .await + .expect("ack write must succeed"); + + let mut observed = [0u8; 4]; + writer_side + .read_exact(&mut observed) + .await + .expect("ack bytes must be readable"); + let mut decryptor = AesCtr::new(&key, iv); + let decrypted = decryptor.decrypt(&observed); + + assert_eq!( + decrypted, + 0x11_22_33_44u32.to_be_bytes(), + "abridged ACK should encode confirm bytes in big-endian order" + ); +} + +#[test] +fn desync_dedup_full_cache_churn_stays_suppressed() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("desync dedup test lock must be available"); + clear_desync_dedup_for_testing(); + + let now = Instant::now(); + for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 { + assert!(should_emit_full_desync(key, false, now)); + } + + for offset in 0..2048u64 { + let emitted = should_emit_full_desync(u64::MAX - offset, false, now); + if offset == 0 { + assert!( + emitted, + "first full-cache newcomer should emit for forensic visibility" + ); + } else { + assert!( + !emitted, + "full-cache newcomer churn inside emit interval must stay suppressed" + ); + } + } +} + +#[test] +fn dedup_hash_is_stable_for_same_input_within_process() { + let sample = ( + "scope_user", + hash_ip("198.51.100.7".parse().unwrap()), + ProtoTag::Secure, + ); + let first = hash_value(&sample); + let second = hash_value(&sample); + assert_eq!( + first, second, + "dedup hash must be stable within a process for cache lookups" + ); +} + +#[test] +fn dedup_hash_resists_simple_collision_bursts_for_peer_ip_space() { + let mut seen = HashSet::new(); + + for octet in 1u16..=2048 { + let third = ((octet / 256) & 0xff) as u8; + let fourth = (octet & 0xff) as u8; + let ip = IpAddr::V4(std::net::Ipv4Addr::new(198, 51, third, fourth)); + let key = hash_value(&( + "scope_user", + hash_ip(ip), + ProtoTag::Secure, + DESYNC_ERROR_CLASS, + )); + seen.insert(key); + } + + assert_eq!( + seen.len(), + 2048, + "adversarial peer-IP burst should not collapse dedup keys via trivial collisions" + ); +} + +#[test] +fn light_fuzz_dedup_hash_collision_rate_stays_negligible() { + let mut rng = StdRng::seed_from_u64(0x9E37_79B9_A1B2_C3D4); + let mut seen = HashSet::new(); + let samples = 8192usize; + + for _ in 0..samples { + let user_seed: u64 = rng.random(); + let peer_seed: u64 = rng.random(); + let proto = if (peer_seed & 1) == 0 { + ProtoTag::Secure + } else { + ProtoTag::Intermediate + }; + let key = hash_value(&(user_seed, peer_seed, proto, DESYNC_ERROR_CLASS)); + seen.insert(key); + } + + let collisions = samples - seen.len(); + assert!( + collisions <= 1, + "light fuzz collision count should remain negligible for 64-bit dedup keys" + ); +} + +#[test] +fn stress_desync_dedup_churn_keeps_cache_hard_bounded() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("desync dedup test lock must be available"); + clear_desync_dedup_for_testing(); + + let now = Instant::now(); + let total = DESYNC_DEDUP_MAX_ENTRIES + 8192; + + let mut emitted_count = 0usize; + for key in 0..total as u64 { + let emitted = should_emit_full_desync(key, false, now); + if emitted { + emitted_count += 1; + } + } + + assert_eq!( + emitted_count, + DESYNC_DEDUP_MAX_ENTRIES + 1, + "after capacity is reached, same-tick newcomer churn must be rate-limited" + ); + + let len = DESYNC_DEDUP + .get() + .expect("dedup cache must be initialized by stress run") + .len(); + assert!( + len <= DESYNC_DEDUP_MAX_ENTRIES, + "dedup cache must stay bounded under stress churn" + ); +} + +#[test] +fn full_cache_newcomer_emission_is_rate_limited_but_periodic() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("desync dedup test lock must be available"); + clear_desync_dedup_for_testing(); + + let dedup = DESYNC_DEDUP.get_or_init(DashMap::new); + let base_now = Instant::now(); + + for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 { + dedup.insert(key, base_now - TokioDuration::from_millis(10)); + } + + // Same-tick newcomer storm: only the first should emit full forensic record. + let mut burst_emits = 0usize; + for i in 0..1024u64 { + if should_emit_full_desync(10_000_000 + i, false, base_now) { + burst_emits += 1; + } + } + assert_eq!( + burst_emits, 1, + "full-cache newcomer burst must be bounded to a single full emit per interval" + ); + + // After each interval elapses, one newcomer may emit again. + for step in 1..=6u64 { + let t = base_now + DESYNC_FULL_CACHE_EMIT_MIN_INTERVAL * step as u32; + assert!( + should_emit_full_desync(20_000_000 + step, false, t), + "full-cache newcomer should re-emit once interval has elapsed" + ); + assert!( + !should_emit_full_desync(30_000_000 + step, false, t), + "additional newcomers in the same interval tick must remain suppressed" + ); + } +} + +#[test] +fn full_cache_mode_override_emits_every_event() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("desync dedup test lock must be available"); + clear_desync_dedup_for_testing(); + + let now = Instant::now(); + for i in 0..10_000u64 { + assert!( + should_emit_full_desync(100_000_000 + i, true, now), + "desync_all_full override must bypass dedup and rate-limit suppression" + ); + } +} + +#[test] +fn report_desync_stats_follow_rate_limited_full_cache_policy() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("desync dedup test lock must be available"); + clear_desync_dedup_for_testing(); + + let dedup = DESYNC_DEDUP.get_or_init(DashMap::new); + let base_now = Instant::now(); + for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 { + dedup.insert(key, base_now - TokioDuration::from_millis(10)); + } + + let stats = Stats::new(); + let mut state = make_forensics_state(); + state.started_at = base_now; + + for i in 0..128u64 { + state.peer_hash = 0xABC0_0000_0000_0000u64 ^ i; + let _ = report_desync_frame_too_large( + &state, + ProtoTag::Secure, + 3, + 1024, + 4096, + Some([0x16, 0x03, 0x03, 0x00]), + &stats, + ); + } + + assert_eq!( + stats.get_desync_total(), + 128, + "every detected desync must increment total counter" + ); + assert_eq!( + stats.get_desync_full_logged(), + 1, + "same-interval full-cache newcomer storm must allow only one full forensic emit" + ); + assert_eq!( + stats.get_desync_suppressed(), + 127, + "remaining same-interval full-cache newcomer events must be suppressed" + ); + + // After one full interval in real wall clock, a newcomer should emit again. + thread::sleep(DESYNC_FULL_CACHE_EMIT_MIN_INTERVAL + TokioDuration::from_millis(20)); + state.peer_hash = 0xDEAD_BEEF_DEAD_BEEFu64; + let _ = report_desync_frame_too_large( + &state, + ProtoTag::Secure, + 4, + 1024, + 4097, + Some([0x16, 0x03, 0x03, 0x01]), + &stats, + ); + + assert_eq!( + stats.get_desync_full_logged(), + 2, + "full forensic emission must recover after rate-limit interval" + ); +} + +#[test] +fn concurrent_full_cache_newcomer_storm_is_single_emit_per_interval() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("desync dedup test lock must be available"); + clear_desync_dedup_for_testing(); + + let dedup = DESYNC_DEDUP.get_or_init(DashMap::new); + let base_now = Instant::now(); + for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 { + dedup.insert(key, base_now - TokioDuration::from_millis(10)); + } + + let emits = Arc::new(AtomicUsize::new(0)); + let mut workers = Vec::new(); + for worker_id in 0..32u64 { + let emits = Arc::clone(&emits); + workers.push(thread::spawn(move || { + for i in 0..512u64 { + let key = 0x7000_0000_0000_0000u64 ^ (worker_id << 20) ^ i; + if should_emit_full_desync(key, false, base_now) { + emits.fetch_add(1, Ordering::Relaxed); + } + } + })); + } + + for worker in workers { + worker.join().expect("worker thread must not panic"); + } + + assert_eq!( + emits.load(Ordering::Relaxed), + 1, + "concurrent same-interval full-cache storm must allow only one full forensic emit" + ); +} + +#[test] +fn light_fuzz_full_cache_rate_limit_oracle_matches_model() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("desync dedup test lock must be available"); + clear_desync_dedup_for_testing(); + + let dedup = DESYNC_DEDUP.get_or_init(DashMap::new); + let base_now = Instant::now(); + for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 { + dedup.insert(key, base_now - TokioDuration::from_millis(10)); + } + + let mut rng = StdRng::seed_from_u64(0xD15EA5E5_F00DBAAD); + let mut model_last_emit: Option = None; + + for i in 0..4096u64 { + let jitter_ms: u64 = rng.random_range(0..=3000); + let t = base_now + TokioDuration::from_millis(jitter_ms); + let key = 0x55AA_0000_0000_0000u64 ^ i ^ rng.random::(); + let actual = should_emit_full_desync(key, false, t); + + let expected = match model_last_emit { + None => { + model_last_emit = Some(t); + true + } + Some(last) => { + match t.checked_duration_since(last) { + Some(elapsed) if elapsed >= DESYNC_FULL_CACHE_EMIT_MIN_INTERVAL => { + model_last_emit = Some(t); + true + } + Some(_) => false, + None => { + // Match production fail-open behavior for non-monotonic synthetic input. + model_last_emit = Some(t); + true + } + } + } + }; + + assert_eq!( + actual, expected, + "full-cache rate-limit gate diverged from reference model under light fuzz" + ); + } +} + +#[test] +fn full_cache_gate_lock_poison_is_fail_closed_without_panic() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("desync dedup test lock must be available"); + clear_desync_dedup_for_testing(); + + let dedup = DESYNC_DEDUP.get_or_init(DashMap::new); + let base_now = Instant::now(); + for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 { + dedup.insert(key, base_now - TokioDuration::from_millis(10)); + } + + // Poison the full-cache gate lock intentionally. + let gate = DESYNC_FULL_CACHE_LAST_EMIT_AT.get_or_init(|| Mutex::new(None)); + let _ = std::panic::catch_unwind(|| { + let _lock = gate.lock().expect("gate lock must be lockable before poison"); + panic!("intentional gate poison for fail-closed regression"); + }); + + let emitted = should_emit_full_desync(0xFACE_0000_0000_0001, false, base_now); + assert!( + !emitted, + "poisoned full-cache gate must fail-closed (suppress) instead of panic or fail-open" + ); + assert!( + dedup.len() <= DESYNC_DEDUP_MAX_ENTRIES, + "dedup cache must remain bounded even when gate lock is poisoned" + ); +} + +#[test] +fn full_cache_non_monotonic_time_emits_and_resets_gate_safely() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("desync dedup test lock must be available"); + clear_desync_dedup_for_testing(); + + let dedup = DESYNC_DEDUP.get_or_init(DashMap::new); + let base_now = Instant::now(); + for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 { + dedup.insert(key, base_now - TokioDuration::from_millis(10)); + } + + // First event seeds the gate. + assert!(should_emit_full_desync( + 0xABCD_0000_0000_0001, + false, + base_now + TokioDuration::from_millis(900) + )); + + // Synthetic earlier timestamp must not panic; it should fail-open and reset gate. + assert!(should_emit_full_desync( + 0xABCD_0000_0000_0002, + false, + base_now + TokioDuration::from_millis(100) + )); + + // Same instant again remains suppressed after reset. + assert!(!should_emit_full_desync( + 0xABCD_0000_0000_0003, + false, + base_now + TokioDuration::from_millis(100) + )); +} + +#[test] +fn desync_dedup_full_cache_inserts_new_key_with_bounded_single_key_churn() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("desync dedup test lock must be available"); + clear_desync_dedup_for_testing(); + + let dedup = DESYNC_DEDUP.get_or_init(DashMap::new); + let base_now = Instant::now(); + + // Fill with fresh entries so stale-pruning does not apply. + for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 { + dedup.insert(key, base_now - TokioDuration::from_millis(10)); + } + + let before_keys: std::collections::HashSet = dedup.iter().map(|e| *e.key()).collect(); + + let newcomer_key = u64::MAX; + let emitted = should_emit_full_desync(newcomer_key, false, base_now); + assert!( + emitted, + "new entry under full fresh cache must emit after bounded eviction" + ); + assert!( + dedup.get(&newcomer_key).is_some(), + "new key must be inserted after bounded eviction" + ); + + let after_keys: std::collections::HashSet = dedup.iter().map(|e| *e.key()).collect(); + let removed_count = before_keys.difference(&after_keys).count(); + let added_count = after_keys.difference(&before_keys).count(); + + assert_eq!( + removed_count, 1, + "full-cache insertion must evict exactly one prior key" + ); + assert_eq!( + added_count, 1, + "full-cache insertion must add exactly one newcomer key" + ); + assert!( + dedup.len() <= DESYNC_DEDUP_MAX_ENTRIES, + "dedup cache must remain hard-bounded after full-cache churn" + ); +} + +#[test] +fn light_fuzz_desync_dedup_temporal_gate_behavior_is_stable() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("desync dedup test lock must be available"); + clear_desync_dedup_for_testing(); + + let key = 0xC0DE_CAFE_u64; + let start = Instant::now(); + + assert!( + should_emit_full_desync(key, false, start), + "first event for key must emit full forensic record" + ); + + // Deterministic pseudo-random time deltas around dedup window edge. + let mut s: u64 = 0x1234_5678_9ABC_DEF0; + for _ in 0..2048 { + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + + let delta_ms = s % (DESYNC_DEDUP_WINDOW.as_millis() as u64 * 2 + 1); + let now = start + TokioDuration::from_millis(delta_ms); + let emitted = should_emit_full_desync(key, false, now); + + if delta_ms < DESYNC_DEDUP_WINDOW.as_millis() as u64 { + assert!( + !emitted, + "events inside dedup window must remain suppressed" + ); + } else { + // Once window elapsed for this key, at least one sample should re-emit and refresh. + if emitted { + return; + } + } + } + + panic!("expected at least one post-window sample to re-emit forensic record"); +} + +#[test] +#[ignore = "Tracking for M-04: Verify should_emit_full_desync returns true on first occurrence and false on duplicate within window"] +fn should_emit_full_desync_filters_duplicates() { + unimplemented!("Stub for M-04"); +} + +#[test] +#[ignore = "Tracking for M-04: Verify desync dedup eviction behaves correctly under map-full condition"] +fn desync_dedup_eviction_under_map_full_condition() { + unimplemented!("Stub for M-04"); +} + +#[tokio::test] +#[ignore = "Tracking for M-05: Verify C2ME channel full path yields then sends under backpressure"] +async fn c2me_channel_full_path_yields_then_sends() { + unimplemented!("Stub for M-05"); +} + +fn make_forensics_state() -> RelayForensicsState { + RelayForensicsState { + trace_id: 1, + conn_id: 2, + user: "test-user".to_string(), + peer: "127.0.0.1:50000".parse::().unwrap(), + peer_hash: 3, + started_at: Instant::now(), + bytes_c2me: 0, + bytes_me2c: Arc::new(AtomicU64::new(0)), + desync_all_full: false, + } +} + +fn make_crypto_reader(reader: R) -> CryptoReader +where + R: tokio::io::AsyncRead + Unpin, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoReader::new(reader, AesCtr::new(&key, iv)) +} + +fn make_crypto_writer(writer: W) -> CryptoWriter +where + W: tokio::io::AsyncWrite + Unpin, +{ + let key = [0u8; 32]; + let iv = 0u128; + CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) +} + +async fn make_me_pool_for_abort_test(stats: Arc) -> Arc { + let general = GeneralConfig::default(); + + MePool::new( + None, + vec![1u8; 32], + None, + false, + None, + Vec::new(), + 1, + None, + 12, + 1200, + HashMap::new(), + HashMap::new(), + None, + NetworkDecision::default(), + None, + Arc::new(SecureRandom::new()), + stats, + general.me_keepalive_enabled, + general.me_keepalive_interval_secs, + general.me_keepalive_jitter_secs, + general.me_keepalive_payload_random, + general.rpc_proxy_req_every, + general.me_warmup_stagger_enabled, + general.me_warmup_step_delay_ms, + general.me_warmup_step_jitter_ms, + general.me_reconnect_max_concurrent_per_dc, + general.me_reconnect_backoff_base_ms, + general.me_reconnect_backoff_cap_ms, + general.me_reconnect_fast_retry_count, + general.me_single_endpoint_shadow_writers, + general.me_single_endpoint_outage_mode_enabled, + general.me_single_endpoint_outage_disable_quarantine, + general.me_single_endpoint_outage_backoff_min_ms, + general.me_single_endpoint_outage_backoff_max_ms, + general.me_single_endpoint_shadow_rotate_every_secs, + general.me_floor_mode, + general.me_adaptive_floor_idle_secs, + general.me_adaptive_floor_min_writers_single_endpoint, + general.me_adaptive_floor_min_writers_multi_endpoint, + general.me_adaptive_floor_recover_grace_secs, + general.me_adaptive_floor_writers_per_core_total, + general.me_adaptive_floor_cpu_cores_override, + general.me_adaptive_floor_max_extra_writers_single_per_core, + general.me_adaptive_floor_max_extra_writers_multi_per_core, + general.me_adaptive_floor_max_active_writers_per_core, + general.me_adaptive_floor_max_warm_writers_per_core, + general.me_adaptive_floor_max_active_writers_global, + general.me_adaptive_floor_max_warm_writers_global, + general.hardswap, + general.me_pool_drain_ttl_secs, + general.me_instadrain, + general.me_pool_drain_threshold, + general.me_pool_drain_soft_evict_enabled, + general.me_pool_drain_soft_evict_grace_secs, + general.me_pool_drain_soft_evict_per_writer, + general.me_pool_drain_soft_evict_budget_per_core, + general.me_pool_drain_soft_evict_cooldown_ms, + general.effective_me_pool_force_close_secs(), + general.me_pool_min_fresh_ratio, + general.me_hardswap_warmup_delay_min_ms, + general.me_hardswap_warmup_delay_max_ms, + general.me_hardswap_warmup_extra_passes, + general.me_hardswap_warmup_pass_backoff_base_ms, + general.me_bind_stale_mode, + general.me_bind_stale_ttl_secs, + general.me_secret_atomic_snapshot, + general.me_deterministic_writer_sort, + MeWriterPickMode::default(), + general.me_writer_pick_sample_size, + MeSocksKdfPolicy::default(), + general.me_writer_cmd_channel_capacity, + general.me_route_channel_capacity, + general.me_route_backpressure_base_timeout_ms, + general.me_route_backpressure_high_timeout_ms, + general.me_route_backpressure_high_watermark_pct, + general.me_reader_route_data_wait_ms, + general.me_health_interval_ms_unhealthy, + general.me_health_interval_ms_healthy, + general.me_warn_rate_limit_ms, + MeRouteNoWriterMode::default(), + general.me_route_no_writer_wait_ms, + general.me_route_inline_recovery_attempts, + general.me_route_inline_recovery_wait_ms, + ) +} + +fn encrypt_for_reader(plaintext: &[u8]) -> Vec { + let key = [0u8; 32]; + let iv = 0u128; + let mut cipher = AesCtr::new(&key, iv); + cipher.encrypt(plaintext) +} + +#[tokio::test] +async fn read_client_payload_times_out_on_header_stall() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("middle relay test lock must be available"); + let (reader, _writer) = duplex(1024); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let forensics = make_forensics_state(); + let mut frame_counter = 0; + + let result = read_client_payload( + &mut crypto_reader, + ProtoTag::Intermediate, + 1024, + TokioDuration::from_millis(25), + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + ) + .await; + + assert!( + matches!(result, Err(ProxyError::Io(ref e)) if e.kind() == std::io::ErrorKind::TimedOut), + "stalled header read must time out" + ); +} + +#[tokio::test] +async fn read_client_payload_times_out_on_payload_stall() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("middle relay test lock must be available"); + let (reader, mut writer) = duplex(1024); + let encrypted_len = encrypt_for_reader(&[8, 0, 0, 0]); + writer.write_all(&encrypted_len).await.unwrap(); + + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let forensics = make_forensics_state(); + let mut frame_counter = 0; + + let result = read_client_payload( + &mut crypto_reader, + ProtoTag::Intermediate, + 1024, + TokioDuration::from_millis(25), + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + ) + .await; + + assert!( + matches!(result, Err(ProxyError::Io(ref e)) if e.kind() == std::io::ErrorKind::TimedOut), + "stalled payload body read must time out" + ); +} + +#[tokio::test] +async fn read_client_payload_large_intermediate_frame_is_exact() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("middle relay test lock must be available"); + + let (reader, mut writer) = duplex(262_144); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let forensics = make_forensics_state(); + let mut frame_counter = 0; + + let payload_len = buffer_pool.buffer_size().saturating_mul(3).max(65_537); + let mut plaintext = Vec::with_capacity(4 + payload_len); + plaintext.extend_from_slice(&(payload_len as u32).to_le_bytes()); + plaintext.extend((0..payload_len).map(|idx| (idx as u8).wrapping_mul(31))); + + let encrypted = encrypt_for_reader(&plaintext); + writer.write_all(&encrypted).await.unwrap(); + + let read = read_client_payload( + &mut crypto_reader, + ProtoTag::Intermediate, + payload_len + 16, + TokioDuration::from_secs(1), + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + ) + .await + .expect("payload read must succeed") + .expect("frame must be present"); + + let (frame, quickack) = read; + assert!(!quickack, "quickack flag must be unset"); + assert_eq!(frame.len(), payload_len, "payload size must match wire length"); + for (idx, byte) in frame.iter().enumerate() { + assert_eq!(*byte, (idx as u8).wrapping_mul(31)); + } + assert_eq!(frame_counter, 1, "exactly one frame must be counted"); +} + +#[tokio::test] +async fn read_client_payload_secure_strips_tail_padding_bytes() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("middle relay test lock must be available"); + + let (reader, mut writer) = duplex(1024); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let forensics = make_forensics_state(); + let mut frame_counter = 0; + + let payload = [0x11u8, 0x22, 0x33, 0x44, 0xaa, 0xbb, 0xcc, 0xdd]; + let tail = [0xeeu8, 0xff, 0x99]; + let wire_len = payload.len() + tail.len(); + + let mut plaintext = Vec::with_capacity(4 + wire_len); + plaintext.extend_from_slice(&(wire_len as u32).to_le_bytes()); + plaintext.extend_from_slice(&payload); + plaintext.extend_from_slice(&tail); + let encrypted = encrypt_for_reader(&plaintext); + writer.write_all(&encrypted).await.unwrap(); + + let read = read_client_payload( + &mut crypto_reader, + ProtoTag::Secure, + 1024, + TokioDuration::from_secs(1), + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + ) + .await + .expect("secure payload read must succeed") + .expect("secure frame must be present"); + + let (frame, quickack) = read; + assert!(!quickack, "quickack flag must be unset"); + assert_eq!(frame.as_ref(), &payload); + assert_eq!(frame_counter, 1, "one secure frame must be counted"); +} + +#[tokio::test] +async fn read_client_payload_secure_rejects_wire_len_below_4() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("middle relay test lock must be available"); + + let (reader, mut writer) = duplex(1024); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let forensics = make_forensics_state(); + let mut frame_counter = 0; + + let mut plaintext = Vec::with_capacity(7); + plaintext.extend_from_slice(&3u32.to_le_bytes()); + plaintext.extend_from_slice(&[1u8, 2, 3]); + let encrypted = encrypt_for_reader(&plaintext); + writer.write_all(&encrypted).await.unwrap(); + + let result = read_client_payload( + &mut crypto_reader, + ProtoTag::Secure, + 1024, + TokioDuration::from_secs(1), + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + ) + .await; + + assert!( + matches!(result, Err(ProxyError::Proxy(ref msg)) if msg.contains("Frame too small: 3")), + "secure wire length below 4 must be fail-closed by the frame-too-small guard" + ); +} + +#[tokio::test] +async fn read_client_payload_intermediate_skips_zero_len_frame() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("middle relay test lock must be available"); + + let (reader, mut writer) = duplex(1024); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let forensics = make_forensics_state(); + let mut frame_counter = 0; + + let payload = [7u8, 6, 5, 4, 3, 2, 1, 0]; + let mut plaintext = Vec::with_capacity(4 + 4 + payload.len()); + plaintext.extend_from_slice(&0u32.to_le_bytes()); + plaintext.extend_from_slice(&(payload.len() as u32).to_le_bytes()); + plaintext.extend_from_slice(&payload); + let encrypted = encrypt_for_reader(&plaintext); + writer.write_all(&encrypted).await.unwrap(); + + let read = read_client_payload( + &mut crypto_reader, + ProtoTag::Intermediate, + 1024, + TokioDuration::from_secs(1), + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + ) + .await + .expect("intermediate payload read must succeed") + .expect("frame must be present"); + + let (frame, quickack) = read; + assert!(!quickack, "quickack flag must be unset"); + assert_eq!(frame.as_ref(), &payload); + assert_eq!(frame_counter, 1, "zero-length frame must be skipped"); +} + +#[tokio::test] +async fn read_client_payload_abridged_extended_len_sets_quickack() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("middle relay test lock must be available"); + + let (reader, mut writer) = duplex(4096); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let forensics = make_forensics_state(); + let mut frame_counter = 0; + + let payload_len = 4 * 130; + let len_words = (payload_len / 4) as u32; + let mut plaintext = Vec::with_capacity(1 + 3 + payload_len); + plaintext.push(0xff | 0x80); + let lw = len_words.to_le_bytes(); + plaintext.extend_from_slice(&lw[..3]); + plaintext.extend((0..payload_len).map(|idx| (idx as u8).wrapping_add(17))); + + let encrypted = encrypt_for_reader(&plaintext); + writer.write_all(&encrypted).await.unwrap(); + + let read = read_client_payload( + &mut crypto_reader, + ProtoTag::Abridged, + payload_len + 16, + TokioDuration::from_secs(1), + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + ) + .await + .expect("abridged payload read must succeed") + .expect("frame must be present"); + + let (frame, quickack) = read; + assert!(quickack, "quickack bit must be propagated from abridged header"); + assert_eq!(frame.len(), payload_len); + assert_eq!(frame_counter, 1, "one abridged frame must be counted"); +} + +#[tokio::test] +async fn read_client_payload_returns_buffer_to_pool_after_emit() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("middle relay test lock must be available"); + + let pool = Arc::new(BufferPool::with_config(64, 8)); + pool.preallocate(1); + assert_eq!(pool.stats().pooled, 1, "precondition: one pooled buffer"); + + let (reader, mut writer) = duplex(4096); + let mut crypto_reader = make_crypto_reader(reader); + let stats = Stats::new(); + let forensics = make_forensics_state(); + let mut frame_counter = 0; + + // Force growth beyond default pool buffer size to catch ownership-take regressions. + let payload_len = 257usize; + let mut plaintext = Vec::with_capacity(4 + payload_len); + plaintext.extend_from_slice(&(payload_len as u32).to_le_bytes()); + plaintext.extend((0..payload_len).map(|idx| (idx as u8).wrapping_mul(13))); + + let encrypted = encrypt_for_reader(&plaintext); + writer.write_all(&encrypted).await.unwrap(); + + let _ = read_client_payload( + &mut crypto_reader, + ProtoTag::Intermediate, + payload_len + 8, + TokioDuration::from_secs(1), + &pool, + &forensics, + &mut frame_counter, + &stats, + ) + .await + .expect("payload read must succeed") + .expect("frame must be present"); + + assert_eq!(frame_counter, 1); + let pool_stats = pool.stats(); + assert!( + pool_stats.pooled >= 1, + "emitted payload buffer must be returned to pool to avoid pool drain" + ); +} + +#[tokio::test] +async fn read_client_payload_keeps_pool_buffer_checked_out_until_frame_drop() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("middle relay test lock must be available"); + + let pool = Arc::new(BufferPool::with_config(64, 2)); + pool.preallocate(1); + assert_eq!(pool.stats().pooled, 1, "one pooled buffer must be available"); + + let (reader, mut writer) = duplex(1024); + let mut crypto_reader = make_crypto_reader(reader); + let stats = Stats::new(); + let forensics = make_forensics_state(); + let mut frame_counter = 0; + + let payload = [0x41u8, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48]; + let mut plaintext = Vec::with_capacity(4 + payload.len()); + plaintext.extend_from_slice(&(payload.len() as u32).to_le_bytes()); + plaintext.extend_from_slice(&payload); + let encrypted = encrypt_for_reader(&plaintext); + writer.write_all(&encrypted).await.unwrap(); + + let (frame, quickack) = read_client_payload( + &mut crypto_reader, + ProtoTag::Intermediate, + 1024, + TokioDuration::from_secs(1), + &pool, + &forensics, + &mut frame_counter, + &stats, + ) + .await + .expect("payload read must succeed") + .expect("frame must be present"); + + assert!(!quickack); + assert_eq!(frame.as_ref(), &payload); + assert_eq!( + pool.stats().pooled, + 0, + "buffer must stay checked out while frame payload is alive" + ); + + drop(frame); + assert!( + pool.stats().pooled >= 1, + "buffer must return to pool only after frame drop" + ); +} + +#[tokio::test] +async fn enqueue_c2me_close_unblocks_after_queue_drain() { + let (tx, mut rx) = mpsc::channel::(1); + tx.send(C2MeCommand::Data { + payload: make_pooled_payload(&[0x41]), + flags: 0, + }) + .await + .unwrap(); + + let tx2 = tx.clone(); + let close_task = tokio::spawn(async move { enqueue_c2me_command(&tx2, C2MeCommand::Close).await }); + + tokio::time::sleep(TokioDuration::from_millis(10)).await; + + let first = timeout(TokioDuration::from_millis(100), rx.recv()) + .await + .unwrap() + .expect("first queued item must be present"); + assert!(matches!(first, C2MeCommand::Data { .. })); + + close_task.await.unwrap().expect("close enqueue must succeed after drain"); + + let second = timeout(TokioDuration::from_millis(100), rx.recv()) + .await + .unwrap() + .expect("close command must follow after queue drain"); + assert!(matches!(second, C2MeCommand::Close)); +} + +#[tokio::test] +async fn enqueue_c2me_close_full_then_receiver_drop_fails_cleanly() { + let (tx, rx) = mpsc::channel::(1); + tx.send(C2MeCommand::Data { + payload: make_pooled_payload(&[0x42]), + flags: 0, + }) + .await + .unwrap(); + + let tx2 = tx.clone(); + let close_task = tokio::spawn(async move { enqueue_c2me_command(&tx2, C2MeCommand::Close).await }); + + tokio::time::sleep(TokioDuration::from_millis(10)).await; + drop(rx); + + let result = timeout(TokioDuration::from_secs(1), close_task) + .await + .expect("close task must finish") + .expect("close task must not panic"); + assert!( + result.is_err(), + "close enqueue must fail cleanly when receiver is dropped under pressure" + ); +} + +#[tokio::test] +async fn process_me_writer_response_ack_obeys_flush_policy() { + let (writer_side, _reader_side) = duplex(1024); + let mut writer = make_crypto_writer(writer_side); + let rng = SecureRandom::new(); + let mut frame_buf = Vec::new(); + let stats = Stats::new(); + let bytes_me2c = AtomicU64::new(0); + + let immediate = process_me_writer_response( + MeResponse::Ack(0x11223344), + &mut writer, + ProtoTag::Intermediate, + &rng, + &mut frame_buf, + &stats, + "user", + None, + &bytes_me2c, + 77, + true, + false, + ) + .await + .expect("ack response must be processed"); + + assert!(matches!( + immediate, + MeWriterResponseOutcome::Continue { + frames: 1, + bytes: 4, + flush_immediately: true, + } + )); + + let delayed = process_me_writer_response( + MeResponse::Ack(0x55667788), + &mut writer, + ProtoTag::Intermediate, + &rng, + &mut frame_buf, + &stats, + "user", + None, + &bytes_me2c, + 77, + false, + false, + ) + .await + .expect("ack response must be processed"); + + assert!(matches!( + delayed, + MeWriterResponseOutcome::Continue { + frames: 1, + bytes: 4, + flush_immediately: false, + } + )); +} + +#[tokio::test] +async fn process_me_writer_response_data_updates_byte_accounting() { + let (writer_side, _reader_side) = duplex(1024); + let mut writer = make_crypto_writer(writer_side); + let rng = SecureRandom::new(); + let mut frame_buf = Vec::new(); + let stats = Stats::new(); + let bytes_me2c = AtomicU64::new(0); + + let payload = vec![1u8, 2, 3, 4, 5, 6, 7, 8, 9]; + let outcome = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from(payload.clone()), + }, + &mut writer, + ProtoTag::Intermediate, + &rng, + &mut frame_buf, + &stats, + "user", + None, + &bytes_me2c, + 88, + false, + false, + ) + .await + .expect("data response must be processed"); + + assert!(matches!( + outcome, + MeWriterResponseOutcome::Continue { + frames: 1, + bytes, + flush_immediately: false, + } if bytes == payload.len() + )); + assert_eq!( + bytes_me2c.load(std::sync::atomic::Ordering::Relaxed), + payload.len() as u64, + "ME->C byte accounting must increase by emitted payload size" + ); +} + +#[tokio::test] +async fn process_me_writer_response_data_enforces_live_user_quota() { + let (writer_side, mut reader_side) = duplex(1024); + let mut writer = make_crypto_writer(writer_side); + let rng = SecureRandom::new(); + let mut frame_buf = Vec::new(); + let stats = Stats::new(); + let bytes_me2c = AtomicU64::new(0); + + stats.add_user_octets_from("quota-user", 10); + + let result = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from(vec![1u8, 2, 3, 4]), + }, + &mut writer, + ProtoTag::Intermediate, + &rng, + &mut frame_buf, + &stats, + "quota-user", + Some(12), + &bytes_me2c, + 89, + false, + false, + ) + .await; + + assert!( + matches!(result, Err(ProxyError::DataQuotaExceeded { user }) if user == "quota-user"), + "ME->client runtime path must terminate when live user quota is crossed" + ); + + let mut raw = [0u8; 1]; + assert!( + timeout(TokioDuration::from_millis(100), reader_side.read(&mut raw)) + .await + .is_err(), + "quota exhaustion must not write any ciphertext to the client stream" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn process_me_writer_response_concurrent_same_user_quota_does_not_overshoot_limit() { + let stats = Stats::new(); + let bytes_me2c = AtomicU64::new(0); + let user = "quota-race-user"; + + let (writer_side_a, _reader_side_a) = duplex(1024); + let (writer_side_b, _reader_side_b) = duplex(1024); + let mut writer_a = make_crypto_writer(writer_side_a); + let mut writer_b = make_crypto_writer(writer_side_b); + let mut frame_buf_a = Vec::new(); + let mut frame_buf_b = Vec::new(); + let rng_a = SecureRandom::new(); + let rng_b = SecureRandom::new(); + + let fut_a = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0x11]), + }, + &mut writer_a, + ProtoTag::Intermediate, + &rng_a, + &mut frame_buf_a, + &stats, + user, + Some(1), + &bytes_me2c, + 91, + false, + false, + ); + let fut_b = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0x22]), + }, + &mut writer_b, + ProtoTag::Intermediate, + &rng_b, + &mut frame_buf_b, + &stats, + user, + Some(1), + &bytes_me2c, + 92, + false, + false, + ); + + let (result_a, result_b) = tokio::join!(fut_a, fut_b); + + assert!( + matches!(result_a, Err(ProxyError::DataQuotaExceeded { ref user }) if user == "quota-race-user") + || matches!(result_a, Ok(_)), + "concurrent quota test must complete without panicking" + ); + assert!( + matches!(result_b, Err(ProxyError::DataQuotaExceeded { ref user }) if user == "quota-race-user") + || matches!(result_b, Ok(_)), + "concurrent quota test must complete without panicking" + ); + assert!( + stats.get_user_total_octets(user) <= 1, + "same-user concurrent middle-relay responses must not overshoot the configured quota" + ); +} + +#[tokio::test] +async fn process_me_writer_response_data_does_not_forward_partial_payload_when_remaining_quota_is_smaller_than_message() { + let (writer_side, mut reader_side) = duplex(1024); + let mut writer = make_crypto_writer(writer_side); + let rng = SecureRandom::new(); + let mut frame_buf = Vec::new(); + let stats = Stats::new(); + let bytes_me2c = AtomicU64::new(0); + + stats.add_user_octets_to("partial-quota-user", 3); + + let result = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from(vec![1u8, 2, 3, 4]), + }, + &mut writer, + ProtoTag::Intermediate, + &rng, + &mut frame_buf, + &stats, + "partial-quota-user", + Some(4), + &bytes_me2c, + 90, + false, + false, + ) + .await; + + assert!( + matches!(result, Err(ProxyError::DataQuotaExceeded { user }) if user == "partial-quota-user"), + "ME->client runtime path must reject oversized payloads before writing" + ); + + let mut raw = [0u8; 1]; + assert!( + timeout(TokioDuration::from_millis(100), reader_side.read(&mut raw)) + .await + .is_err(), + "oversized payloads must not leak any partial ciphertext to the client stream" + ); +} + +#[tokio::test] +async fn middle_relay_abort_midflight_releases_route_gauge() { + let stats = Arc::new(Stats::new()); + let me_pool = make_me_pool_for_abort_test(stats.clone()).await; + let config = Arc::new(ProxyConfig::default()); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Middle)); + let route_snapshot = route_runtime.snapshot(); + + let (server_side, client_side) = duplex(64 * 1024); + let (server_reader, server_writer) = tokio::io::split(server_side); + let crypto_reader = make_crypto_reader(server_reader); + let crypto_writer = make_crypto_writer(server_writer); + + let success = HandshakeSuccess { + user: "abort-middle-user".to_string(), + dc_idx: 2, + proto_tag: ProtoTag::Intermediate, + dec_key: [0u8; 32], + dec_iv: 0, + enc_key: [0u8; 32], + enc_iv: 0, + peer: "127.0.0.1:50001".parse().unwrap(), + is_tls: false, + }; + + let relay_task = tokio::spawn(handle_via_middle_proxy( + crypto_reader, + crypto_writer, + success, + me_pool, + stats.clone(), + config, + buffer_pool, + "127.0.0.1:443".parse().unwrap(), + rng, + route_runtime.subscribe(), + route_snapshot, + 0xdecafbad, + )); + + let started = tokio::time::timeout(TokioDuration::from_secs(2), async { + loop { + if stats.get_current_connections_me() == 1 { + break; + } + tokio::time::sleep(TokioDuration::from_millis(10)).await; + } + }) + .await; + assert!(started.is_ok(), "middle relay must increment route gauge before abort"); + + relay_task.abort(); + let joined = relay_task.await; + assert!(joined.is_err(), "aborted middle relay task must return join error"); + + tokio::time::sleep(TokioDuration::from_millis(20)).await; + assert_eq!( + stats.get_current_connections_me(), + 0, + "route gauge must be released when middle relay task is aborted mid-flight" + ); + + drop(client_side); +} + +#[tokio::test] +async fn middle_relay_cutover_midflight_releases_route_gauge() { + let stats = Arc::new(Stats::new()); + let me_pool = make_me_pool_for_abort_test(stats.clone()).await; + let config = Arc::new(ProxyConfig::default()); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Middle)); + let route_snapshot = route_runtime.snapshot(); + + let (server_side, client_side) = duplex(64 * 1024); + let (server_reader, server_writer) = tokio::io::split(server_side); + let crypto_reader = make_crypto_reader(server_reader); + let crypto_writer = make_crypto_writer(server_writer); + + let success = HandshakeSuccess { + user: "cutover-middle-user".to_string(), + dc_idx: 2, + proto_tag: ProtoTag::Intermediate, + dec_key: [0u8; 32], + dec_iv: 0, + enc_key: [0u8; 32], + enc_iv: 0, + peer: "127.0.0.1:50003".parse().unwrap(), + is_tls: false, + }; + + let relay_task = tokio::spawn(handle_via_middle_proxy( + crypto_reader, + crypto_writer, + success, + me_pool, + stats.clone(), + config, + buffer_pool, + "127.0.0.1:443".parse().unwrap(), + rng, + route_runtime.subscribe(), + route_snapshot, + 0xfeed_beef, + )); + + tokio::time::timeout(TokioDuration::from_secs(2), async { + loop { + if stats.get_current_connections_me() == 1 { + break; + } + tokio::time::sleep(TokioDuration::from_millis(10)).await; + } + }) + .await + .expect("middle relay must increment route gauge before cutover"); + + assert!( + route_runtime.set_mode(RelayRouteMode::Direct).is_some(), + "cutover must advance route generation" + ); + + let relay_result = tokio::time::timeout(TokioDuration::from_secs(6), relay_task) + .await + .expect("middle relay must terminate after cutover") + .expect("middle relay task must not panic"); + assert!( + relay_result.is_err(), + "cutover should terminate middle relay session" + ); + assert!( + matches!( + relay_result, + Err(ProxyError::Proxy(ref msg)) if msg == ROUTE_SWITCH_ERROR_MSG + ), + "client-visible cutover error must stay generic and avoid route-internal metadata" + ); + + assert_eq!( + stats.get_current_connections_me(), + 0, + "route gauge must be released when middle relay exits on cutover" + ); + + drop(client_side); +} + +async fn run_quota_race_attempt( + stats: &Stats, + bytes_me2c: &AtomicU64, + user: &str, + payload: u8, + conn_id: u64, + barrier: Arc, +) -> Result { + let (writer_side, _reader_side) = duplex(1024); + let mut writer = make_crypto_writer(writer_side); + let rng = SecureRandom::new(); + let mut frame_buf = Vec::new(); + + barrier.wait().await; + process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from(vec![payload]), + }, + &mut writer, + ProtoTag::Intermediate, + &rng, + &mut frame_buf, + stats, + user, + Some(1), + bytes_me2c, + conn_id, + false, + false, + ) + .await +} + +#[tokio::test] +async fn abridged_max_extended_length_fails_closed_without_panic_or_partial_read() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("middle relay test lock must be available"); + + let (reader, mut writer) = duplex(256); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let forensics = make_forensics_state(); + let mut frame_counter = 0; + + let plaintext = vec![0x7f, 0xff, 0xff, 0xff]; + let encrypted = encrypt_for_reader(&plaintext); + writer.write_all(&encrypted).await.unwrap(); + + let result = read_client_payload( + &mut crypto_reader, + ProtoTag::Abridged, + 4096, + TokioDuration::from_secs(1), + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + ) + .await; + + assert!(result.is_err(), "oversized abridged length must fail closed"); + assert_eq!(frame_counter, 0, "oversized frame must not be counted as accepted"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn deterministic_quota_race_exactly_one_succeeds_and_one_is_rejected() { + let stats = Stats::new(); + let bytes_me2c = AtomicU64::new(0); + let user = "gap-t04-race-user"; + let barrier = Arc::new(Barrier::new(2)); + + let f1 = run_quota_race_attempt(&stats, &bytes_me2c, user, 0x11, 5001, barrier.clone()); + let f2 = run_quota_race_attempt(&stats, &bytes_me2c, user, 0x22, 5002, barrier); + + let (r1, r2) = tokio::join!(f1, f2); + + assert!( + matches!(r1, Ok(_) | Err(ProxyError::DataQuotaExceeded { .. })), + "first racer must either finish or fail closed on quota" + ); + assert!( + matches!(r2, Ok(_) | Err(ProxyError::DataQuotaExceeded { .. })), + "second racer must either finish or fail closed on quota" + ); + assert!( + matches!(r1, Err(ProxyError::DataQuotaExceeded { .. })) + || matches!(r2, Err(ProxyError::DataQuotaExceeded { .. })), + "at least one racer must be quota-rejected" + ); + assert_eq!( + stats.get_user_total_octets(user), + 1, + "same-user race must forward/account exactly one payload byte" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_quota_race_bursts_never_allow_double_success_per_round() { + let stats = Stats::new(); + let bytes_me2c = AtomicU64::new(0); + + for round in 0..128u64 { + let user = format!("gap-t04-race-burst-{round}"); + let barrier = Arc::new(Barrier::new(2)); + + let one = run_quota_race_attempt( + &stats, + &bytes_me2c, + &user, + 0x33, + 6000 + round, + barrier.clone(), + ); + let two = run_quota_race_attempt( + &stats, + &bytes_me2c, + &user, + 0x44, + 7000 + round, + barrier, + ); + + let (r1, r2) = tokio::join!(one, two); + assert!( + matches!(r1, Ok(_) | Err(ProxyError::DataQuotaExceeded { .. })) + && matches!(r2, Ok(_) | Err(ProxyError::DataQuotaExceeded { .. })), + "round {round}: racers must resolve cleanly without unexpected errors" + ); + assert!( + matches!(r1, Err(ProxyError::DataQuotaExceeded { .. })) + || matches!(r2, Err(ProxyError::DataQuotaExceeded { .. })), + "round {round}: at least one racer must be quota-rejected" + ); + assert_eq!( + stats.get_user_total_octets(&user), + 1, + "round {round}: same-user total octets must remain exactly 1 (single forwarded winner)" + ); + } +} + +#[tokio::test] +async fn middle_relay_cutover_storm_multi_session_keeps_generic_errors_and_releases_gauge() { + let session_count = 6usize; + let stats = Arc::new(Stats::new()); + let me_pool = make_me_pool_for_abort_test(stats.clone()).await; + let config = Arc::new(ProxyConfig::default()); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Middle)); + let route_snapshot = route_runtime.snapshot(); + + let mut relay_tasks = Vec::with_capacity(session_count); + let mut client_sides = Vec::with_capacity(session_count); + + for idx in 0..session_count { + let (server_side, client_side) = duplex(64 * 1024); + client_sides.push(client_side); + let (server_reader, server_writer) = tokio::io::split(server_side); + let crypto_reader = make_crypto_reader(server_reader); + let crypto_writer = make_crypto_writer(server_writer); + + let success = HandshakeSuccess { + user: format!("cutover-storm-middle-user-{idx}"), + dc_idx: 2, + proto_tag: ProtoTag::Intermediate, + dec_key: [0u8; 32], + dec_iv: 0, + enc_key: [0u8; 32], + enc_iv: 0, + peer: SocketAddr::new( + std::net::IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 1)), + 52000 + idx as u16, + ), + is_tls: false, + }; + + relay_tasks.push(tokio::spawn(handle_via_middle_proxy( + crypto_reader, + crypto_writer, + success, + me_pool.clone(), + stats.clone(), + config.clone(), + buffer_pool.clone(), + "127.0.0.1:443".parse().unwrap(), + rng.clone(), + route_runtime.subscribe(), + route_snapshot, + 0xB000_0000 + idx as u64, + ))); + } + + tokio::time::timeout(TokioDuration::from_secs(4), async { + loop { + if stats.get_current_connections_me() == session_count as u64 { + break; + } + tokio::time::sleep(TokioDuration::from_millis(10)).await; + } + }) + .await + .expect("all middle sessions must become active before cutover storm"); + + let route_runtime_flipper = route_runtime.clone(); + let flipper = tokio::spawn(async move { + for step in 0..64u32 { + let mode = if (step & 1) == 0 { + RelayRouteMode::Direct + } else { + RelayRouteMode::Middle + }; + let _ = route_runtime_flipper.set_mode(mode); + tokio::time::sleep(TokioDuration::from_millis(15)).await; + } + }); + + for relay_task in relay_tasks { + let relay_result = tokio::time::timeout(TokioDuration::from_secs(10), relay_task) + .await + .expect("middle relay task must finish under cutover storm") + .expect("middle relay task must not panic"); + + assert!( + matches!( + relay_result, + Err(ProxyError::Proxy(ref msg)) if msg == ROUTE_SWITCH_ERROR_MSG + ), + "storm-cutover termination must remain generic for all middle sessions" + ); + } + + flipper.abort(); + let _ = flipper.await; + + assert_eq!( + stats.get_current_connections_me(), + 0, + "middle route gauge must return to zero after cutover storm" + ); + + drop(client_sides); +} + +#[tokio::test] +async fn secure_padding_distribution_in_relay_writer() { + timeout(TokioDuration::from_secs(10), async { + let (mut client_side, relay_side) = duplex(512 * 1024); + let key = [0u8; 32]; + let iv = 0u128; + let mut writer = CryptoWriter::new(relay_side, AesCtr::new(&key, iv), 8 * 1024); + let rng = Arc::new(SecureRandom::new()); + let mut frame_buf = Vec::new(); + let mut decryptor = AesCtr::new(&key, iv); + + let mut padding_counts = [0usize; 4]; + let iterations = 180usize; + let payload = vec![0xAAu8; 100]; // 4-byte aligned + + for _ in 0..iterations { + write_client_payload( + &mut writer, + ProtoTag::Secure, + 0, + &payload, + &rng, + &mut frame_buf, + ) + .await + .expect("payload write must succeed"); + writer + .flush() + .await + .expect("writer flush must complete so encrypted frame becomes readable"); + + let mut len_buf = [0u8; 4]; + client_side + .read_exact(&mut len_buf) + .await + .expect("must read encrypted secure length"); + let decrypted_len_bytes = decryptor.decrypt(&len_buf); + let decrypted_len_bytes: [u8; 4] = decrypted_len_bytes + .try_into() + .expect("decrypted length must be 4 bytes"); + let wire_len = (u32::from_le_bytes(decrypted_len_bytes) & 0x7fff_ffff) as usize; + + assert!( + wire_len >= payload.len(), + "wire length must include at least payload bytes" + ); + let padding_len = wire_len - payload.len(); + assert!(padding_len >= 1 && padding_len <= 3); + padding_counts[padding_len] += 1; + + // Drain and decrypt frame bytes so CTR state stays aligned across writes. + let mut trash = vec![0u8; wire_len]; + client_side + .read_exact(&mut trash) + .await + .expect("must read encrypted secure frame body"); + let _ = decryptor.decrypt(&trash); + } + + for p in 1..=3 { + let count = padding_counts[p]; + assert!( + count > iterations / 8, + "padding length {p} is under-represented ({count}/{iterations})" + ); + } + }) + .await + .expect("secure padding distribution test exceeded runtime budget"); +} + +#[tokio::test] +async fn negative_middle_end_connection_lost_during_relay_exits_on_client_eof() { + let (client_reader_side, client_writer_side) = duplex(1024); + let (_relay_reader_side, relay_writer_side) = duplex(1024); + + let key = [0u8; 32]; + let iv = 0u128; + let crypto_reader = CryptoReader::new(client_reader_side, AesCtr::new(&key, iv)); + let crypto_writer = CryptoWriter::new(relay_writer_side, AesCtr::new(&key, iv), 1024); + + let stats = Arc::new(Stats::new()); + let config = Arc::new(ProxyConfig::default()); + let buffer_pool = Arc::new(BufferPool::with_config(1024, 1)); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = RouteRuntimeController::new(RelayRouteMode::Middle); + + // Create an ME pool. + let me_pool = make_me_pool_for_abort_test(stats.clone()).await; + + // ConnRegistry ids are monotonic; reserve one id so we can predict the + // next session conn_id and close it deterministically without relying on + // writer-bound views such as active_conn_ids(). + let (probe_conn_id, probe_rx) = me_pool.registry().register().await; + drop(probe_rx); + me_pool.registry().unregister(probe_conn_id).await; + let target_conn_id = probe_conn_id.wrapping_add(1); + + let success = HandshakeSuccess { + user: "test-user".to_string(), + peer: "127.0.0.1:12345".parse().unwrap(), + dc_idx: 1, + proto_tag: ProtoTag::Intermediate, + enc_key: key, + enc_iv: iv, + dec_key: key, + dec_iv: iv, + is_tls: false, + }; + + let session_task = tokio::spawn(handle_via_middle_proxy( + crypto_reader, + crypto_writer, + success, + me_pool.clone(), + stats.clone(), + config.clone(), + buffer_pool.clone(), + "127.0.0.1:443".parse().unwrap(), + rng.clone(), + route_runtime.subscribe(), + route_runtime.snapshot(), + 0x1234_5678, + )); + + // Wait until session startup is visible, then unregister the predicted + // conn_id to close the per-session ME response channel. + timeout(TokioDuration::from_millis(500), async { + loop { + if stats.get_current_connections_me() >= 1 { + break; + } + tokio::time::sleep(TokioDuration::from_millis(10)).await; + } + }) + .await + .expect("ME session must start before channel close simulation"); + + me_pool.registry().unregister(target_conn_id).await; + + drop(client_writer_side); + + let result = timeout(TokioDuration::from_secs(2), session_task) + .await + .expect("Session task must terminate after ME drop and client EOF") + .expect("Session task must not panic"); + + assert!( + result.is_ok(), + "Session should complete cleanly after ME drop when client closes, got: {:?}", + result + ); +} + +#[tokio::test] +async fn adversarial_middle_end_drop_plus_cutover_returns_generic_route_switch() { + let (client_reader_side, _client_writer_side) = duplex(1024); + let (_relay_reader_side, relay_writer_side) = duplex(1024); + + let key = [0u8; 32]; + let iv = 0u128; + let crypto_reader = CryptoReader::new(client_reader_side, AesCtr::new(&key, iv)); + let crypto_writer = CryptoWriter::new(relay_writer_side, AesCtr::new(&key, iv), 1024); + + let stats = Arc::new(Stats::new()); + let config = Arc::new(ProxyConfig::default()); + let buffer_pool = Arc::new(BufferPool::with_config(1024, 1)); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Middle)); + + let me_pool = make_me_pool_for_abort_test(stats.clone()).await; + + // Predict the next conn_id so we can force-drop its ME channel deterministically. + let (probe_conn_id, probe_rx) = me_pool.registry().register().await; + drop(probe_rx); + me_pool.registry().unregister(probe_conn_id).await; + let target_conn_id = probe_conn_id.wrapping_add(1); + + let success = HandshakeSuccess { + user: "test-user-cutover".to_string(), + peer: "127.0.0.1:12345".parse().unwrap(), + dc_idx: 1, + proto_tag: ProtoTag::Intermediate, + enc_key: key, + enc_iv: iv, + dec_key: key, + dec_iv: iv, + is_tls: false, + }; + + let runtime_clone = route_runtime.clone(); + let session_task = tokio::spawn(handle_via_middle_proxy( + crypto_reader, + crypto_writer, + success, + me_pool.clone(), + stats.clone(), + config, + buffer_pool, + "127.0.0.1:443".parse().unwrap(), + rng, + runtime_clone.subscribe(), + runtime_clone.snapshot(), + 0xC001_CAFE, + )); + + timeout(TokioDuration::from_millis(500), async { + loop { + if stats.get_current_connections_me() >= 1 { + break; + } + tokio::time::sleep(TokioDuration::from_millis(10)).await; + } + }) + .await + .expect("ME session must start before race trigger"); + + // Race ME channel drop with route cutover and assert generic client-visible outcome. + me_pool.registry().unregister(target_conn_id).await; + assert!( + route_runtime.set_mode(RelayRouteMode::Direct).is_some(), + "cutover must advance generation" + ); + + let relay_result = timeout(TokioDuration::from_secs(6), session_task) + .await + .expect("session must terminate under ME-drop + cutover race") + .expect("session task must not panic"); + + assert!( + matches!( + relay_result, + Err(ProxyError::Proxy(ref msg)) if msg == ROUTE_SWITCH_ERROR_MSG + ), + "race outcome must remain generic and not leak ME internals, got: {:?}", + relay_result + ); +} + +#[tokio::test] +async fn stress_middle_end_drop_with_client_eof_never_hangs_across_burst() { + let stats = Arc::new(Stats::new()); + let me_pool = make_me_pool_for_abort_test(stats.clone()).await; + + for round in 0..32u64 { + let (client_reader_side, client_writer_side) = duplex(1024); + let (_relay_reader_side, relay_writer_side) = duplex(1024); + + let key = [0u8; 32]; + let iv = 0u128; + let crypto_reader = CryptoReader::new(client_reader_side, AesCtr::new(&key, iv)); + let crypto_writer = CryptoWriter::new(relay_writer_side, AesCtr::new(&key, iv), 1024); + + let config = Arc::new(ProxyConfig::default()); + let buffer_pool = Arc::new(BufferPool::with_config(1024, 1)); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = RouteRuntimeController::new(RelayRouteMode::Middle); + + let (probe_conn_id, probe_rx) = me_pool.registry().register().await; + drop(probe_rx); + me_pool.registry().unregister(probe_conn_id).await; + let target_conn_id = probe_conn_id.wrapping_add(1); + + let success = HandshakeSuccess { + user: format!("stress-me-drop-eof-{round}"), + peer: "127.0.0.1:12345".parse().unwrap(), + dc_idx: 1, + proto_tag: ProtoTag::Intermediate, + enc_key: key, + enc_iv: iv, + dec_key: key, + dec_iv: iv, + is_tls: false, + }; + + let session_task = tokio::spawn(handle_via_middle_proxy( + crypto_reader, + crypto_writer, + success, + me_pool.clone(), + stats.clone(), + config, + buffer_pool, + "127.0.0.1:443".parse().unwrap(), + rng, + route_runtime.subscribe(), + route_runtime.snapshot(), + 0xD00D_0000 + round, + )); + + timeout(TokioDuration::from_millis(500), async { + loop { + if stats.get_current_connections_me() >= 1 { + break; + } + tokio::time::sleep(TokioDuration::from_millis(10)).await; + } + }) + .await + .expect("session must start before forced drop in burst round"); + + me_pool.registry().unregister(target_conn_id).await; + drop(client_writer_side); + + let result = timeout(TokioDuration::from_secs(2), session_task) + .await + .expect("burst round session must terminate quickly") + .expect("burst round session must not panic"); + + assert!( + result.is_ok(), + "burst round {round}: expected clean shutdown after ME drop + EOF, got: {:?}", + result + ); + } +} diff --git a/src/proxy/tests/relay_adversarial_tests.rs b/src/proxy/tests/relay_adversarial_tests.rs new file mode 100644 index 0000000..f87d82b --- /dev/null +++ b/src/proxy/tests/relay_adversarial_tests.rs @@ -0,0 +1,210 @@ +use super::*; +use crate::error::ProxyError; +use crate::stats::Stats; +use crate::stream::BufferPool; +use std::sync::Arc; +use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; +use tokio::time::{Duration, Instant, timeout}; + +// ------------------------------------------------------------------ +// Priority 3: Async Relay HOL Blocking Prevention (OWASP ASVS 5.1.5) +// ------------------------------------------------------------------ + +#[tokio::test] +async fn relay_hol_blocking_prevention_regression() { + let stats = Arc::new(Stats::new()); + let user = "hol-user"; + + let (client_peer, relay_client) = duplex(65536); + let (relay_server, server_peer) = duplex(65536); + + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + let (mut cp_reader, mut cp_writer) = tokio::io::split(client_peer); + let (mut sp_reader, mut sp_writer) = tokio::io::split(server_peer); + + let relay_task = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 8192, + 8192, + user, + Arc::clone(&stats), + None, + Arc::new(BufferPool::new()), + )); + + let payload_size = 1024 * 10; + let s2c_payload = vec![0x41; payload_size]; + let c2s_payload = vec![0x42; payload_size]; + + let s2c_handle = tokio::spawn(async move { + sp_writer.write_all(&s2c_payload).await.unwrap(); + + let mut total_read = 0; + let mut buf = [0u8; 10]; + while total_read < payload_size { + let n = cp_reader.read(&mut buf).await.unwrap(); + total_read += n; + tokio::time::sleep(Duration::from_millis(100)).await; + } + }); + + let start = Instant::now(); + cp_writer.write_all(&c2s_payload).await.unwrap(); + + let mut server_buf = vec![0u8; payload_size]; + sp_reader.read_exact(&mut server_buf).await.unwrap(); + let elapsed = start.elapsed(); + + assert!(elapsed < Duration::from_millis(1000), "C->S must not be blocked by slow S->C (HOL blocking): {:?}", elapsed); + assert_eq!(server_buf, c2s_payload); + + s2c_handle.abort(); + relay_task.abort(); +} + +// ------------------------------------------------------------------ +// Priority 3: Data Quota Mid-Session Cutoff (OWASP ASVS 5.1.6) +// ------------------------------------------------------------------ + +#[tokio::test] +async fn relay_quota_mid_session_cutoff() { + let stats = Arc::new(Stats::new()); + let user = "quota-mid-user"; + let quota = 5000; + + let (client_peer, relay_client) = duplex(8192); + let (relay_server, server_peer) = duplex(8192); + + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + let (mut _cp_reader, mut cp_writer) = tokio::io::split(client_peer); + let (mut sp_reader, _sp_writer) = tokio::io::split(server_peer); + + let relay_task = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 1024, + 1024, + user, + Arc::clone(&stats), + Some(quota), + Arc::new(BufferPool::new()), + )); + + // Send 4000 bytes (Ok) + let buf1 = vec![0x42; 4000]; + cp_writer.write_all(&buf1).await.unwrap(); + let mut server_recv = vec![0u8; 4000]; + sp_reader.read_exact(&mut server_recv).await.unwrap(); + + // Send another 2000 bytes (Total 6000 > 5000) + let buf2 = vec![0x42; 2000]; + let _ = cp_writer.write_all(&buf2).await; + + let relay_res = timeout(Duration::from_secs(1), relay_task).await.unwrap(); + + match relay_res { + Ok(Err(ProxyError::DataQuotaExceeded { .. })) => { + // Expected + } + other => panic!("Expected DataQuotaExceeded error, got: {:?}", other), + } + + let mut small_buf = [0u8; 1]; + let n = sp_reader.read(&mut small_buf).await.unwrap(); + assert_eq!(n, 0, "Server must see EOF after quota reached"); +} + +#[tokio::test] +async fn relay_chaos_half_close_crossfire_terminates_without_hang() { + let stats = Arc::new(Stats::new()); + + let (mut client_peer, relay_client) = duplex(8192); + let (relay_server, mut server_peer) = duplex(8192); + + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_task = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 1024, + 1024, + "half-close-crossfire", + Arc::clone(&stats), + None, + Arc::new(BufferPool::new()), + )); + + client_peer.write_all(b"c2s-pre-half-close").await.unwrap(); + server_peer.write_all(b"s2c-pre-half-close").await.unwrap(); + + client_peer.shutdown().await.unwrap(); + tokio::time::sleep(Duration::from_millis(10)).await; + server_peer.shutdown().await.unwrap(); + + let done = timeout(Duration::from_secs(1), relay_task) + .await + .expect("relay must terminate after bilateral half-close") + .expect("relay task must not panic"); + assert!(done.is_ok(), "relay must terminate cleanly under half-close crossfire"); +} + +#[tokio::test] +#[ignore = "heavy soak; run manually"] +async fn relay_soak_bidirectional_temporal_jitter_5k_rounds() { + let stats = Arc::new(Stats::new()); + + let (mut client_peer, relay_client) = duplex(65536); + let (relay_server, mut server_peer) = duplex(65536); + + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_task = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 4096, + 4096, + "soak-jitter-user", + Arc::clone(&stats), + None, + Arc::new(BufferPool::new()), + )); + + for i in 0..5_000u32 { + let c = [((i as u8).wrapping_mul(13)).wrapping_add(1); 17]; + client_peer.write_all(&c).await.unwrap(); + let mut c_seen = [0u8; 17]; + server_peer.read_exact(&mut c_seen).await.unwrap(); + assert_eq!(c_seen, c); + + let s = [((i as u8).wrapping_mul(7)).wrapping_add(3); 23]; + server_peer.write_all(&s).await.unwrap(); + let mut s_seen = [0u8; 23]; + client_peer.read_exact(&mut s_seen).await.unwrap(); + assert_eq!(s_seen, s); + + if i % 10 == 0 { + tokio::time::sleep(Duration::from_millis((i % 3) as u64)).await; + } + } + + drop(client_peer); + drop(server_peer); + let done = timeout(Duration::from_secs(2), relay_task) + .await + .expect("relay must stop after soak peers close") + .expect("relay task must not panic"); + assert!(done.is_ok()); +} diff --git a/src/proxy/tests/relay_quota_boundary_blackhat_tests.rs b/src/proxy/tests/relay_quota_boundary_blackhat_tests.rs new file mode 100644 index 0000000..7a2f8b7 --- /dev/null +++ b/src/proxy/tests/relay_quota_boundary_blackhat_tests.rs @@ -0,0 +1,416 @@ +use super::relay_bidirectional; +use crate::error::ProxyError; +use crate::stats::Stats; +use crate::stream::BufferPool; +use rand::rngs::StdRng; +use rand::{RngExt, SeedableRng}; +use std::sync::Arc; +use tokio::io::{duplex, AsyncRead, AsyncReadExt, AsyncWriteExt}; +use tokio::time::{timeout, Duration, Instant}; + +async fn read_available(reader: &mut R, budget: Duration) -> usize { + let start = Instant::now(); + let mut total = 0usize; + let mut buf = [0u8; 256]; + + loop { + let elapsed = start.elapsed(); + if elapsed >= budget { + break; + } + let remaining = budget.saturating_sub(elapsed); + match timeout(remaining, reader.read(&mut buf)).await { + Ok(Ok(0)) => break, + Ok(Ok(n)) => total = total.saturating_add(n), + Ok(Err(_)) | Err(_) => break, + } + } + + total +} + +#[tokio::test] +async fn integration_full_duplex_exact_budget_then_hard_cutoff() { + let stats = Arc::new(Stats::new()); + let user = "quota-full-duplex-boundary-user"; + + let (mut client_peer, relay_client) = duplex(4096); + let (relay_server, mut server_peer) = duplex(4096); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 1024, + 1024, + user, + Arc::clone(&stats), + Some(10), + Arc::new(BufferPool::new()), + )); + + client_peer.write_all(&[0x10, 0x11, 0x12, 0x13]).await.unwrap(); + let mut c2s = [0u8; 4]; + server_peer.read_exact(&mut c2s).await.unwrap(); + assert_eq!(c2s, [0x10, 0x11, 0x12, 0x13]); + + server_peer + .write_all(&[0x20, 0x21, 0x22, 0x23, 0x24, 0x25]) + .await + .unwrap(); + let mut s2c = [0u8; 6]; + client_peer.read_exact(&mut s2c).await.unwrap(); + assert_eq!(s2c, [0x20, 0x21, 0x22, 0x23, 0x24, 0x25]); + + let _ = client_peer.write_all(&[0x99]).await; + let _ = server_peer.write_all(&[0x88]).await; + + let mut probe_server = [0u8; 1]; + let mut probe_client = [0u8; 1]; + let leaked_to_server = timeout(Duration::from_millis(120), server_peer.read(&mut probe_server)).await; + let leaked_to_client = timeout(Duration::from_millis(120), client_peer.read(&mut probe_client)).await; + + assert!( + !matches!(leaked_to_server, Ok(Ok(n)) if n > 0), + "once quota is exhausted, no extra client byte must be forwarded" + ); + assert!( + !matches!(leaked_to_client, Ok(Ok(n)) if n > 0), + "once quota is exhausted, no extra server byte must be forwarded" + ); + + let relay_result = timeout(Duration::from_secs(2), relay) + .await + .expect("relay must terminate under quota cutoff") + .expect("relay task must not panic"); + + assert!(matches!( + relay_result, + Err(ProxyError::DataQuotaExceeded { ref user }) if user == "quota-full-duplex-boundary-user" + )); + assert!(stats.get_user_total_octets(user) <= 10); +} + +#[tokio::test] +async fn negative_preloaded_quota_blocks_both_directions_immediately() { + let stats = Arc::new(Stats::new()); + let user = "quota-preloaded-cutoff-user"; + stats.add_user_octets_from(user, 5); + + let (mut client_peer, relay_client) = duplex(2048); + let (relay_server, mut server_peer) = duplex(2048); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 512, + 512, + user, + Arc::clone(&stats), + Some(5), + Arc::new(BufferPool::new()), + )); + + let _ = tokio::join!( + client_peer.write_all(&[0x41, 0x42]), + server_peer.write_all(&[0x51, 0x52]), + ); + + let leaked_to_server = read_available(&mut server_peer, Duration::from_millis(120)).await; + let leaked_to_client = read_available(&mut client_peer, Duration::from_millis(120)).await; + + assert_eq!(leaked_to_server, 0, "preloaded limit must block C->S immediately"); + assert_eq!(leaked_to_client, 0, "preloaded limit must block S->C immediately"); + + let relay_result = timeout(Duration::from_secs(2), relay) + .await + .expect("relay must terminate under preloaded cutoff") + .expect("relay task must not panic"); + assert!(matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. }))); + assert!(stats.get_user_total_octets(user) <= 5); +} + +#[tokio::test] +async fn edge_quota_one_bidirectional_race_allows_at_most_one_forwarded_octet() { + let stats = Arc::new(Stats::new()); + let user = "quota-one-race-user"; + + let (mut client_peer, relay_client) = duplex(1024); + let (relay_server, mut server_peer) = duplex(1024); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 256, + 256, + user, + Arc::clone(&stats), + Some(1), + Arc::new(BufferPool::new()), + )); + + let _ = tokio::join!(client_peer.write_all(&[0xAA]), server_peer.write_all(&[0xBB])); + + let mut to_server = [0u8; 1]; + let mut to_client = [0u8; 1]; + + let delivered_server = match timeout(Duration::from_millis(120), server_peer.read(&mut to_server)).await { + Ok(Ok(n)) => n, + _ => 0, + }; + let delivered_client = match timeout(Duration::from_millis(120), client_peer.read(&mut to_client)).await { + Ok(Ok(n)) => n, + _ => 0, + }; + + assert!( + delivered_server + delivered_client <= 1, + "quota=1 must not allow >1 forwarded byte across both directions" + ); + + let relay_result = timeout(Duration::from_secs(2), relay) + .await + .expect("relay must terminate under quota=1") + .expect("relay task must not panic"); + assert!(matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. }))); + assert!(stats.get_user_total_octets(user) <= 1); +} + +#[tokio::test] +async fn adversarial_blackhat_alternating_fragmented_jitter_never_overshoots_global_quota() { + let stats = Arc::new(Stats::new()); + let user = "quota-blackhat-jitter-user"; + let quota = 32u64; + + let (mut client_peer, relay_client) = duplex(4096); + let (relay_server, mut server_peer) = duplex(4096); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 256, + 256, + user, + Arc::clone(&stats), + Some(quota), + Arc::new(BufferPool::new()), + )); + + let mut delivered_to_server = 0usize; + let mut delivered_to_client = 0usize; + + for i in 0..256usize { + if relay.is_finished() { + break; + } + + if (i & 1) == 0 { + let _ = client_peer.write_all(&[(i as u8) ^ 0x5A]).await; + let mut one = [0u8; 1]; + if let Ok(Ok(n)) = timeout(Duration::from_millis(4), server_peer.read(&mut one)).await { + delivered_to_server = delivered_to_server.saturating_add(n); + } + } else { + let _ = server_peer.write_all(&[(i as u8) ^ 0xA5]).await; + let mut one = [0u8; 1]; + if let Ok(Ok(n)) = timeout(Duration::from_millis(4), client_peer.read(&mut one)).await { + delivered_to_client = delivered_to_client.saturating_add(n); + } + } + + tokio::time::sleep(Duration::from_millis(((i % 3) + 1) as u64)).await; + } + + let relay_result = timeout(Duration::from_secs(2), relay) + .await + .expect("relay must terminate under black-hat jitter attack") + .expect("relay task must not panic"); + + assert!(matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. }))); + assert!( + delivered_to_server + delivered_to_client <= quota as usize, + "combined forwarded bytes must never exceed configured quota" + ); + assert!(stats.get_user_total_octets(user) <= quota); +} + +#[tokio::test] +async fn light_fuzz_randomized_schedule_preserves_quota_and_forwarded_byte_invariants() { + let mut rng = StdRng::seed_from_u64(0xD15C_A11E_F00D_BAAD); + + for case in 0..48u64 { + let stats = Arc::new(Stats::new()); + let user = format!("quota-fuzz-schedule-{case}"); + let quota = rng.random_range(1u64..=32u64); + + let (mut client_peer, relay_client) = duplex(4096); + let (relay_server, mut server_peer) = duplex(4096); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_user = user.clone(); + let relay_stats = Arc::clone(&stats); + let relay = tokio::spawn(async move { + relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 256, + 256, + &relay_user, + Arc::clone(&relay_stats), + Some(quota), + Arc::new(BufferPool::new()), + ) + .await + }); + + let mut delivered_total = 0usize; + + for _ in 0..96usize { + if relay.is_finished() { + break; + } + + if rng.random::() { + let _ = client_peer.write_all(&[rng.random::()]).await; + let mut one = [0u8; 1]; + if let Ok(Ok(n)) = timeout(Duration::from_millis(3), server_peer.read(&mut one)).await { + delivered_total = delivered_total.saturating_add(n); + } + } else { + let _ = server_peer.write_all(&[rng.random::()]).await; + let mut one = [0u8; 1]; + if let Ok(Ok(n)) = timeout(Duration::from_millis(3), client_peer.read(&mut one)).await { + delivered_total = delivered_total.saturating_add(n); + } + } + } + + drop(client_peer); + drop(server_peer); + + let relay_result = timeout(Duration::from_secs(2), relay) + .await + .expect("fuzz relay must terminate") + .expect("fuzz relay task must not panic"); + + assert!( + relay_result.is_ok() || matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. })), + "relay must either close cleanly or terminate via typed quota error" + ); + assert!( + delivered_total <= quota as usize, + "fuzz case {case}: forwarded bytes must not exceed quota" + ); + assert!( + stats.get_user_total_octets(&user) <= quota, + "fuzz case {case}: accounted bytes must not exceed quota" + ); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_multi_relay_same_user_mixed_direction_jitter_respects_global_quota() { + let stats = Arc::new(Stats::new()); + let user = "quota-stress-multi-relay-user"; + let quota = 64u64; + + let mut workers = Vec::new(); + + for worker_id in 0..4u8 { + let stats = Arc::clone(&stats); + let user = user.to_string(); + + workers.push(tokio::spawn(async move { + let (mut client_peer, relay_client) = duplex(4096); + let (relay_server, mut server_peer) = duplex(4096); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_user = user.clone(); + let relay = tokio::spawn(async move { + relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 256, + 256, + &relay_user, + Arc::clone(&stats), + Some(quota), + Arc::new(BufferPool::new()), + ) + .await + }); + + let mut delivered = 0usize; + + for step in 0..96u8 { + if relay.is_finished() { + break; + } + + if ((step as usize + worker_id as usize) & 1) == 0 { + let _ = client_peer.write_all(&[step ^ 0x3C]).await; + let mut one = [0u8; 1]; + if let Ok(Ok(n)) = timeout(Duration::from_millis(3), server_peer.read(&mut one)).await { + delivered = delivered.saturating_add(n); + } + } else { + let _ = server_peer.write_all(&[step ^ 0xC3]).await; + let mut one = [0u8; 1]; + if let Ok(Ok(n)) = timeout(Duration::from_millis(3), client_peer.read(&mut one)).await { + delivered = delivered.saturating_add(n); + } + } + + tokio::time::sleep(Duration::from_millis((((worker_id as u64) + (step as u64)) % 3) + 1)).await; + } + + drop(client_peer); + drop(server_peer); + let relay_result = timeout(Duration::from_secs(2), relay) + .await + .expect("stress relay must terminate") + .expect("stress relay task must not panic"); + + assert!( + relay_result.is_ok() || matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. })), + "stress relay must either close cleanly or terminate via typed quota error" + ); + delivered + })); + } + + let mut delivered_sum = 0usize; + for worker in workers { + delivered_sum = delivered_sum.saturating_add(worker.await.expect("stress worker must not panic")); + } + + assert!( + stats.get_user_total_octets(user) <= quota, + "global per-user quota must hold under concurrent mixed-direction relay stress" + ); + assert!( + delivered_sum <= quota as usize, + "combined delivered bytes across relays must stay within global quota" + ); +} diff --git a/src/proxy/tests/relay_quota_lock_pressure_adversarial_tests.rs b/src/proxy/tests/relay_quota_lock_pressure_adversarial_tests.rs new file mode 100644 index 0000000..4add5f0 --- /dev/null +++ b/src/proxy/tests/relay_quota_lock_pressure_adversarial_tests.rs @@ -0,0 +1,413 @@ +use super::*; +use crate::error::ProxyError; +use crate::stats::Stats; +use crate::stream::BufferPool; +use dashmap::DashMap; +use std::sync::Arc; +use std::sync::atomic::AtomicBool; +use std::time::Duration; +use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; +use tokio::sync::Barrier; +use tokio::time::Instant; + +#[test] +fn quota_lock_same_user_returns_same_arc_instance() { + let _guard = super::quota_user_lock_test_scope(); + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let a = quota_user_lock("quota-lock-same-user"); + let b = quota_user_lock("quota-lock-same-user"); + assert!(Arc::ptr_eq(&a, &b)); +} + +#[test] +fn quota_lock_parallel_same_user_reuses_single_lock() { + let _guard = super::quota_user_lock_test_scope(); + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let user = "quota-lock-parallel-same"; + let mut handles = Vec::new(); + + for _ in 0..64 { + handles.push(std::thread::spawn(move || quota_user_lock(user))); + } + + let first = handles + .remove(0) + .join() + .expect("thread must return lock handle"); + + for handle in handles { + let got = handle.join().expect("thread must return lock handle"); + assert!(Arc::ptr_eq(&first, &got)); + } +} + +#[test] +fn quota_lock_unique_users_materialize_distinct_entries() { + let _guard = super::quota_user_lock_test_scope(); + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + + map.clear(); + + let base = format!("quota-lock-distinct-{}", std::process::id()); + let users: Vec = (0..(QUOTA_USER_LOCKS_MAX / 2)) + .map(|idx| format!("{base}-{idx}")) + .collect(); + + for user in &users { + let _ = quota_user_lock(user); + } + + for user in &users { + assert!(map.get(user).is_some(), "lock cache must contain entry for {user}"); + } +} + +#[test] +fn quota_lock_unique_churn_stress_keeps_all_inserted_keys_addressable() { + let _guard = super::quota_user_lock_test_scope(); + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + + map.clear(); + + let base = format!("quota-lock-churn-{}", std::process::id()); + for idx in 0..(QUOTA_USER_LOCKS_MAX + 256) { + let _ = quota_user_lock(&format!("{base}-{idx}")); + } + + assert!( + map.len() <= QUOTA_USER_LOCKS_MAX, + "quota lock cache must stay bounded under unique-user churn" + ); +} + +#[test] +fn quota_lock_saturation_returns_stable_overflow_lock_without_cache_growth() { + let _guard = super::quota_user_lock_test_scope(); + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let prefix = format!("quota-held-{}", std::process::id()); + let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); + for idx in 0..QUOTA_USER_LOCKS_MAX { + retained.push(quota_user_lock(&format!("{prefix}-{idx}"))); + } + + assert_eq!( + map.len(), + QUOTA_USER_LOCKS_MAX, + "cache must be saturated for overflow check" + ); + + let overflow_user = format!("quota-overflow-{}", std::process::id()); + let overflow_a = quota_user_lock(&overflow_user); + let overflow_b = quota_user_lock(&overflow_user); + + assert_eq!( + map.len(), + QUOTA_USER_LOCKS_MAX, + "overflow path must not grow lock cache" + ); + assert!( + map.get(&overflow_user).is_none(), + "overflow user lock must stay outside bounded cache under saturation" + ); + assert!( + Arc::ptr_eq(&overflow_a, &overflow_b), + "overflow user must receive stable striped overflow lock while saturated" + ); + + drop(retained); +} + +#[test] +fn quota_lock_reclaims_unreferenced_entries_before_ephemeral_fallback() { + let _guard = super::quota_user_lock_test_scope(); + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + // Saturate with retained strong references first so parallel tests cannot + // reclaim our fixture entries before we validate the reclaim path. + let prefix = format!("quota-reclaim-drop-{}", std::process::id()); + let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); + for idx in 0..QUOTA_USER_LOCKS_MAX { + retained.push(quota_user_lock(&format!("{prefix}-{idx}"))); + } + + drop(retained); + + let overflow_user = format!("quota-reclaim-overflow-{}", std::process::id()); + let overflow = quota_user_lock(&overflow_user); + + assert!( + map.get(&overflow_user).is_some(), + "after reclaiming stale entries, overflow user should become cacheable" + ); + assert!( + Arc::strong_count(&overflow) >= 2, + "cacheable overflow lock should be held by both map and caller" + ); +} + +#[test] +fn quota_lock_saturated_same_user_must_not_return_distinct_locks() { + let _guard = super::quota_user_lock_test_scope(); + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); + for idx in 0..QUOTA_USER_LOCKS_MAX { + retained.push(quota_user_lock(&format!("quota-saturated-held-{}-{idx}", std::process::id()))); + } + + let overflow_user = format!("quota-saturated-same-user-{}", std::process::id()); + let a = quota_user_lock(&overflow_user); + let b = quota_user_lock(&overflow_user); + + assert!( + Arc::ptr_eq(&a, &b), + "same user must not receive distinct locks under saturation because that enables quota race bypass" + ); + + drop(retained); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn quota_lock_saturation_concurrent_same_user_never_overshoots_quota() { + let _guard = super::quota_user_lock_test_scope(); + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); + for idx in 0..QUOTA_USER_LOCKS_MAX { + retained.push(quota_user_lock(&format!("quota-saturated-race-held-{}-{idx}", std::process::id()))); + } + + let stats = Arc::new(Stats::new()); + let user = format!("quota-saturated-race-user-{}", std::process::id()); + let gate = Arc::new(Barrier::new(2)); + + let worker = |label: u8, stats: Arc, user: String, gate: Arc| { + tokio::spawn(async move { + let counters = Arc::new(SharedCounters::new()); + let quota_exceeded = Arc::new(AtomicBool::new(false)); + let mut io = StatsIo::new( + tokio::io::sink(), + counters, + Arc::clone(&stats), + user, + Some(1), + quota_exceeded, + Instant::now(), + ); + gate.wait().await; + io.write_all(&[label]).await + }) + }; + + let one = worker(0x11, Arc::clone(&stats), user.clone(), Arc::clone(&gate)); + let two = worker(0x22, Arc::clone(&stats), user.clone(), Arc::clone(&gate)); + + let _ = tokio::time::timeout(Duration::from_secs(2), async { + let _ = one.await.expect("task one must not panic"); + let _ = two.await.expect("task two must not panic"); + }) + .await + .expect("quota race workers must complete"); + + assert!( + stats.get_user_total_octets(&user) <= 1, + "saturated lock path must never overshoot quota for same user" + ); + + drop(retained); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn quota_lock_saturation_stress_same_user_never_overshoots_quota() { + let _guard = super::quota_user_lock_test_scope(); + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); + for idx in 0..QUOTA_USER_LOCKS_MAX { + retained.push(quota_user_lock(&format!("quota-saturated-stress-held-{}-{idx}", std::process::id()))); + } + + for round in 0..128u32 { + let stats = Arc::new(Stats::new()); + let user = format!("quota-saturated-stress-user-{}-{round}", std::process::id()); + let gate = Arc::new(Barrier::new(2)); + + let one = { + let stats = Arc::clone(&stats); + let user = user.clone(); + let gate = Arc::clone(&gate); + tokio::spawn(async move { + let counters = Arc::new(SharedCounters::new()); + let quota_exceeded = Arc::new(AtomicBool::new(false)); + let mut io = StatsIo::new( + tokio::io::sink(), + counters, + Arc::clone(&stats), + user, + Some(1), + quota_exceeded, + Instant::now(), + ); + gate.wait().await; + io.write_all(&[0x31]).await + }) + }; + + let two = { + let stats = Arc::clone(&stats); + let user = user.clone(); + let gate = Arc::clone(&gate); + tokio::spawn(async move { + let counters = Arc::new(SharedCounters::new()); + let quota_exceeded = Arc::new(AtomicBool::new(false)); + let mut io = StatsIo::new( + tokio::io::sink(), + counters, + Arc::clone(&stats), + user, + Some(1), + quota_exceeded, + Instant::now(), + ); + gate.wait().await; + io.write_all(&[0x32]).await + }) + }; + + let _ = one.await.expect("stress task one must not panic"); + let _ = two.await.expect("stress task two must not panic"); + + assert!( + stats.get_user_total_octets(&user) <= 1, + "round {round}: saturated path must not overshoot quota" + ); + } + + drop(retained); +} + +#[test] +fn quota_error_classifier_accepts_internal_quota_sentinel_only() { + let err = quota_io_error(); + assert!(is_quota_io_error(&err)); +} + +#[test] +fn quota_error_classifier_rejects_plain_permission_denied() { + let err = std::io::Error::new(std::io::ErrorKind::PermissionDenied, "permission denied"); + assert!(!is_quota_io_error(&err)); +} + +#[test] +fn quota_lock_test_scope_recovers_after_guard_poison() { + let poison_result = std::thread::spawn(|| { + let _guard = super::quota_user_lock_test_scope(); + panic!("intentional test-only guard poison"); + }) + .join(); + assert!(poison_result.is_err(), "poison setup thread must panic"); + + let _guard = super::quota_user_lock_test_scope(); + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let a = quota_user_lock("quota-lock-poison-recovery-user"); + let b = quota_user_lock("quota-lock-poison-recovery-user"); + assert!(Arc::ptr_eq(&a, &b)); +} + +#[tokio::test] +async fn quota_lock_integration_zero_quota_cuts_off_without_forwarding() { + let stats = Arc::new(Stats::new()); + let user = "quota-zero-user"; + + let (mut client_peer, relay_client) = duplex(2048); + let (relay_server, mut server_peer) = duplex(2048); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 512, + 512, + user, + Arc::clone(&stats), + Some(0), + Arc::new(BufferPool::new()), + )); + + client_peer + .write_all(b"x") + .await + .expect("client write must succeed"); + + let mut probe = [0u8; 1]; + let forwarded = tokio::time::timeout(Duration::from_millis(80), server_peer.read(&mut probe)).await; + if let Ok(Ok(n)) = forwarded { + assert_eq!(n, 0, "zero quota path must not forward payload bytes"); + } + + let result = tokio::time::timeout(Duration::from_secs(2), relay) + .await + .expect("relay must terminate under zero quota") + .expect("relay task must not panic"); + assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. }))); +} + +#[tokio::test] +async fn quota_lock_integration_no_quota_relays_both_directions_under_burst() { + let stats = Arc::new(Stats::new()); + + let (mut client_peer, relay_client) = duplex(8192); + let (relay_server, mut server_peer) = duplex(8192); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 1024, + 1024, + "quota-none-burst-user", + Arc::clone(&stats), + None, + Arc::new(BufferPool::new()), + )); + + let c2s = vec![0xA5; 2048]; + let s2c = vec![0x5A; 1536]; + + client_peer.write_all(&c2s).await.expect("client burst write must succeed"); + let mut got_c2s = vec![0u8; c2s.len()]; + server_peer.read_exact(&mut got_c2s).await.expect("server must receive c2s burst"); + assert_eq!(got_c2s, c2s); + + server_peer.write_all(&s2c).await.expect("server burst write must succeed"); + let mut got_s2c = vec![0u8; s2c.len()]; + client_peer.read_exact(&mut got_s2c).await.expect("client must receive s2c burst"); + assert_eq!(got_s2c, s2c); + + drop(client_peer); + drop(server_peer); + + let done = tokio::time::timeout(Duration::from_secs(2), relay) + .await + .expect("relay must terminate after peers close") + .expect("relay task must not panic"); + assert!(done.is_ok()); +} diff --git a/src/proxy/tests/relay_quota_model_adversarial_tests.rs b/src/proxy/tests/relay_quota_model_adversarial_tests.rs new file mode 100644 index 0000000..e9e6a61 --- /dev/null +++ b/src/proxy/tests/relay_quota_model_adversarial_tests.rs @@ -0,0 +1,300 @@ +use super::relay_bidirectional; +use crate::error::ProxyError; +use crate::stats::Stats; +use crate::stream::BufferPool; +use rand::rngs::StdRng; +use rand::{RngExt, SeedableRng}; +use std::sync::Arc; +use tokio::io::{duplex, AsyncRead, AsyncReadExt, AsyncWriteExt}; +use tokio::sync::Barrier; +use tokio::time::{timeout, Duration}; + +fn assert_is_prefix(received: &[u8], sent: &[u8], direction: &str) { + assert!( + sent.starts_with(received), + "{direction} stream corruption: received={} sent={} (received must be prefix of sent)", + received.len(), + sent.len() + ); +} + +async fn drain_available(reader: &mut R, out: &mut Vec, rounds: usize) { + for _ in 0..rounds { + let mut buf = [0u8; 64]; + match timeout(Duration::from_millis(2), reader.read(&mut buf)).await { + Ok(Ok(0)) => break, + Ok(Ok(n)) => out.extend_from_slice(&buf[..n]), + Ok(Err(_)) | Err(_) => break, + } + } +} + +#[tokio::test] +async fn model_fuzz_bidirectional_schedule_preserves_prefixes_and_quota_budget() { + let mut rng = StdRng::seed_from_u64(0xC0DE_CAFE_D15C_F00D); + + for case in 0..64u64 { + let stats = Arc::new(Stats::new()); + let user = format!("quota-model-fuzz-{case}"); + let quota = rng.random_range(1u64..=64u64); + + let (mut client_peer, relay_client) = duplex(8192); + let (relay_server, mut server_peer) = duplex(8192); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_user = user.clone(); + let relay_stats = Arc::clone(&stats); + let relay = tokio::spawn(async move { + relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 256, + 256, + &relay_user, + relay_stats, + Some(quota), + Arc::new(BufferPool::new()), + ) + .await + }); + + let mut sent_c2s = Vec::new(); + let mut sent_s2c = Vec::new(); + let mut recv_at_server = Vec::new(); + let mut recv_at_client = Vec::new(); + + for _ in 0..96usize { + if relay.is_finished() { + break; + } + + let do_c2s = rng.random::(); + let chunk_len = rng.random_range(1usize..=12usize); + let mut chunk = vec![0u8; chunk_len]; + for b in &mut chunk { + *b = rng.random::(); + } + + if do_c2s { + if client_peer.write_all(&chunk).await.is_ok() { + sent_c2s.extend_from_slice(&chunk); + } + } else if server_peer.write_all(&chunk).await.is_ok() { + sent_s2c.extend_from_slice(&chunk); + } + + drain_available(&mut server_peer, &mut recv_at_server, 2).await; + drain_available(&mut client_peer, &mut recv_at_client, 2).await; + + assert_is_prefix(&recv_at_server, &sent_c2s, "C->S"); + assert_is_prefix(&recv_at_client, &sent_s2c, "S->C"); + assert!( + recv_at_server.len() + recv_at_client.len() <= quota as usize, + "fuzz case {case}: delivered bytes exceed quota" + ); + assert!( + stats.get_user_total_octets(&user) <= quota, + "fuzz case {case}: accounted bytes exceed quota" + ); + } + + drop(client_peer); + drop(server_peer); + + let relay_result = timeout(Duration::from_secs(2), relay) + .await + .expect("fuzz relay must terminate") + .expect("fuzz relay task must not panic"); + + assert!( + relay_result.is_ok() || matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. })), + "fuzz case {case}: relay must end cleanly or with typed quota error" + ); + + assert_is_prefix(&recv_at_server, &sent_c2s, "C->S final"); + assert_is_prefix(&recv_at_client, &sent_s2c, "S->C final"); + assert!(recv_at_server.len() + recv_at_client.len() <= quota as usize); + assert!(stats.get_user_total_octets(&user) <= quota); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn adversarial_dual_direction_cutoff_race_allows_at_most_one_forwarded_byte() { + let stats = Arc::new(Stats::new()); + let user = "quota-dual-race-user"; + + let (mut client_peer, relay_client) = duplex(1024); + let (relay_server, mut server_peer) = duplex(1024); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 128, + 128, + user, + Arc::clone(&stats), + Some(1), + Arc::new(BufferPool::new()), + )); + + let gate = Arc::new(Barrier::new(3)); + + let writer_c2s = { + let gate = Arc::clone(&gate); + tokio::spawn(async move { + gate.wait().await; + let _ = client_peer.write_all(&[0xA1]).await; + client_peer + }) + }; + + let writer_s2c = { + let gate = Arc::clone(&gate); + tokio::spawn(async move { + gate.wait().await; + let _ = server_peer.write_all(&[0xB2]).await; + server_peer + }) + }; + + gate.wait().await; + + let mut client_peer = writer_c2s.await.expect("c2s writer must not panic"); + let mut server_peer = writer_s2c.await.expect("s2c writer must not panic"); + + let mut got_at_server = [0u8; 1]; + let mut got_at_client = [0u8; 1]; + + let n_server = match timeout(Duration::from_millis(120), server_peer.read(&mut got_at_server)).await { + Ok(Ok(n)) => n, + _ => 0, + }; + let n_client = match timeout(Duration::from_millis(120), client_peer.read(&mut got_at_client)).await { + Ok(Ok(n)) => n, + _ => 0, + }; + + assert!( + n_server + n_client <= 1, + "quota=1 race must not forward both concurrent direction bytes" + ); + + drop(client_peer); + drop(server_peer); + + let relay_result = timeout(Duration::from_secs(2), relay) + .await + .expect("quota race relay must terminate") + .expect("quota race relay task must not panic"); + + assert!(matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. }))); + assert!(stats.get_user_total_octets(user) <= 1); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_shared_user_multi_relay_global_quota_never_overshoots_under_model_load() { + let stats = Arc::new(Stats::new()); + let user = "quota-model-stress-user"; + let quota = 96u64; + + let mut workers = Vec::new(); + for worker_id in 0..6u64 { + let stats = Arc::clone(&stats); + let user = user.to_string(); + + workers.push(tokio::spawn(async move { + let mut rng = StdRng::seed_from_u64(0x9E37_79B9_7F4A_7C15 ^ worker_id); + + let (mut client_peer, relay_client) = duplex(4096); + let (relay_server, mut server_peer) = duplex(4096); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_user = user.clone(); + let relay_stats = Arc::clone(&stats); + let relay = tokio::spawn(async move { + relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 192, + 192, + &relay_user, + relay_stats, + Some(quota), + Arc::new(BufferPool::new()), + ) + .await + }); + + let mut sent_c2s = Vec::new(); + let mut sent_s2c = Vec::new(); + let mut recv_at_server = Vec::new(); + let mut recv_at_client = Vec::new(); + + for _ in 0..64usize { + if relay.is_finished() { + break; + } + + let choose_c2s = rng.random::(); + let len = rng.random_range(1usize..=10usize); + let mut payload = vec![0u8; len]; + for b in &mut payload { + *b = rng.random::(); + } + + if choose_c2s { + if client_peer.write_all(&payload).await.is_ok() { + sent_c2s.extend_from_slice(&payload); + } + } else if server_peer.write_all(&payload).await.is_ok() { + sent_s2c.extend_from_slice(&payload); + } + + drain_available(&mut server_peer, &mut recv_at_server, 2).await; + drain_available(&mut client_peer, &mut recv_at_client, 2).await; + + assert_is_prefix(&recv_at_server, &sent_c2s, "stress C->S"); + assert_is_prefix(&recv_at_client, &sent_s2c, "stress S->C"); + } + + drop(client_peer); + drop(server_peer); + + let relay_result = timeout(Duration::from_secs(2), relay) + .await + .expect("stress relay must terminate") + .expect("stress relay task must not panic"); + + assert!( + relay_result.is_ok() || matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. })), + "stress relay must end cleanly or with typed quota error" + ); + + recv_at_server.len() + recv_at_client.len() + })); + } + + let mut delivered_sum = 0usize; + for worker in workers { + delivered_sum = delivered_sum.saturating_add(worker.await.expect("worker must not panic")); + } + + assert!( + stats.get_user_total_octets(user) <= quota, + "global per-user quota must never overshoot under concurrent multi-relay model load" + ); + assert!( + delivered_sum <= quota as usize, + "aggregate delivered bytes across relays must remain within global quota" + ); +} diff --git a/src/proxy/tests/relay_quota_overflow_regression_tests.rs b/src/proxy/tests/relay_quota_overflow_regression_tests.rs new file mode 100644 index 0000000..207d603 --- /dev/null +++ b/src/proxy/tests/relay_quota_overflow_regression_tests.rs @@ -0,0 +1,194 @@ +use super::relay_bidirectional; +use crate::error::ProxyError; +use crate::stats::Stats; +use crate::stream::BufferPool; +use std::sync::Arc; +use tokio::io::{duplex, AsyncRead, AsyncReadExt, AsyncWriteExt}; +use tokio::time::{timeout, Duration}; + +async fn read_available(reader: &mut R, budget_ms: u64) -> usize { + let mut total = 0usize; + loop { + let mut buf = [0u8; 64]; + match timeout(Duration::from_millis(budget_ms), reader.read(&mut buf)).await { + Ok(Ok(0)) => break, + Ok(Ok(n)) => total = total.saturating_add(n), + Ok(Err(_)) | Err(_) => break, + } + } + total +} + +#[tokio::test] +async fn regression_client_chunk_larger_than_remaining_quota_does_not_overshoot_accounting() { + let stats = Arc::new(Stats::new()); + let user = "quota-overflow-regression-client-chunk"; + + // Leave only 1 byte remaining under quota. + stats.add_user_octets_from(user, 9); + + let (mut client_peer, relay_client) = duplex(2048); + let (relay_server, mut server_peer) = duplex(2048); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 512, + 512, + user, + Arc::clone(&stats), + Some(10), + Arc::new(BufferPool::new()), + )); + + // Single chunk attempts to cross remaining budget (4 > 1). + client_peer.write_all(&[0x11, 0x22, 0x33, 0x44]).await.unwrap(); + client_peer.shutdown().await.unwrap(); + + let forwarded = read_available(&mut server_peer, 60).await; + + let relay_result = timeout(Duration::from_secs(2), relay) + .await + .expect("relay must terminate after quota overflow attempt") + .expect("relay task must not panic"); + + assert_eq!( + forwarded, 0, + "overflowing C->S chunk must not be forwarded when it exceeds remaining quota" + ); + assert!(matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. }))); + assert!( + stats.get_user_total_octets(user) <= 10, + "accounted bytes must never exceed quota after overflowing chunk" + ); +} + +#[tokio::test] +async fn regression_client_exact_remaining_quota_forwards_once_then_hard_cuts_off() { + let stats = Arc::new(Stats::new()); + let user = "quota-overflow-regression-boundary"; + + // Leave exactly 4 bytes remaining. + stats.add_user_octets_from(user, 6); + + let (mut client_peer, relay_client) = duplex(2048); + let (relay_server, mut server_peer) = duplex(2048); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 256, + 256, + user, + Arc::clone(&stats), + Some(10), + Arc::new(BufferPool::new()), + )); + + // Exact boundary write should pass once. + client_peer.write_all(&[0xAA, 0xBB, 0xCC, 0xDD]).await.unwrap(); + + let mut exact = [0u8; 4]; + timeout(Duration::from_secs(1), server_peer.read_exact(&mut exact)) + .await + .unwrap() + .unwrap(); + assert_eq!(exact, [0xAA, 0xBB, 0xCC, 0xDD]); + + // Any extra byte after boundary should be rejected/cut off. + let _ = client_peer.write_all(&[0xEE]).await; + client_peer.shutdown().await.unwrap(); + + let leaked_after = read_available(&mut server_peer, 60).await; + + let relay_result = timeout(Duration::from_secs(2), relay) + .await + .expect("relay must terminate at quota boundary") + .expect("relay task must not panic"); + + assert_eq!( + leaked_after, 0, + "no bytes may pass after exact boundary is consumed" + ); + assert!(matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. }))); + assert!(stats.get_user_total_octets(user) <= 10); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_parallel_relays_same_user_quota_overflow_never_exceeds_cap() { + let stats = Arc::new(Stats::new()); + let user = "quota-overflow-regression-stress"; + let quota = 12u64; + + let mut handles = Vec::new(); + for _ in 0..4usize { + let stats = Arc::clone(&stats); + let user = user.to_string(); + + handles.push(tokio::spawn(async move { + let (mut client_peer, relay_client) = duplex(4096); + let (relay_server, mut server_peer) = duplex(4096); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_user = user.clone(); + let relay_stats = Arc::clone(&stats); + let relay = tokio::spawn(async move { + relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 192, + 192, + &relay_user, + relay_stats, + Some(quota), + Arc::new(BufferPool::new()), + ) + .await + }); + + // Aggressive sender tries to overflow shared user quota. + let burst = vec![0x5Au8; 64]; + let _ = client_peer.write_all(&burst).await; + let _ = client_peer.shutdown().await; + + let mut forwarded = 0usize; + forwarded = forwarded.saturating_add(read_available(&mut server_peer, 40).await); + + let relay_result = timeout(Duration::from_secs(2), relay) + .await + .expect("stress relay must terminate") + .expect("stress relay task must not panic"); + + assert!( + relay_result.is_ok() || matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. })), + "stress relay must finish cleanly or with typed quota error" + ); + forwarded + })); + } + + let mut forwarded_sum = 0usize; + for handle in handles { + forwarded_sum = forwarded_sum.saturating_add(handle.await.expect("worker must not panic")); + } + + assert!( + forwarded_sum <= quota as usize, + "aggregate forwarded bytes across relays must stay within global user quota" + ); + assert!( + stats.get_user_total_octets(user) <= quota, + "global accounted bytes must stay within quota under overflow stress" + ); +} diff --git a/src/proxy/tests/relay_quota_wake_liveness_regression_tests.rs b/src/proxy/tests/relay_quota_wake_liveness_regression_tests.rs new file mode 100644 index 0000000..1cd5920 --- /dev/null +++ b/src/proxy/tests/relay_quota_wake_liveness_regression_tests.rs @@ -0,0 +1,288 @@ +use super::*; +use crate::stats::Stats; +use dashmap::DashMap; +use std::sync::Arc; +use std::sync::atomic::AtomicBool; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::sync::Barrier; +use tokio::time::{Duration, timeout}; + +fn saturate_lock_cache() -> Vec>> { + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); + for idx in 0..QUOTA_USER_LOCKS_MAX { + retained.push(quota_user_lock(&format!("quota-liveness-saturated-{idx}"))); + } + retained +} + +fn quota_test_guard() -> impl Drop { + super::quota_user_lock_test_scope() +} + +#[tokio::test] +async fn positive_writer_progresses_after_contention_release_without_external_wake() { + let _guard = quota_test_guard(); + + let _retained = saturate_lock_cache(); + let user = "quota-liveness-writer-positive"; + let stats = Arc::new(Stats::new()); + + let lock = quota_user_lock(user); + let held_guard = lock + .try_lock() + .expect("test must hold user quota lock before write"); + + let counters = Arc::new(SharedCounters::new()); + let quota_exceeded = Arc::new(AtomicBool::new(false)); + let mut io = StatsIo::new( + tokio::io::sink(), + counters, + Arc::clone(&stats), + user.to_string(), + Some(1024), + quota_exceeded, + tokio::time::Instant::now(), + ); + + let writer = tokio::spawn(async move { io.write_all(&[0x11]).await }); + + // Let the initial deferred wake fire while contention is still active. + tokio::time::sleep(Duration::from_millis(4)).await; + + drop(held_guard); + + let completed = timeout(Duration::from_millis(250), writer) + .await + .expect("writer must be re-polled and complete after lock release") + .expect("writer task must not panic"); + assert!(completed.is_ok(), "writer must complete after lock release"); +} + +#[tokio::test] +async fn edge_reader_progresses_after_contention_release_without_external_wake() { + let _guard = quota_test_guard(); + + let _retained = saturate_lock_cache(); + let user = "quota-liveness-reader-edge"; + let stats = Arc::new(Stats::new()); + + let lock = quota_user_lock(user); + let held_guard = lock + .try_lock() + .expect("test must hold user quota lock before read"); + + let counters = Arc::new(SharedCounters::new()); + let quota_exceeded = Arc::new(AtomicBool::new(false)); + let mut io = StatsIo::new( + tokio::io::empty(), + counters, + Arc::clone(&stats), + user.to_string(), + Some(1024), + quota_exceeded, + tokio::time::Instant::now(), + ); + + let reader = tokio::spawn(async move { + let mut one = [0u8; 1]; + io.read(&mut one).await + }); + + tokio::time::sleep(Duration::from_millis(4)).await; + drop(held_guard); + + let completed = timeout(Duration::from_millis(250), reader) + .await + .expect("reader must be re-polled and complete after lock release") + .expect("reader task must not panic"); + assert!(completed.is_ok(), "reader must complete after lock release"); +} + +#[tokio::test] +async fn adversarial_early_deferred_wake_consumption_does_not_deadlock_writer() { + let _guard = quota_test_guard(); + + let _retained = saturate_lock_cache(); + let user = "quota-liveness-adversarial"; + let stats = Arc::new(Stats::new()); + + let lock = quota_user_lock(user); + let held_guard = lock + .try_lock() + .expect("test must hold user quota lock before adversarial write"); + + let counters = Arc::new(SharedCounters::new()); + let quota_exceeded = Arc::new(AtomicBool::new(false)); + let mut io = StatsIo::new( + tokio::io::sink(), + counters, + Arc::clone(&stats), + user.to_string(), + Some(1024), + quota_exceeded, + tokio::time::Instant::now(), + ); + + let writer = tokio::spawn(async move { io.write_all(&[0x22]).await }); + + // Force multiple scheduler rounds while lock remains held so the first + // deferred wake has already been consumed under contention. + for _ in 0..32 { + tokio::task::yield_now().await; + } + + drop(held_guard); + + let completed = timeout(Duration::from_millis(300), writer) + .await + .expect("writer must not stay parked forever after release") + .expect("writer task must not panic"); + assert!(completed.is_ok()); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn integration_parallel_waiters_resume_after_single_release_event() { + let _guard = quota_test_guard(); + + let _retained = saturate_lock_cache(); + let user = format!("quota-liveness-integration-{}", std::process::id()); + let stats = Arc::new(Stats::new()); + let barrier = Arc::new(Barrier::new(13)); + + let lock = quota_user_lock(&user); + let held_guard = lock + .try_lock() + .expect("test must hold user quota lock before launching waiters"); + + let mut waiters = Vec::new(); + for _ in 0..12 { + let stats = Arc::clone(&stats); + let user = user.clone(); + let barrier = Arc::clone(&barrier); + waiters.push(tokio::spawn(async move { + let counters = Arc::new(SharedCounters::new()); + let quota_exceeded = Arc::new(AtomicBool::new(false)); + let mut io = StatsIo::new( + tokio::io::sink(), + counters, + stats, + user, + Some(4096), + quota_exceeded, + tokio::time::Instant::now(), + ); + barrier.wait().await; + io.write_all(&[0x33]).await + })); + } + + barrier.wait().await; + tokio::time::sleep(Duration::from_millis(4)).await; + drop(held_guard); + + timeout(Duration::from_secs(1), async { + for waiter in waiters { + let outcome = waiter.await.expect("waiter must not panic"); + assert!(outcome.is_ok(), "waiter must resume and complete after release"); + } + }) + .await + .expect("all waiters must complete in bounded time"); +} + +#[tokio::test] +async fn light_fuzz_release_timing_matrix_preserves_liveness() { + let _guard = quota_test_guard(); + + let _retained = saturate_lock_cache(); + let stats = Arc::new(Stats::new()); + + let mut seed = 0xD1CE_F00D_0123_4567u64; + for round in 0..64u32 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let delay_ms = 1 + (seed & 0x7) as u64; + let user = format!("quota-liveness-fuzz-{}-{round}", std::process::id()); + + let lock = quota_user_lock(&user); + let held_guard = lock + .try_lock() + .expect("test must hold user quota lock in fuzz round"); + + let counters = Arc::new(SharedCounters::new()); + let quota_exceeded = Arc::new(AtomicBool::new(false)); + let mut io = StatsIo::new( + tokio::io::sink(), + counters, + Arc::clone(&stats), + user, + Some(2048), + quota_exceeded, + tokio::time::Instant::now(), + ); + + let writer = tokio::spawn(async move { io.write_all(&[0x44]).await }); + + tokio::time::sleep(Duration::from_millis(delay_ms)).await; + drop(held_guard); + + let done = timeout(Duration::from_millis(300), writer) + .await + .expect("fuzz round writer must complete") + .expect("fuzz writer task must not panic"); + assert!(done.is_ok(), "fuzz round writer must not stall after release"); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_repeated_contention_cycles_remain_live() { + let _guard = quota_test_guard(); + + let _retained = saturate_lock_cache(); + let stats = Arc::new(Stats::new()); + + for cycle in 0..40u32 { + let user = format!("quota-liveness-stress-{}-{cycle}", std::process::id()); + let lock = quota_user_lock(&user); + let held_guard = lock + .try_lock() + .expect("test must hold lock before stress cycle"); + + let mut tasks = Vec::new(); + for _ in 0..6 { + let stats = Arc::clone(&stats); + let user = user.clone(); + tasks.push(tokio::spawn(async move { + let counters = Arc::new(SharedCounters::new()); + let quota_exceeded = Arc::new(AtomicBool::new(false)); + let mut io = StatsIo::new( + tokio::io::sink(), + counters, + stats, + user, + Some(2048), + quota_exceeded, + tokio::time::Instant::now(), + ); + io.write_all(&[0x55]).await + })); + } + + tokio::task::yield_now().await; + drop(held_guard); + + timeout(Duration::from_millis(700), async { + for task in tasks { + let outcome = task.await.expect("stress task must not panic"); + assert!(outcome.is_ok(), "stress writer must complete"); + } + }) + .await + .expect("stress cycle must finish in bounded time"); + } +} diff --git a/src/proxy/tests/relay_quota_waker_storm_adversarial_tests.rs b/src/proxy/tests/relay_quota_waker_storm_adversarial_tests.rs new file mode 100644 index 0000000..2dabaa3 --- /dev/null +++ b/src/proxy/tests/relay_quota_waker_storm_adversarial_tests.rs @@ -0,0 +1,304 @@ +use super::*; +use crate::stats::Stats; +use dashmap::DashMap; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::task::{Context, Waker}; +use tokio::io::{ReadBuf, AsyncWriteExt}; +use tokio::time::{Duration, timeout}; + +#[derive(Default)] +struct WakeCounter { + wakes: AtomicUsize, +} + +impl std::task::Wake for WakeCounter { + fn wake(self: Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } + + fn wake_by_ref(self: &Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } +} + +fn quota_test_guard() -> impl Drop { + super::quota_user_lock_test_scope() +} + +fn saturate_quota_user_locks() -> Vec>> { + let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); + map.clear(); + + let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); + for idx in 0..QUOTA_USER_LOCKS_MAX { + retained.push(quota_user_lock(&format!("quota-waker-saturate-{idx}"))); + } + retained +} + +#[tokio::test] +async fn positive_contended_writer_emits_deferred_wake_for_liveness() { + let _guard = quota_test_guard(); + + let _retained = saturate_quota_user_locks(); + let stats = Arc::new(Stats::new()); + let user = "quota-waker-positive-user"; + + let lock = quota_user_lock(user); + let held_guard = lock + .try_lock() + .expect("test must hold overflow lock before polling writer"); + + let counters = Arc::new(SharedCounters::new()); + let quota_exceeded = Arc::new(AtomicBool::new(false)); + let mut io = StatsIo::new( + tokio::io::sink(), + counters, + Arc::clone(&stats), + user.to_string(), + Some(1024), + quota_exceeded, + tokio::time::Instant::now(), + ); + + let wake_counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&wake_counter)); + let mut cx = Context::from_waker(&waker); + + let pending = Pin::new(&mut io).poll_write(&mut cx, &[0xA1]); + assert!(pending.is_pending()); + + timeout(Duration::from_millis(100), async { + loop { + if wake_counter.wakes.load(Ordering::Relaxed) >= 1 { + break; + } + tokio::task::yield_now().await; + } + }) + .await + .expect("contended writer must receive deferred wake"); + + drop(held_guard); + let ready = Pin::new(&mut io).poll_write(&mut cx, &[0xA2]); + assert!(ready.is_ready(), "writer must progress after contention release"); +} + +#[tokio::test] +async fn adversarial_blackhat_writer_contention_does_not_create_waker_storm() { + let _guard = quota_test_guard(); + + let _retained = saturate_quota_user_locks(); + let stats = Arc::new(Stats::new()); + let user = "quota-waker-blackhat-writer"; + + let lock = quota_user_lock(user); + let held_guard = lock + .try_lock() + .expect("test must hold overflow lock before polling writer"); + + let counters = Arc::new(SharedCounters::new()); + let quota_exceeded = Arc::new(AtomicBool::new(false)); + let mut io = StatsIo::new( + tokio::io::sink(), + counters, + Arc::clone(&stats), + user.to_string(), + Some(1024), + quota_exceeded, + tokio::time::Instant::now(), + ); + + let wake_counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&wake_counter)); + let mut cx = Context::from_waker(&waker); + + for _ in 0..512 { + let poll = Pin::new(&mut io).poll_write(&mut cx, &[0xBE]); + assert!(poll.is_pending(), "writer must stay pending while lock is held"); + tokio::task::yield_now().await; + } + + let wakes = wake_counter.wakes.load(Ordering::Relaxed); + assert!( + wakes <= 128, + "pending writer retries must not trigger wake storm; observed wakes={wakes}" + ); + + drop(held_guard); + let ready = Pin::new(&mut io).poll_write(&mut cx, &[0xEF]); + assert!(ready.is_ready()); +} + +#[tokio::test] +async fn edge_read_path_contention_keeps_wake_budget_bounded() { + let _guard = quota_test_guard(); + + let _retained = saturate_quota_user_locks(); + let stats = Arc::new(Stats::new()); + let user = "quota-waker-read-edge"; + + let lock = quota_user_lock(user); + let held_guard = lock + .try_lock() + .expect("test must hold overflow lock before polling reader"); + + let counters = Arc::new(SharedCounters::new()); + let quota_exceeded = Arc::new(AtomicBool::new(false)); + let mut io = StatsIo::new( + tokio::io::empty(), + counters, + Arc::clone(&stats), + user.to_string(), + Some(1024), + quota_exceeded, + tokio::time::Instant::now(), + ); + + let wake_counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&wake_counter)); + let mut cx = Context::from_waker(&waker); + let mut storage = [0u8; 1]; + + for _ in 0..512 { + let mut buf = ReadBuf::new(&mut storage); + let poll = Pin::new(&mut io).poll_read(&mut cx, &mut buf); + assert!(poll.is_pending()); + tokio::task::yield_now().await; + } + + let wakes = wake_counter.wakes.load(Ordering::Relaxed); + assert!( + wakes <= 128, + "pending reader retries must not trigger wake storm; observed wakes={wakes}" + ); + + drop(held_guard); + let mut buf = ReadBuf::new(&mut storage); + let ready = Pin::new(&mut io).poll_read(&mut cx, &mut buf); + assert!(ready.is_ready()); +} + +#[tokio::test] +async fn light_fuzz_mixed_poll_schedule_under_contention_stays_bounded() { + let _guard = quota_test_guard(); + + let _retained = saturate_quota_user_locks(); + let stats = Arc::new(Stats::new()); + let user = "quota-waker-fuzz-user"; + + let lock = quota_user_lock(user); + let held_guard = lock + .try_lock() + .expect("test must hold overflow lock before fuzz polling"); + + let counters_w = Arc::new(SharedCounters::new()); + let mut writer_io = StatsIo::new( + tokio::io::sink(), + counters_w, + Arc::clone(&stats), + user.to_string(), + Some(1024), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + let counters_r = Arc::new(SharedCounters::new()); + let mut reader_io = StatsIo::new( + tokio::io::empty(), + counters_r, + Arc::clone(&stats), + user.to_string(), + Some(1024), + Arc::new(AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + let wake_counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&wake_counter)); + let mut cx = Context::from_waker(&waker); + let mut seed = 0xBADC_0FFE_EE11_2211u64; + let mut storage = [0u8; 1]; + + for _ in 0..1024 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + if (seed & 1) == 0 { + let poll = Pin::new(&mut writer_io).poll_write(&mut cx, &[0x44]); + assert!(poll.is_pending()); + } else { + let mut buf = ReadBuf::new(&mut storage); + let poll = Pin::new(&mut reader_io).poll_read(&mut cx, &mut buf); + assert!(poll.is_pending()); + } + tokio::task::yield_now().await; + } + + assert!( + wake_counter.wakes.load(Ordering::Relaxed) <= 192, + "mixed contention fuzz must keep deferred wake count tightly bounded" + ); + + drop(held_guard); + let ready_w = Pin::new(&mut writer_io).poll_write(&mut cx, &[0x55]); + assert!(ready_w.is_ready()); + + let mut buf = ReadBuf::new(&mut storage); + let ready_r = Pin::new(&mut reader_io).poll_read(&mut cx, &mut buf); + assert!(ready_r.is_ready()); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +#[ignore = "red-team detector: reveals possible starvation if deferred wake fires before contention release"] +async fn stress_many_contended_writers_complete_after_release() { + let _guard = quota_test_guard(); + + let _retained = saturate_quota_user_locks(); + let user = "quota-waker-stress-user".to_string(); + let stats = Arc::new(Stats::new()); + + let lock = quota_user_lock(&user); + let held_guard = lock + .try_lock() + .expect("test must hold overflow lock before launching contended tasks"); + + let mut tasks = Vec::new(); + for _ in 0..32 { + let stats = Arc::clone(&stats); + let user = user.clone(); + tasks.push(tokio::spawn(async move { + let counters = Arc::new(SharedCounters::new()); + let quota_exceeded = Arc::new(AtomicBool::new(false)); + let mut io = StatsIo::new( + tokio::io::sink(), + counters, + stats, + user, + Some(2048), + quota_exceeded, + tokio::time::Instant::now(), + ); + + io.write_all(&[0xAA]).await + })); + } + + for _ in 0..8 { + tokio::task::yield_now().await; + } + + drop(held_guard); + + timeout(Duration::from_secs(2), async { + for task in tasks { + let result = task.await.expect("stress task must not panic"); + assert!(result.is_ok(), "task must complete after lock release"); + } + }) + .await + .expect("all contended writer tasks must finish in bounded time after release"); +} diff --git a/src/proxy/tests/relay_security_tests.rs b/src/proxy/tests/relay_security_tests.rs new file mode 100644 index 0000000..b9b3478 --- /dev/null +++ b/src/proxy/tests/relay_security_tests.rs @@ -0,0 +1,1197 @@ +use super::relay_bidirectional; +use crate::error::ProxyError; +use crate::stats::Stats; +use crate::stream::BufferPool; +use std::future::poll_fn; +use std::io; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Mutex; +use std::task::{Context, Poll}; +use std::task::Waker; +use tokio::io::{AsyncRead, ReadBuf}; +use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt, duplex}; +use tokio::time::{Duration, timeout}; + +#[derive(Default)] +struct WakeCounter { + wakes: AtomicUsize, +} + +impl std::task::Wake for WakeCounter { + fn wake(self: Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } + + fn wake_by_ref(self: &Arc) { + self.wakes.fetch_add(1, Ordering::Relaxed); + } +} + +#[tokio::test] +async fn quota_lock_contention_does_not_self_wake_pending_writer() { + let _guard = super::quota_user_lock_test_scope(); + let map = super::QUOTA_USER_LOCKS.get_or_init(dashmap::DashMap::new); + map.clear(); + + let stats = Arc::new(Stats::new()); + let user = "quota-lock-contention-user"; + + let lock = super::quota_user_lock(user); + let _held_lock = lock + .try_lock() + .expect("test must hold the per-user quota lock before polling writer"); + + let counters = Arc::new(super::SharedCounters::new()); + let quota_exceeded = Arc::new(std::sync::atomic::AtomicBool::new(false)); + let mut io = super::StatsIo::new( + tokio::io::sink(), + counters, + Arc::clone(&stats), + user.to_string(), + Some(1024), + quota_exceeded, + tokio::time::Instant::now(), + ); + + let wake_counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&wake_counter)); + let mut cx = Context::from_waker(&waker); + + let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x11]); + assert!(poll.is_pending(), "writer must remain pending while lock is contended"); + assert_eq!( + wake_counter.wakes.load(Ordering::Relaxed), + 0, + "contended quota lock must not self-wake immediately and spin the executor" + ); +} + +#[tokio::test] +async fn quota_lock_contention_writer_schedules_single_deferred_wake_until_lock_acquired() { + let _guard = super::quota_user_lock_test_scope(); + let map = super::QUOTA_USER_LOCKS.get_or_init(dashmap::DashMap::new); + map.clear(); + + let stats = Arc::new(Stats::new()); + let user = "quota-lock-writer-liveness-user"; + + let lock = super::quota_user_lock(user); + let held_lock = lock + .try_lock() + .expect("test must hold the per-user quota lock before polling writer"); + + let counters = Arc::new(super::SharedCounters::new()); + let quota_exceeded = Arc::new(std::sync::atomic::AtomicBool::new(false)); + let mut io = super::StatsIo::new( + tokio::io::sink(), + counters, + Arc::clone(&stats), + user.to_string(), + Some(1024), + quota_exceeded, + tokio::time::Instant::now(), + ); + + let wake_counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&wake_counter)); + let mut cx = Context::from_waker(&waker); + + let first = Pin::new(&mut io).poll_write(&mut cx, &[0x11]); + assert!(first.is_pending(), "writer must remain pending while lock is contended"); + assert_eq!( + wake_counter.wakes.load(Ordering::Relaxed), + 0, + "deferred wake must not fire synchronously" + ); + + timeout(Duration::from_millis(50), async { + loop { + if wake_counter.wakes.load(Ordering::Relaxed) >= 1 { + break; + } + tokio::task::yield_now().await; + } + }) + .await + .expect("contended writer must schedule a deferred wake in bounded time"); + let wakes_after_first_yield = wake_counter.wakes.load(Ordering::Relaxed); + assert!( + wakes_after_first_yield >= 1, + "contended writer must schedule at least one deferred wake for liveness" + ); + + let second = Pin::new(&mut io).poll_write(&mut cx, &[0x22]); + assert!(second.is_pending(), "writer remains pending while lock is still held"); + + for _ in 0..8 { + tokio::task::yield_now().await; + } + assert_eq!( + wake_counter.wakes.load(Ordering::Relaxed), + wakes_after_first_yield, + "writer contention should not schedule unbounded wake storms before lock acquisition" + ); + + drop(held_lock); + let released = Pin::new(&mut io).poll_write(&mut cx, &[0x33]); + assert!(released.is_ready(), "writer must make progress once quota lock is released"); +} + +#[tokio::test] +async fn quota_lock_contention_read_path_schedules_deferred_wake_for_liveness() { + let _guard = super::quota_user_lock_test_scope(); + let map = super::QUOTA_USER_LOCKS.get_or_init(dashmap::DashMap::new); + map.clear(); + + let stats = Arc::new(Stats::new()); + let user = "quota-lock-read-liveness-user"; + + let lock = super::quota_user_lock(user); + let held_lock = lock + .try_lock() + .expect("test must hold the per-user quota lock before polling reader"); + + let counters = Arc::new(super::SharedCounters::new()); + let quota_exceeded = Arc::new(std::sync::atomic::AtomicBool::new(false)); + let mut io = super::StatsIo::new( + tokio::io::empty(), + counters, + Arc::clone(&stats), + user.to_string(), + Some(1024), + quota_exceeded, + tokio::time::Instant::now(), + ); + + let wake_counter = Arc::new(WakeCounter::default()); + let waker = Waker::from(Arc::clone(&wake_counter)); + let mut cx = Context::from_waker(&waker); + let mut storage = [0u8; 1]; + let mut buf = ReadBuf::new(&mut storage); + + let first = Pin::new(&mut io).poll_read(&mut cx, &mut buf); + assert!(first.is_pending(), "reader must remain pending while lock is contended"); + assert_eq!( + wake_counter.wakes.load(Ordering::Relaxed), + 0, + "read contention wake must not fire synchronously" + ); + + timeout(Duration::from_millis(50), async { + loop { + if wake_counter.wakes.load(Ordering::Relaxed) >= 1 { + break; + } + tokio::task::yield_now().await; + } + }) + .await + .expect("read contention must schedule a deferred wake in bounded time"); + + drop(held_lock); + let mut buf_after_release = ReadBuf::new(&mut storage); + let released = Pin::new(&mut io).poll_read(&mut cx, &mut buf_after_release); + assert!(released.is_ready(), "reader must make progress once quota lock is released"); +} + +#[tokio::test] +async fn relay_bidirectional_enforces_live_user_quota() { + let stats = Arc::new(Stats::new()); + let user = "quota-user"; + stats.add_user_octets_from(user, 6); + + let (mut client_peer, relay_client) = duplex(4096); + let (relay_server, mut server_peer) = duplex(4096); + + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_task = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 1024, + 1024, + user, + Arc::clone(&stats), + Some(8), + Arc::new(BufferPool::new()), + )); + + client_peer + .write_all(&[0x10, 0x20, 0x30, 0x40]) + .await + .expect("client write must succeed"); + + let mut forwarded = [0u8; 4]; + let _ = timeout( + Duration::from_millis(200), + server_peer.read_exact(&mut forwarded), + ) + .await; + + let relay_result = timeout(Duration::from_secs(2), relay_task) + .await + .expect("relay task must finish under quota cutoff") + .expect("relay task must not panic"); + + assert!( + matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == "quota-user"), + "relay must surface a typed quota error once live quota is exceeded" + ); +} + +#[tokio::test] +async fn relay_bidirectional_does_not_forward_server_bytes_after_quota_is_exhausted() { + let stats = Arc::new(Stats::new()); + let quota_user = "quota-exhausted-user"; + stats.add_user_octets_from(quota_user, 1); + + let (mut client_peer, relay_client) = duplex(4096); + let (relay_server, mut server_peer) = duplex(4096); + + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_task = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 1024, + 1024, + quota_user, + Arc::clone(&stats), + Some(1), + Arc::new(BufferPool::new()), + )); + + server_peer + .write_all(&[0xde, 0xad, 0xbe, 0xef]) + .await + .expect("server write must succeed"); + + let mut observed = [0u8; 4]; + let forwarded = timeout( + Duration::from_millis(200), + client_peer.read_exact(&mut observed), + ) + .await; + + let relay_result = timeout(Duration::from_secs(2), relay_task) + .await + .expect("relay task must finish under quota cutoff") + .expect("relay task must not panic"); + + assert!( + !matches!(forwarded, Ok(Ok(n)) if n == observed.len()), + "no full server payload should be forwarded once quota is already exhausted" + ); + assert!( + matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user), + "relay must still terminate with a typed quota error" + ); +} + +#[tokio::test] +async fn relay_bidirectional_does_not_leak_partial_server_payload_when_remaining_quota_is_smaller_than_write() { + let stats = Arc::new(Stats::new()); + let quota_user = "partial-leak-user"; + stats.add_user_octets_from(quota_user, 3); + + let (mut client_peer, relay_client) = duplex(4096); + let (relay_server, mut server_peer) = duplex(4096); + + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_task = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 1024, + 1024, + quota_user, + Arc::clone(&stats), + Some(4), + Arc::new(BufferPool::new()), + )); + + server_peer + .write_all(&[0x11, 0x22, 0x33, 0x44]) + .await + .expect("server write must succeed"); + + let mut observed = [0u8; 8]; + let forwarded = timeout(Duration::from_millis(200), client_peer.read(&mut observed)).await; + + let relay_result = timeout(Duration::from_secs(2), relay_task) + .await + .expect("relay task must finish under quota cutoff") + .expect("relay task must not panic"); + + assert!( + !matches!(forwarded, Ok(Ok(n)) if n > 0), + "quota exhaustion must not leak any partial server payload when remaining quota is smaller than the write" + ); + assert!( + matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user), + "relay must still terminate with a typed quota error" + ); +} + +#[tokio::test] +async fn relay_bidirectional_zero_quota_remains_fail_closed_for_server_payloads_under_stress() { + let stats = Arc::new(Stats::new()); + let quota_user = "zero-quota-user"; + + for payload_len in [1usize, 16, 512, 4096] { + let (mut client_peer, relay_client) = duplex(4096); + let (relay_server, mut server_peer) = duplex(4096); + + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_task = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 1024, + 1024, + quota_user, + Arc::clone(&stats), + Some(0), + Arc::new(BufferPool::new()), + )); + + let payload = vec![0x7f; payload_len]; + let _ = server_peer.write_all(&payload).await; + + let mut observed = vec![0u8; payload_len]; + let forwarded = timeout(Duration::from_millis(200), client_peer.read(&mut observed)).await; + + let relay_result = timeout(Duration::from_secs(2), relay_task) + .await + .expect("relay task must finish under zero-quota cutoff") + .expect("relay task must not panic"); + + assert!( + !matches!(forwarded, Ok(Ok(n)) if n > 0), + "zero quota must not forward any server bytes for payload_len={payload_len}" + ); + assert!( + matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user), + "zero quota must terminate with the typed quota error for payload_len={payload_len}" + ); + } +} + +#[tokio::test] +async fn relay_bidirectional_allows_exact_server_payload_at_quota_boundary() { + let stats = Arc::new(Stats::new()); + let quota_user = "exact-boundary-user"; + + let (mut client_peer, relay_client) = duplex(4096); + let (relay_server, mut server_peer) = duplex(4096); + + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_task = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 1024, + 1024, + quota_user, + Arc::clone(&stats), + Some(4), + Arc::new(BufferPool::new()), + )); + + server_peer + .write_all(&[0x91, 0x92, 0x93, 0x94]) + .await + .expect("server write must succeed at exact quota boundary"); + + let mut observed = [0u8; 4]; + client_peer + .read_exact(&mut observed) + .await + .expect("client must receive the full payload at the exact quota boundary"); + assert_eq!(observed, [0x91, 0x92, 0x93, 0x94]); + + let relay_result = timeout(Duration::from_secs(2), relay_task) + .await + .expect("relay task must finish after exact boundary delivery") + .expect("relay task must not panic"); + + assert!( + matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user), + "relay must close with a typed quota error after reaching the exact boundary" + ); +} + +#[tokio::test] +async fn relay_bidirectional_does_not_forward_client_bytes_after_quota_is_exhausted() { + let stats = Arc::new(Stats::new()); + let quota_user = "client-exhausted-user"; + stats.add_user_octets_from(quota_user, 1); + + let (mut client_peer, relay_client) = duplex(4096); + let (relay_server, mut server_peer) = duplex(4096); + + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_task = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 1024, + 1024, + quota_user, + Arc::clone(&stats), + Some(1), + Arc::new(BufferPool::new()), + )); + + client_peer + .write_all(&[0x51, 0x52, 0x53, 0x54]) + .await + .expect("client write must succeed even when quota is already exhausted"); + + let mut observed = [0u8; 4]; + let forwarded = timeout( + Duration::from_millis(200), + server_peer.read_exact(&mut observed), + ) + .await; + + let relay_result = timeout(Duration::from_secs(2), relay_task) + .await + .expect("relay task must finish under quota cutoff") + .expect("relay task must not panic"); + + assert!( + !matches!(forwarded, Ok(Ok(n)) if n == observed.len()), + "client payload must not be fully forwarded once quota is already exhausted" + ); + assert!( + matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user), + "relay must still terminate with a typed quota error" + ); +} + +#[tokio::test] +async fn relay_bidirectional_server_bytes_remain_blocked_even_under_multiple_payload_sizes() { + let stats = Arc::new(Stats::new()); + let quota_user = "quota-fuzz-user"; + stats.add_user_octets_from(quota_user, 2); + + for payload_len in [1usize, 32, 1024, 8192] { + let (mut client_peer, relay_client) = duplex(4096); + let (relay_server, mut server_peer) = duplex(4096); + + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_task = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 1024, + 1024, + quota_user, + Arc::clone(&stats), + Some(2), + Arc::new(BufferPool::new()), + )); + + let payload = vec![0xaa; payload_len]; + let _ = server_peer.write_all(&payload).await; + + let mut observed = vec![0u8; payload_len]; + let forwarded = timeout( + Duration::from_millis(200), + client_peer.read_exact(&mut observed), + ) + .await; + + let relay_result = timeout(Duration::from_secs(2), relay_task) + .await + .expect("relay task must finish under quota cutoff") + .expect("relay task must not panic"); + + assert!( + !matches!(forwarded, Ok(Ok(n)) if n == payload_len), + "quota exhaustion must block full server-to-client forwarding for payload_len={payload_len}" + ); + assert!( + matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user), + "relay must keep returning the typed quota error for payload_len={payload_len}" + ); + } +} + +#[tokio::test] +async fn relay_bidirectional_terminates_on_activity_timeout() { + tokio::time::pause(); + let stats = Arc::new(Stats::new()); + let user = "timeout-user"; + + let (client_peer, relay_client) = duplex(4096); + let (relay_server, server_peer) = duplex(4096); + + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_task = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 1024, + 1024, + user, + Arc::clone(&stats), + None, // No quota + Arc::new(BufferPool::new()), + )); + + // Wait past the activity timeout threshold (1800 seconds) + buffer + tokio::time::sleep(Duration::from_secs(1805)).await; + + // Resume time to process timeouts + tokio::time::resume(); + + let relay_result = timeout(Duration::from_secs(1), relay_task) + .await + .expect("relay task must finish inside bounded timeout due to inactivity cutoff") + .expect("relay task must not panic"); + + assert!( + relay_result.is_ok(), + "relay should complete successfully on scheduled inactivity timeout" + ); + + // Verify client/server sockets are closed + drop(client_peer); + drop(server_peer); +} + +#[tokio::test] +async fn relay_bidirectional_watchdog_resists_premature_execution() { + tokio::time::pause(); + let stats = Arc::new(Stats::new()); + let user = "activity-user"; + + let (mut client_peer, relay_client) = duplex(4096); + let (relay_server, server_peer) = duplex(4096); + + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let mut relay_task = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 1024, + 1024, + user, + Arc::clone(&stats), + None, + Arc::new(BufferPool::new()), + )); + + // Advance by half the timeout + tokio::time::sleep(Duration::from_secs(900)).await; + + // Provide activity + client_peer + .write_all(&[0xaa, 0xbb]) + .await + .expect("client write must succeed"); + client_peer.flush().await.unwrap(); + + // Advance by another half (total time since start is 1800, but since last activity is 900) + tokio::time::sleep(Duration::from_secs(900)).await; + + tokio::time::resume(); + + // Re-evaluating the task, it should NOT have timed out and still be pending + let relay_result = timeout(Duration::from_millis(100), &mut relay_task).await; + assert!( + relay_result.is_err(), + "Relay must not exit prematurely as long as activity was received before timeout" + ); + + // Explicitly drop sockets to cleanly shut down relay loop + drop(client_peer); + drop(server_peer); + + let completion = timeout(Duration::from_secs(1), relay_task).await + .expect("relay task must complete securely after client disconnection") + .expect("relay task must not panic"); + assert!(completion.is_ok(), "relay exits clean"); +} + +#[tokio::test] +async fn relay_bidirectional_half_closure_terminates_cleanly() { + let stats = Arc::new(Stats::new()); + let (client_peer, relay_client) = duplex(4096); + let (relay_server, server_peer) = duplex(4096); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_task = tokio::spawn(relay_bidirectional( + client_reader, client_writer, server_reader, server_writer, 1024, 1024, "half-close", stats, None, Arc::new(BufferPool::new()), + )); + + // Half closure: drop the client completely but leave the server active. + drop(client_peer); + + // Check that we don't immediately crash. Bidirectional relay stays open for the server -> client flush. + // Eventually dropping the server cleanly closes the task. + drop(server_peer); + timeout(Duration::from_secs(1), relay_task).await.unwrap().unwrap().unwrap(); +} + +#[tokio::test] +async fn relay_bidirectional_zero_length_noise_fuzzing() { + let stats = Arc::new(Stats::new()); + let (mut client_peer, relay_client) = duplex(4096); + let (relay_server, mut server_peer) = duplex(4096); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_task = tokio::spawn(relay_bidirectional( + client_reader, client_writer, server_reader, server_writer, 1024, 1024, "fuzz", stats, None, Arc::new(BufferPool::new()), + )); + + // Flood with zero-length payloads (edge cases in stream framing logic sometimes loop) + for _ in 0..100 { + client_peer.write_all(&[]).await.unwrap(); + } + client_peer.write_all(&[1, 2, 3]).await.unwrap(); + client_peer.flush().await.unwrap(); + + let mut buf = [0u8; 3]; + server_peer.read_exact(&mut buf).await.unwrap(); + assert_eq!(&buf, &[1, 2, 3]); + + drop(client_peer); + drop(server_peer); + timeout(Duration::from_secs(1), relay_task).await.unwrap().unwrap().unwrap(); +} + +#[tokio::test] +async fn relay_bidirectional_asymmetric_backpressure() { + let stats = Arc::new(Stats::new()); + // Give the client stream an extremely narrow throughput limit explicitly + let (client_peer, relay_client) = duplex(1024); + let (relay_server, mut server_peer) = duplex(4096); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_task = tokio::spawn(relay_bidirectional( + client_reader, client_writer, server_reader, server_writer, 1024, 1024, "slowloris", stats, None, Arc::new(BufferPool::new()), + )); + + let payload = vec![0xba; 65536]; // 64k payload + + // Server attempts to shove 64KB into a relay whose client pipe only holds 1KB! + let write_res = tokio::time::timeout(Duration::from_millis(50), server_peer.write_all(&payload)).await; + + assert!( + write_res.is_err(), + "Relay backpressure MUST halt the server writer from unbounded buffering when client stream is full!" + ); + + drop(client_peer); + drop(server_peer); + + let completion = timeout(Duration::from_secs(1), relay_task).await.unwrap().unwrap(); + assert!( + completion.is_ok() || completion.is_err(), + "Task must unwind reliably (either Ok or BrokenPipe Err) when dropped despite active backpressure locks" + ); +} + +use rand::{RngExt, SeedableRng, rngs::StdRng}; + +#[tokio::test] +async fn relay_bidirectional_light_fuzzing_temporal_jitter() { + tokio::time::pause(); + let stats = Arc::new(Stats::new()); + let (mut client_peer, relay_client) = duplex(4096); + let (relay_server, server_peer) = duplex(4096); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let mut relay_task = tokio::spawn(relay_bidirectional( + client_reader, client_writer, server_reader, server_writer, 1024, 1024, "fuzz-user", stats, None, Arc::new(BufferPool::new()), + )); + + let mut rng = StdRng::seed_from_u64(0xDEADBEEF); + + for _ in 0..10 { + // Vary timing significantly up to 1600 seconds (limit is 1800s) + let jitter = rng.random_range(100..1600); + tokio::time::sleep(Duration::from_secs(jitter)).await; + + client_peer.write_all(&[0x11]).await.unwrap(); + client_peer.flush().await.unwrap(); + + // Ensure task has not died + let res = timeout(Duration::from_millis(10), &mut relay_task).await; + assert!(res.is_err(), "Relay must remain open indefinitely under light temporal fuzzing with active jitter pulses"); + } + + drop(client_peer); + drop(server_peer); + timeout(Duration::from_secs(1), relay_task).await.unwrap().unwrap().unwrap(); +} + +struct FaultyReader { + error_once: Option, +} + +struct TwoPartyGate { + arrivals: AtomicUsize, + total_bytes: AtomicUsize, + wakers: Mutex>, +} + +impl TwoPartyGate { + fn new() -> Self { + Self { + arrivals: AtomicUsize::new(0), + total_bytes: AtomicUsize::new(0), + wakers: Mutex::new(Vec::new()), + } + } + + fn arrive_or_park(&self, cx: &mut Context<'_>) -> bool { + if self.arrivals.load(Ordering::Relaxed) >= 2 { + return true; + } + + let prev = self.arrivals.fetch_add(1, Ordering::AcqRel); + if prev + 1 >= 2 { + let mut wakers = self.wakers.lock().unwrap_or_else(|p| p.into_inner()); + for waker in wakers.drain(..) { + waker.wake(); + } + true + } else { + let mut wakers = self.wakers.lock().unwrap_or_else(|p| p.into_inner()); + wakers.push(cx.waker().clone()); + false + } + } + + fn total_bytes(&self) -> usize { + self.total_bytes.load(Ordering::Relaxed) + } +} + +struct GateWriter { + gate: Arc, + entered: bool, +} + +impl GateWriter { + fn new(gate: Arc) -> Self { + Self { + gate, + entered: false, + } + } +} + +impl AsyncWrite for GateWriter { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + if !self.entered { + self.entered = true; + } + + if !self.gate.arrive_or_park(cx) { + return Poll::Pending; + } + + self.gate + .total_bytes + .fetch_add(buf.len(), Ordering::Relaxed); + Poll::Ready(Ok(buf.len())) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} + +struct GateReader { + gate: Arc, + entered: bool, + emitted: bool, +} + +impl GateReader { + fn new(gate: Arc) -> Self { + Self { + gate, + entered: false, + emitted: false, + } + } +} + +impl AsyncRead for GateReader { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + if self.emitted { + return Poll::Ready(Ok(())); + } + + if !self.entered { + self.entered = true; + } + + if !self.gate.arrive_or_park(cx) { + return Poll::Pending; + } + + buf.put_slice(&[0x42]); + self.gate.total_bytes.fetch_add(1, Ordering::Relaxed); + self.emitted = true; + Poll::Ready(Ok(())) + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn adversarial_concurrent_quota_write_race_does_not_overshoot_limit() { + let stats = Arc::new(Stats::new()); + let gate = Arc::new(TwoPartyGate::new()); + let user = "concurrent-quota-write".to_string(); + + let writer_a = super::StatsIo::new( + GateWriter::new(Arc::clone(&gate)), + Arc::new(super::SharedCounters::new()), + Arc::clone(&stats), + user.clone(), + Some(1), + Arc::new(std::sync::atomic::AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + let writer_b = super::StatsIo::new( + GateWriter::new(Arc::clone(&gate)), + Arc::new(super::SharedCounters::new()), + Arc::clone(&stats), + user.clone(), + Some(1), + Arc::new(std::sync::atomic::AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + let task_a = tokio::spawn(async move { + let mut w = writer_a; + AsyncWriteExt::write_all(&mut w, &[0x01]).await + }); + let task_b = tokio::spawn(async move { + let mut w = writer_b; + AsyncWriteExt::write_all(&mut w, &[0x02]).await + }); + + let (res_a, res_b) = tokio::join!(task_a, task_b); + let _ = res_a.expect("task a must join"); + let _ = res_b.expect("task b must join"); + + assert!( + gate.total_bytes() <= 1, + "concurrent same-user writes must not forward more than one byte under quota=1" + ); + assert!( + stats.get_user_total_octets(&user) <= 1, + "concurrent same-user writes must not account over limit" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn adversarial_concurrent_quota_read_race_does_not_overshoot_limit() { + let stats = Arc::new(Stats::new()); + let gate = Arc::new(TwoPartyGate::new()); + let user = "concurrent-quota-read".to_string(); + + let reader_a = super::StatsIo::new( + GateReader::new(Arc::clone(&gate)), + Arc::new(super::SharedCounters::new()), + Arc::clone(&stats), + user.clone(), + Some(1), + Arc::new(std::sync::atomic::AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + let reader_b = super::StatsIo::new( + GateReader::new(Arc::clone(&gate)), + Arc::new(super::SharedCounters::new()), + Arc::clone(&stats), + user.clone(), + Some(1), + Arc::new(std::sync::atomic::AtomicBool::new(false)), + tokio::time::Instant::now(), + ); + + let task_a = tokio::spawn(async move { + let mut r = reader_a; + let mut one = [0u8; 1]; + AsyncReadExt::read_exact(&mut r, &mut one).await + }); + let task_b = tokio::spawn(async move { + let mut r = reader_b; + let mut one = [0u8; 1]; + AsyncReadExt::read_exact(&mut r, &mut one).await + }); + + let (res_a, res_b) = tokio::join!(task_a, task_b); + let _ = res_a.expect("task a must join"); + let _ = res_b.expect("task b must join"); + + assert!( + gate.total_bytes() <= 1, + "concurrent same-user reads must not consume more than one byte under quota=1" + ); + assert!( + stats.get_user_total_octets(&user) <= 1, + "concurrent same-user reads must not account over limit" + ); +} + +#[tokio::test] +async fn stress_same_user_quota_parallel_relays_never_exceed_limit() { + let stats = Arc::new(Stats::new()); + let user = "parallel-quota-user"; + + for _ in 0..128 { + let (mut client_peer_a, relay_client_a) = duplex(256); + let (relay_server_a, mut server_peer_a) = duplex(256); + let (mut client_peer_b, relay_client_b) = duplex(256); + let (relay_server_b, mut server_peer_b) = duplex(256); + + let (client_reader_a, client_writer_a) = tokio::io::split(relay_client_a); + let (server_reader_a, server_writer_a) = tokio::io::split(relay_server_a); + let (client_reader_b, client_writer_b) = tokio::io::split(relay_client_b); + let (server_reader_b, server_writer_b) = tokio::io::split(relay_server_b); + + let relay_a = tokio::spawn(relay_bidirectional( + client_reader_a, + client_writer_a, + server_reader_a, + server_writer_a, + 64, + 64, + user, + Arc::clone(&stats), + Some(1), + Arc::new(BufferPool::new()), + )); + + let relay_b = tokio::spawn(relay_bidirectional( + client_reader_b, + client_writer_b, + server_reader_b, + server_writer_b, + 64, + 64, + user, + Arc::clone(&stats), + Some(1), + Arc::new(BufferPool::new()), + )); + + let _ = tokio::join!( + client_peer_a.write_all(&[0x01]), + server_peer_a.write_all(&[0x02]), + client_peer_b.write_all(&[0x03]), + server_peer_b.write_all(&[0x04]), + ); + + let _ = timeout(Duration::from_millis(50), poll_fn(|cx| { + let mut one = [0u8; 1]; + let _ = Pin::new(&mut client_peer_a).poll_read(cx, &mut ReadBuf::new(&mut one)); + Poll::Ready(()) + })) + .await; + + drop(client_peer_a); + drop(server_peer_a); + drop(client_peer_b); + drop(server_peer_b); + + let _ = timeout(Duration::from_secs(1), relay_a).await; + let _ = timeout(Duration::from_secs(1), relay_b).await; + + assert!( + stats.get_user_total_octets(user) <= 1, + "parallel relays must not exceed configured quota" + ); + } +} + +impl FaultyReader { + fn permission_denied_with_message(message: impl Into) -> Self { + Self { + error_once: Some(io::Error::new(io::ErrorKind::PermissionDenied, message.into())), + } + } +} + +impl AsyncRead for FaultyReader { + fn poll_read( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _buf: &mut ReadBuf<'_>, + ) -> Poll> { + if let Some(err) = self.error_once.take() { + return Poll::Ready(Err(err)); + } + Poll::Ready(Ok(())) + } +} + +#[tokio::test] +async fn relay_bidirectional_does_not_misclassify_transport_permission_denied_as_quota() { + let stats = Arc::new(Stats::new()); + let (client_peer, relay_client) = duplex(4096); + let (client_reader, client_writer) = tokio::io::split(relay_client); + + let relay_result = relay_bidirectional( + client_reader, + client_writer, + FaultyReader::permission_denied_with_message("user data quota exceeded"), + tokio::io::sink(), + 1024, + 1024, + "non-quota-permission-denied", + Arc::clone(&stats), + None, + Arc::new(BufferPool::new()), + ) + .await; + + drop(client_peer); + + assert!( + matches!(relay_result, Err(ProxyError::Io(ref err)) if err.kind() == io::ErrorKind::PermissionDenied), + "non-quota transport PermissionDenied errors must remain IO errors" + ); +} + +#[tokio::test] +async fn relay_bidirectional_light_fuzz_permission_denied_messages_remain_io_errors() { + let mut rng = StdRng::seed_from_u64(0xA11CE0B5); + + for i in 0..128u64 { + let stats = Arc::new(Stats::new()); + let (client_peer, relay_client) = duplex(1024); + let (client_reader, client_writer) = tokio::io::split(relay_client); + + let random_len = rng.random_range(1..=48); + let mut msg = String::with_capacity(random_len); + for _ in 0..random_len { + let ch = (b'a' + (rng.random::() % 26)) as char; + msg.push(ch); + } + // Include the legacy quota string in a subset of fuzz cases to validate + // collision resistance against message-based classification. + if i % 7 == 0 { + msg = "user data quota exceeded".to_string(); + } + + let relay_result = relay_bidirectional( + client_reader, + client_writer, + FaultyReader::permission_denied_with_message(msg), + tokio::io::sink(), + 1024, + 1024, + "fuzz-perm-denied", + Arc::clone(&stats), + None, + Arc::new(BufferPool::new()), + ) + .await; + + drop(client_peer); + + assert!( + matches!(relay_result, Err(ProxyError::Io(ref err)) if err.kind() == io::ErrorKind::PermissionDenied), + "transport PermissionDenied case must stay typed as IO regardless of message content" + ); + } +} + +#[tokio::test] +async fn relay_half_close_keeps_reverse_direction_progressing() { + let stats = Arc::new(Stats::new()); + let user = "half-close-user"; + + let (client_peer, relay_client) = duplex(1024); + let (relay_server, server_peer) = duplex(1024); + + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + let (mut cp_reader, mut cp_writer) = tokio::io::split(client_peer); + let (mut sp_reader, mut sp_writer) = tokio::io::split(server_peer); + + let relay_task = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 8192, + 8192, + user, + Arc::clone(&stats), + None, + Arc::new(BufferPool::new()), + )); + + sp_writer.write_all(&[0x10, 0x20, 0x30, 0x40]).await.unwrap(); + sp_writer.shutdown().await.unwrap(); + + let mut inbound = [0u8; 4]; + cp_reader.read_exact(&mut inbound).await.unwrap(); + assert_eq!(inbound, [0x10, 0x20, 0x30, 0x40]); + + cp_writer.write_all(&[0xaa, 0xbb, 0xcc, 0xdd]).await.unwrap(); + let mut outbound = [0u8; 4]; + sp_reader.read_exact(&mut outbound).await.unwrap(); + assert_eq!(outbound, [0xaa, 0xbb, 0xcc, 0xdd]); + + relay_task.abort(); + let joined = relay_task.await; + assert!(joined.is_err(), "aborted relay task must return join error"); +} diff --git a/src/proxy/tests/relay_watchdog_delta_security_tests.rs b/src/proxy/tests/relay_watchdog_delta_security_tests.rs new file mode 100644 index 0000000..f05ee62 --- /dev/null +++ b/src/proxy/tests/relay_watchdog_delta_security_tests.rs @@ -0,0 +1,61 @@ +use super::watchdog_delta; + +#[test] +fn positive_monotonic_growth_returns_exact_delta() { + assert_eq!(watchdog_delta(42, 40), 2); + assert_eq!(watchdog_delta(4096, 1024), 3072); +} + +#[test] +fn edge_equal_values_return_zero_delta() { + assert_eq!(watchdog_delta(0, 0), 0); + assert_eq!(watchdog_delta(777, 777), 0); +} + +#[test] +fn adversarial_wrap_like_regression_saturates_to_zero() { + // Simulates a wrapped or reset counter observation where current < previous. + assert_eq!(watchdog_delta(0, 1), 0); + assert_eq!(watchdog_delta(12, 4096), 0); +} + +#[test] +fn adversarial_blackhat_large_previous_value_never_underflows() { + let current = 3u64; + let previous = u64::MAX - 1; + assert_eq!(watchdog_delta(current, previous), 0); +} + +#[test] +fn light_fuzz_mixed_pairs_match_saturating_sub_contract() { + // Deterministic xorshift64* generator for reproducible pseudo-fuzzing. + let mut seed = 0xA51C_ED42_D00D_F00Du64; + + for _ in 0..10_000 { + seed ^= seed >> 12; + seed ^= seed << 25; + seed ^= seed >> 27; + let current = seed.wrapping_mul(0x2545_F491_4F6C_DD1D); + + seed ^= seed >> 12; + seed ^= seed << 25; + seed ^= seed >> 27; + let previous = seed.wrapping_mul(0x2545_F491_4F6C_DD1D); + + let expected = current.saturating_sub(previous); + let actual = watchdog_delta(current, previous); + assert_eq!(actual, expected, "delta mismatch for ({current}, {previous})"); + } +} + +#[test] +fn stress_long_running_monotonic_sequence_remains_exact() { + let mut prev = 0u64; + + for step in 1u64..=200_000 { + let curr = prev.saturating_add(step & 0x7); + let delta = watchdog_delta(curr, prev); + assert_eq!(delta, curr - prev); + prev = curr; + } +} diff --git a/src/proxy/tests/route_mode_coherence_adversarial_tests.rs b/src/proxy/tests/route_mode_coherence_adversarial_tests.rs new file mode 100644 index 0000000..4f255d4 --- /dev/null +++ b/src/proxy/tests/route_mode_coherence_adversarial_tests.rs @@ -0,0 +1,228 @@ +use super::*; +use rand::rngs::StdRng; +use rand::{RngExt, SeedableRng}; +use std::sync::Arc; + +#[test] +fn positive_direct_cutover_sets_timestamp_and_snapshot_coherently() { + let runtime = RouteRuntimeController::new(RelayRouteMode::Middle); + let rx = runtime.subscribe(); + + assert!( + runtime.direct_since_epoch_secs().is_none(), + "middle startup must not expose direct-since timestamp" + ); + + let emitted = runtime + .set_mode(RelayRouteMode::Direct) + .expect("middle->direct must emit cutover"); + let observed = *rx.borrow(); + + assert_eq!(observed, emitted, "watch snapshot must match emitted cutover"); + assert_eq!(observed.mode, RelayRouteMode::Direct); + assert!( + runtime.direct_since_epoch_secs().is_some(), + "direct cutover must publish a non-empty direct-since timestamp" + ); +} + +#[test] +fn negative_idempotent_set_mode_does_not_mutate_timestamp_or_generation() { + let runtime = RouteRuntimeController::new(RelayRouteMode::Direct); + + let before_state = runtime.snapshot(); + let before_ts = runtime.direct_since_epoch_secs(); + + let changed = runtime.set_mode(RelayRouteMode::Direct); + + let after_state = runtime.snapshot(); + let after_ts = runtime.direct_since_epoch_secs(); + + assert!(changed.is_none(), "idempotent set_mode must return None"); + assert_eq!( + after_state.generation, before_state.generation, + "idempotent set_mode must not advance generation" + ); + assert_eq!( + after_ts, before_ts, + "idempotent set_mode must not alter direct-since timestamp" + ); +} + +#[test] +fn edge_middle_cutover_clears_timestamp() { + let runtime = RouteRuntimeController::new(RelayRouteMode::Direct); + let rx = runtime.subscribe(); + + assert!( + runtime.direct_since_epoch_secs().is_some(), + "direct startup must expose direct-since timestamp" + ); + + let emitted = runtime + .set_mode(RelayRouteMode::Middle) + .expect("direct->middle must emit cutover"); + let observed = *rx.borrow(); + + assert_eq!(observed, emitted, "watch snapshot must match emitted cutover"); + assert_eq!(observed.mode, RelayRouteMode::Middle); + assert!( + runtime.direct_since_epoch_secs().is_none(), + "middle cutover must clear direct-since timestamp" + ); +} + +#[test] +fn adversarial_blackhat_probe_sequence_observes_consistent_mode_timestamp_pairs() { + let runtime = RouteRuntimeController::new(RelayRouteMode::Middle); + let rx = runtime.subscribe(); + + for _ in 0..2048usize { + let emitted_direct = runtime + .set_mode(RelayRouteMode::Direct) + .expect("middle->direct must emit"); + let observed_direct = *rx.borrow(); + assert_eq!(observed_direct, emitted_direct); + assert!( + runtime.direct_since_epoch_secs().is_some(), + "direct observation must never expose empty timestamp" + ); + + let emitted_middle = runtime + .set_mode(RelayRouteMode::Middle) + .expect("direct->middle must emit"); + let observed_middle = *rx.borrow(); + assert_eq!(observed_middle, emitted_middle); + assert!( + runtime.direct_since_epoch_secs().is_none(), + "middle observation must never expose direct timestamp" + ); + } +} + +#[test] +fn integration_subscriber_and_runtime_gates_stay_coherent_across_cutovers() { + let runtime = RouteRuntimeController::new(RelayRouteMode::Middle); + let rx = runtime.subscribe(); + + let plan = [ + RelayRouteMode::Direct, + RelayRouteMode::Middle, + RelayRouteMode::Direct, + RelayRouteMode::Middle, + RelayRouteMode::Direct, + ]; + + let mut expected_generation = 0u64; + + for mode in plan { + let emitted = runtime + .set_mode(mode) + .expect("each planned transition toggles mode and must emit"); + expected_generation = expected_generation.saturating_add(1); + + let watched = *rx.borrow(); + let snapshot = runtime.snapshot(); + + assert_eq!(emitted.mode, mode); + assert_eq!(emitted.generation, expected_generation); + assert_eq!(watched, emitted); + assert_eq!(snapshot, emitted); + + if matches!(mode, RelayRouteMode::Direct) { + assert!(runtime.direct_since_epoch_secs().is_some()); + } else { + assert!(runtime.direct_since_epoch_secs().is_none()); + } + } +} + +#[test] +fn light_fuzz_random_mode_plan_preserves_timestamp_and_generation_invariants() { + let runtime = RouteRuntimeController::new(RelayRouteMode::Middle); + let mut rng = StdRng::seed_from_u64(0x5EED_CAFE_D15C_A11E); + + let mut expected_mode = RelayRouteMode::Middle; + let mut expected_generation = 0u64; + + for _ in 0..25_000usize { + let candidate = if rng.random::() { + RelayRouteMode::Direct + } else { + RelayRouteMode::Middle + }; + + let changed = runtime.set_mode(candidate); + if candidate == expected_mode { + assert!(changed.is_none(), "idempotent fuzz step must not emit"); + continue; + } + + expected_mode = candidate; + expected_generation = expected_generation.saturating_add(1); + + let emitted = changed.expect("non-idempotent fuzz step must emit"); + assert_eq!(emitted.mode, expected_mode); + assert_eq!(emitted.generation, expected_generation); + + let snapshot = runtime.snapshot(); + assert_eq!(snapshot, emitted, "snapshot must match emitted cutover"); + + if matches!(snapshot.mode, RelayRouteMode::Direct) { + assert!( + runtime.direct_since_epoch_secs().is_some(), + "direct fuzz state must expose timestamp" + ); + } else { + assert!( + runtime.direct_since_epoch_secs().is_none(), + "middle fuzz state must clear timestamp" + ); + } + } +} + +#[test] +fn stress_parallel_subscribers_never_observe_generation_regression() { + let runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Middle)); + + let mut readers = Vec::new(); + for _ in 0..4usize { + let runtime = Arc::clone(&runtime); + readers.push(std::thread::spawn(move || { + let rx = runtime.subscribe(); + let mut last = rx.borrow().generation; + for _ in 0..10_000usize { + let current = rx.borrow().generation; + assert!( + current >= last, + "watch generation must be monotonic for every subscriber" + ); + last = current; + std::thread::yield_now(); + } + })); + } + + for step in 0..20_000usize { + let mode = if (step & 1) == 0 { + RelayRouteMode::Direct + } else { + RelayRouteMode::Middle + }; + let _ = runtime.set_mode(mode); + } + + for reader in readers { + reader + .join() + .expect("parallel subscriber reader must not panic"); + } + + let final_state = runtime.snapshot(); + if matches!(final_state.mode, RelayRouteMode::Direct) { + assert!(runtime.direct_since_epoch_secs().is_some()); + } else { + assert!(runtime.direct_since_epoch_secs().is_none()); + } +} diff --git a/src/proxy/tests/route_mode_security_tests.rs b/src/proxy/tests/route_mode_security_tests.rs new file mode 100644 index 0000000..49cbb66 --- /dev/null +++ b/src/proxy/tests/route_mode_security_tests.rs @@ -0,0 +1,406 @@ +use super::*; +use rand::{RngExt, SeedableRng}; +use rand::rngs::StdRng; +use std::sync::Arc; +use std::sync::atomic::{AtomicU64, Ordering}; + +#[test] +fn cutover_stagger_delay_is_deterministic_for_same_inputs() { + let d1 = cutover_stagger_delay(0x0123_4567_89ab_cdef, 42); + let d2 = cutover_stagger_delay(0x0123_4567_89ab_cdef, 42); + assert_eq!( + d1, d2, + "stagger delay must be deterministic for identical session/generation inputs" + ); +} + +#[test] +fn cutover_stagger_delay_stays_within_budget_bounds() { + // Black-hat model: censors trigger many cutovers and correlate disconnect timing. + // Keep delay inside a narrow coarse window to avoid long-tail spikes. + for generation in [0u64, 1, 2, 3, 16, 128, u32::MAX as u64, u64::MAX] { + for session_id in [ + 0u64, + 1, + 2, + 0xdead_beef, + 0xfeed_face_cafe_babe, + u64::MAX, + ] { + let delay = cutover_stagger_delay(session_id, generation); + assert!( + (1000..=1999).contains(&delay.as_millis()), + "stagger delay must remain in fixed 1000..=1999ms budget" + ); + } + } +} + +#[test] +fn cutover_stagger_delay_changes_with_generation_for_same_session() { + let session_id = 0x0123_4567_89ab_cdef; + let first = cutover_stagger_delay(session_id, 100); + let second = cutover_stagger_delay(session_id, 101); + assert_ne!( + first, second, + "adjacent cutover generations should decorrelate disconnect delays" + ); +} + +#[test] +fn route_runtime_set_mode_is_idempotent_for_same_mode() { + let runtime = RouteRuntimeController::new(RelayRouteMode::Direct); + let first = runtime.snapshot(); + let changed = runtime.set_mode(RelayRouteMode::Direct); + let second = runtime.snapshot(); + + assert!( + changed.is_none(), + "setting already-active mode must not produce a cutover event" + ); + assert_eq!( + first.generation, second.generation, + "idempotent mode set must not bump generation" + ); +} + +#[test] +fn affected_cutover_state_triggers_only_for_newer_generation() { + let runtime = RouteRuntimeController::new(RelayRouteMode::Direct); + let rx = runtime.subscribe(); + let initial = runtime.snapshot(); + + assert!( + affected_cutover_state(&rx, RelayRouteMode::Direct, initial.generation).is_none(), + "current generation must not be considered a cutover for existing session" + ); + + let next = runtime + .set_mode(RelayRouteMode::Middle) + .expect("mode change must produce cutover state"); + let seen = affected_cutover_state(&rx, RelayRouteMode::Direct, initial.generation) + .expect("newer generation must be observed as cutover"); + + assert_eq!(seen.generation, next.generation); + assert_eq!(seen.mode, RelayRouteMode::Middle); +} + +#[test] +fn integration_watch_and_snapshot_follow_same_transition_sequence() { + let runtime = RouteRuntimeController::new(RelayRouteMode::Direct); + let rx = runtime.subscribe(); + + let sequence = [ + RelayRouteMode::Middle, + RelayRouteMode::Middle, + RelayRouteMode::Direct, + RelayRouteMode::Direct, + RelayRouteMode::Middle, + ]; + + let mut expected_generation = 0u64; + let mut expected_mode = RelayRouteMode::Direct; + + for target in sequence { + let changed = runtime.set_mode(target); + if target == expected_mode { + assert!(changed.is_none(), "idempotent transition must return none"); + } else { + expected_mode = target; + expected_generation = expected_generation.saturating_add(1); + let emitted = changed.expect("real transition must emit cutover state"); + assert_eq!(emitted.mode, expected_mode); + assert_eq!(emitted.generation, expected_generation); + } + + let snap = runtime.snapshot(); + let watched = *rx.borrow(); + assert_eq!(snap, watched, "snapshot and watch state must stay aligned"); + assert_eq!(snap.mode, expected_mode); + assert_eq!(snap.generation, expected_generation); + } +} + +#[test] +fn session_is_not_affected_when_mode_matches_even_if_generation_advanced() { + let session_mode = RelayRouteMode::Direct; + let current = RouteCutoverState { + mode: RelayRouteMode::Direct, + generation: 2, + }; + let session_generation = 0; + + assert!( + !is_session_affected_by_cutover(current, session_mode, session_generation), + "session on matching final route mode should not be force-cut over on intermediate generation bumps" + ); +} + +#[test] +fn cutover_predicate_rejects_equal_generation_even_if_mode_differs() { + let current = RouteCutoverState { + mode: RelayRouteMode::Middle, + generation: 77, + }; + assert!( + !is_session_affected_by_cutover(current, RelayRouteMode::Direct, 77), + "equal generation must never trigger cutover regardless of mode mismatch" + ); +} + +#[test] +fn adversarial_route_oscillation_only_cuts_over_sessions_with_different_final_mode() { + let runtime = RouteRuntimeController::new(RelayRouteMode::Direct); + let rx = runtime.subscribe(); + let session_generation = runtime.snapshot().generation; + + runtime + .set_mode(RelayRouteMode::Middle) + .expect("direct->middle must transition"); + runtime + .set_mode(RelayRouteMode::Direct) + .expect("middle->direct must transition"); + + assert!( + affected_cutover_state(&rx, RelayRouteMode::Direct, session_generation).is_none(), + "direct session should survive when final mode returns to direct" + ); + assert!( + affected_cutover_state(&rx, RelayRouteMode::Middle, session_generation).is_some(), + "middle session should be cut over when final mode is direct" + ); +} + +#[test] +fn light_fuzz_cutover_predicate_matches_reference_oracle() { + let mut rng = StdRng::seed_from_u64(0xC0DEC0DE5EED); + for _ in 0..20_000 { + let current = RouteCutoverState { + mode: if rng.random::() { + RelayRouteMode::Direct + } else { + RelayRouteMode::Middle + }, + generation: rng.random_range(0u64..1_000_000), + }; + let session_mode = if rng.random::() { + RelayRouteMode::Direct + } else { + RelayRouteMode::Middle + }; + let session_generation = rng.random_range(0u64..1_000_000); + + let expected = current.generation > session_generation && current.mode != session_mode; + let actual = is_session_affected_by_cutover(current, session_mode, session_generation); + assert_eq!( + actual, expected, + "cutover predicate must match mode-aware generation oracle" + ); + } +} + +#[test] +fn light_fuzz_set_mode_generation_tracks_only_real_transitions() { + let runtime = RouteRuntimeController::new(RelayRouteMode::Direct); + let mut rng = StdRng::seed_from_u64(0x0DDC0FFE); + + let mut expected_mode = RelayRouteMode::Direct; + let mut expected_generation = 0u64; + + for _ in 0..10_000 { + let candidate = if rng.random::() { + RelayRouteMode::Direct + } else { + RelayRouteMode::Middle + }; + let changed = runtime.set_mode(candidate); + + if candidate == expected_mode { + assert!(changed.is_none(), "idempotent set_mode must not emit cutover state"); + } else { + expected_mode = candidate; + expected_generation = expected_generation.saturating_add(1); + let next = changed.expect("mode transition must emit cutover state"); + assert_eq!(next.mode, expected_mode); + assert_eq!(next.generation, expected_generation); + } + } + + let final_state = runtime.snapshot(); + assert_eq!(final_state.mode, expected_mode); + assert_eq!(final_state.generation, expected_generation); +} + +#[test] +fn stress_snapshot_and_watch_state_remain_consistent_under_concurrent_switch_storm() { + let runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + + std::thread::scope(|scope| { + let mut writers = Vec::new(); + for worker in 0..4usize { + let runtime = Arc::clone(&runtime); + writers.push(scope.spawn(move || { + for step in 0..20_000usize { + let mode = if (worker + step) % 2 == 0 { + RelayRouteMode::Direct + } else { + RelayRouteMode::Middle + }; + let _ = runtime.set_mode(mode); + } + })); + } + + for writer in writers { + writer + .join() + .expect("route mode writer thread must not panic"); + } + + let rx = runtime.subscribe(); + for _ in 0..128 { + assert_eq!( + runtime.snapshot(), + *rx.borrow(), + "snapshot and watch state must converge after concurrent set_mode churn" + ); + std::thread::yield_now(); + } + }); +} + +#[test] +fn stress_concurrent_transition_count_matches_final_generation() { + let runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let successful_transitions = Arc::new(AtomicU64::new(0)); + + std::thread::scope(|scope| { + let mut workers = Vec::new(); + for worker in 0..6usize { + let runtime = Arc::clone(&runtime); + let successful_transitions = Arc::clone(&successful_transitions); + workers.push(scope.spawn(move || { + let mut state = (worker as u64 + 1).wrapping_mul(0x9E37_79B9_7F4A_7C15); + for _ in 0..25_000usize { + state ^= state << 7; + state ^= state >> 9; + state ^= state << 8; + let mode = if (state & 1) == 0 { + RelayRouteMode::Direct + } else { + RelayRouteMode::Middle + }; + if runtime.set_mode(mode).is_some() { + successful_transitions.fetch_add(1, Ordering::Relaxed); + } + } + })); + } + + for worker in workers { + worker.join().expect("route mode transition worker must not panic"); + } + }); + + let final_state = runtime.snapshot(); + assert_eq!( + final_state.generation, + successful_transitions.load(Ordering::Relaxed), + "final generation must equal number of accepted mode transitions" + ); + assert_eq!( + final_state, + *runtime.subscribe().borrow(), + "watch and snapshot state must match after concurrent transition accounting" + ); +} + +#[test] +fn light_fuzz_cutover_stagger_delay_distribution_stays_in_fixed_window() { + // Deterministic xorshift fuzzing keeps this test stable across runs. + let mut s: u64 = 0x9E37_79B9_7F4A_7C15; + + for _ in 0..20_000 { + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + let session_id = s; + + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + let generation = s; + + let delay = cutover_stagger_delay(session_id, generation); + assert!( + (1000..=1999).contains(&delay.as_millis()), + "fuzzed inputs must always map into fixed stagger window" + ); + } +} + +#[test] +fn cutover_stagger_delay_distribution_has_no_empty_buckets_under_sequential_sessions() { + let mut buckets = [0usize; 1000]; + let generation = 4242u64; + + for session_id in 0..250_000u64 { + let delay_ms = cutover_stagger_delay(session_id, generation).as_millis() as usize; + let idx = delay_ms - 1000; + buckets[idx] += 1; + } + + let empty = buckets.iter().filter(|&&count| count == 0).count(); + assert_eq!( + empty, 0, + "all 1000 delay buckets must be exercised to avoid cutover herd clustering" + ); +} + +#[test] +fn light_fuzz_cutover_stagger_delay_distribution_stays_reasonably_uniform() { + let mut buckets = [0usize; 1000]; + let mut s: u64 = 0x1BAD_B002_CAFE_F00D; + + for _ in 0..300_000usize { + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + let session_id = s; + + s ^= s << 7; + s ^= s >> 9; + s ^= s << 8; + let generation = s; + + let delay_ms = cutover_stagger_delay(session_id, generation).as_millis() as usize; + buckets[delay_ms - 1000] += 1; + } + + let min = *buckets.iter().min().unwrap_or(&0); + let max = *buckets.iter().max().unwrap_or(&0); + assert!(min > 0, "fuzzed distribution must not leave empty buckets"); + assert!( + max <= min.saturating_mul(3), + "bucket skew is too high for anti-herd staggering (max={max}, min={min})" + ); +} + +#[test] +fn stress_cutover_stagger_delay_distribution_remains_stable_across_generations() { + for generation in [0u64, 1, 7, 31, 255, 1024, u32::MAX as u64, u64::MAX - 1] { + let mut buckets = [0usize; 1000]; + for session_id in 0..100_000u64 { + let delay_ms = cutover_stagger_delay(session_id ^ 0x9E37_79B9, generation) + .as_millis() as usize; + buckets[delay_ms - 1000] += 1; + } + + let min = *buckets.iter().min().unwrap_or(&0); + let max = *buckets.iter().max().unwrap_or(&0); + assert!( + max <= min.saturating_mul(4).max(1), + "generation={generation}: distribution collapsed (max={max}, min={min})" + ); + } +} diff --git a/src/stats/mod.rs b/src/stats/mod.rs index 0df4dc0..c9fc318 100644 --- a/src/stats/mod.rs +++ b/src/stats/mod.rs @@ -6,6 +6,7 @@ pub mod beobachten; pub mod telemetry; use std::sync::atomic::{AtomicBool, AtomicU8, AtomicU64, Ordering}; +use std::sync::Arc; use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; use dashmap::DashMap; use parking_lot::Mutex; @@ -19,135 +20,44 @@ use tracing::debug; use crate::config::{MeTelemetryLevel, MeWriterPickMode}; use self::telemetry::TelemetryPolicy; -const ME_WRITER_TEARDOWN_MODE_COUNT: usize = 2; -const ME_WRITER_TEARDOWN_REASON_COUNT: usize = 11; -const ME_WRITER_CLEANUP_SIDE_EFFECT_STEP_COUNT: usize = 2; -const ME_WRITER_TEARDOWN_DURATION_BUCKET_COUNT: usize = 12; -const ME_WRITER_TEARDOWN_DURATION_BUCKET_BOUNDS_MICROS: [u64; ME_WRITER_TEARDOWN_DURATION_BUCKET_COUNT] = [ - 1_000, - 5_000, - 10_000, - 25_000, - 50_000, - 100_000, - 250_000, - 500_000, - 1_000_000, - 2_500_000, - 5_000_000, - 10_000_000, -]; -const ME_WRITER_TEARDOWN_DURATION_BUCKET_LABELS: [&str; ME_WRITER_TEARDOWN_DURATION_BUCKET_COUNT] = [ - "0.001", - "0.005", - "0.01", - "0.025", - "0.05", - "0.1", - "0.25", - "0.5", - "1", - "2.5", - "5", - "10", -]; - -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] -#[repr(u8)] -pub enum MeWriterTeardownMode { - Normal = 0, - HardDetach = 1, +#[derive(Clone, Copy)] +enum RouteConnectionGauge { + Direct, + Middle, } -impl MeWriterTeardownMode { - pub const ALL: [Self; ME_WRITER_TEARDOWN_MODE_COUNT] = - [Self::Normal, Self::HardDetach]; +#[must_use = "RouteConnectionLease must be kept alive to hold the connection gauge increment"] +pub struct RouteConnectionLease { + stats: Arc, + gauge: RouteConnectionGauge, + active: bool, +} - pub const fn as_str(self) -> &'static str { - match self { - Self::Normal => "normal", - Self::HardDetach => "hard_detach", +impl RouteConnectionLease { + fn new(stats: Arc, gauge: RouteConnectionGauge) -> Self { + Self { + stats, + gauge, + active: true, } } - const fn idx(self) -> usize { - self as usize + #[cfg(test)] + fn disarm(&mut self) { + self.active = false; } } -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] -#[repr(u8)] -pub enum MeWriterTeardownReason { - ReaderExit = 0, - WriterTaskExit = 1, - PingSendFail = 2, - SignalSendFail = 3, - RouteChannelClosed = 4, - CloseRpcChannelClosed = 5, - PruneClosedWriter = 6, - ReapTimeoutExpired = 7, - ReapThresholdForce = 8, - ReapEmpty = 9, - WatchdogStuckDraining = 10, -} - -impl MeWriterTeardownReason { - pub const ALL: [Self; ME_WRITER_TEARDOWN_REASON_COUNT] = [ - Self::ReaderExit, - Self::WriterTaskExit, - Self::PingSendFail, - Self::SignalSendFail, - Self::RouteChannelClosed, - Self::CloseRpcChannelClosed, - Self::PruneClosedWriter, - Self::ReapTimeoutExpired, - Self::ReapThresholdForce, - Self::ReapEmpty, - Self::WatchdogStuckDraining, - ]; - - pub const fn as_str(self) -> &'static str { - match self { - Self::ReaderExit => "reader_exit", - Self::WriterTaskExit => "writer_task_exit", - Self::PingSendFail => "ping_send_fail", - Self::SignalSendFail => "signal_send_fail", - Self::RouteChannelClosed => "route_channel_closed", - Self::CloseRpcChannelClosed => "close_rpc_channel_closed", - Self::PruneClosedWriter => "prune_closed_writer", - Self::ReapTimeoutExpired => "reap_timeout_expired", - Self::ReapThresholdForce => "reap_threshold_force", - Self::ReapEmpty => "reap_empty", - Self::WatchdogStuckDraining => "watchdog_stuck_draining", +impl Drop for RouteConnectionLease { + fn drop(&mut self) { + if !self.active { + return; } - } - - const fn idx(self) -> usize { - self as usize - } -} - -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] -#[repr(u8)] -pub enum MeWriterCleanupSideEffectStep { - CloseSignalChannelFull = 0, - CloseSignalChannelClosed = 1, -} - -impl MeWriterCleanupSideEffectStep { - pub const ALL: [Self; ME_WRITER_CLEANUP_SIDE_EFFECT_STEP_COUNT] = - [Self::CloseSignalChannelFull, Self::CloseSignalChannelClosed]; - - pub const fn as_str(self) -> &'static str { - match self { - Self::CloseSignalChannelFull => "close_signal_channel_full", - Self::CloseSignalChannelClosed => "close_signal_channel_closed", + match self.gauge { + RouteConnectionGauge::Direct => self.stats.decrement_current_connections_direct(), + RouteConnectionGauge::Middle => self.stats.decrement_current_connections_me(), } } - - const fn idx(self) -> usize { - self as usize - } } // ============= Stats ============= @@ -189,6 +99,10 @@ pub struct Stats { me_handshake_reject_total: AtomicU64, me_reader_eof_total: AtomicU64, me_idle_close_by_peer_total: AtomicU64, + relay_idle_soft_mark_total: AtomicU64, + relay_idle_hard_close_total: AtomicU64, + relay_pressure_evict_total: AtomicU64, + relay_protocol_desync_close_total: AtomicU64, me_crc_mismatch: AtomicU64, me_seq_mismatch: AtomicU64, me_endpoint_quarantine_total: AtomicU64, @@ -251,26 +165,9 @@ pub struct Stats { pool_swap_total: AtomicU64, pool_drain_active: AtomicU64, pool_force_close_total: AtomicU64, - pool_drain_soft_evict_total: AtomicU64, - pool_drain_soft_evict_writer_total: AtomicU64, pool_stale_pick_total: AtomicU64, - me_writer_close_signal_drop_total: AtomicU64, - me_writer_close_signal_channel_full_total: AtomicU64, - me_draining_writers_reap_progress_total: AtomicU64, me_writer_removed_total: AtomicU64, me_writer_removed_unexpected_total: AtomicU64, - me_writer_teardown_attempt_total: - [[AtomicU64; ME_WRITER_TEARDOWN_MODE_COUNT]; ME_WRITER_TEARDOWN_REASON_COUNT], - me_writer_teardown_success_total: [AtomicU64; ME_WRITER_TEARDOWN_MODE_COUNT], - me_writer_teardown_timeout_total: AtomicU64, - me_writer_teardown_escalation_total: AtomicU64, - me_writer_teardown_noop_total: AtomicU64, - me_writer_cleanup_side_effect_failures_total: - [AtomicU64; ME_WRITER_CLEANUP_SIDE_EFFECT_STEP_COUNT], - me_writer_teardown_duration_bucket_hits: - [[AtomicU64; ME_WRITER_TEARDOWN_DURATION_BUCKET_COUNT + 1]; ME_WRITER_TEARDOWN_MODE_COUNT], - me_writer_teardown_duration_sum_micros: [AtomicU64; ME_WRITER_TEARDOWN_MODE_COUNT], - me_writer_teardown_duration_count: [AtomicU64; ME_WRITER_TEARDOWN_MODE_COUNT], me_refill_triggered_total: AtomicU64, me_refill_skipped_inflight_total: AtomicU64, me_refill_failed_total: AtomicU64, @@ -281,11 +178,6 @@ pub struct Stats { me_inline_recovery_total: AtomicU64, ip_reservation_rollback_tcp_limit_total: AtomicU64, ip_reservation_rollback_quota_limit_total: AtomicU64, - relay_adaptive_promotions_total: AtomicU64, - relay_adaptive_demotions_total: AtomicU64, - relay_adaptive_hard_promotions_total: AtomicU64, - reconnect_evict_total: AtomicU64, - reconnect_stale_close_total: AtomicU64, telemetry_core_enabled: AtomicBool, telemetry_user_enabled: AtomicBool, telemetry_me_level: AtomicU8, @@ -438,35 +330,15 @@ impl Stats { pub fn decrement_current_connections_me(&self) { Self::decrement_atomic_saturating(&self.current_connections_me); } - pub fn increment_relay_adaptive_promotions_total(&self) { - if self.telemetry_core_enabled() { - self.relay_adaptive_promotions_total - .fetch_add(1, Ordering::Relaxed); - } + + pub fn acquire_direct_connection_lease(self: &Arc) -> RouteConnectionLease { + self.increment_current_connections_direct(); + RouteConnectionLease::new(self.clone(), RouteConnectionGauge::Direct) } - pub fn increment_relay_adaptive_demotions_total(&self) { - if self.telemetry_core_enabled() { - self.relay_adaptive_demotions_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_relay_adaptive_hard_promotions_total(&self) { - if self.telemetry_core_enabled() { - self.relay_adaptive_hard_promotions_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_reconnect_evict_total(&self) { - if self.telemetry_core_enabled() { - self.reconnect_evict_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_reconnect_stale_close_total(&self) { - if self.telemetry_core_enabled() { - self.reconnect_stale_close_total - .fetch_add(1, Ordering::Relaxed); - } + + pub fn acquire_me_connection_lease(self: &Arc) -> RouteConnectionLease { + self.increment_current_connections_me(); + RouteConnectionLease::new(self.clone(), RouteConnectionGauge::Middle) } pub fn increment_handshake_timeouts(&self) { if self.telemetry_core_enabled() { @@ -657,6 +529,30 @@ impl Stats { .fetch_add(1, Ordering::Relaxed); } } + pub fn increment_relay_idle_soft_mark_total(&self) { + if self.telemetry_me_allows_normal() { + self.relay_idle_soft_mark_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_relay_idle_hard_close_total(&self) { + if self.telemetry_me_allows_normal() { + self.relay_idle_hard_close_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_relay_pressure_evict_total(&self) { + if self.telemetry_me_allows_normal() { + self.relay_pressure_evict_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_relay_protocol_desync_close_total(&self) { + if self.telemetry_me_allows_normal() { + self.relay_protocol_desync_close_total + .fetch_add(1, Ordering::Relaxed); + } + } pub fn increment_me_crc_mismatch(&self) { if self.telemetry_me_allows_normal() { self.me_crc_mismatch.fetch_add(1, Ordering::Relaxed); @@ -863,41 +759,11 @@ impl Stats { self.pool_force_close_total.fetch_add(1, Ordering::Relaxed); } } - pub fn increment_pool_drain_soft_evict_total(&self) { - if self.telemetry_me_allows_normal() { - self.pool_drain_soft_evict_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_pool_drain_soft_evict_writer_total(&self) { - if self.telemetry_me_allows_normal() { - self.pool_drain_soft_evict_writer_total - .fetch_add(1, Ordering::Relaxed); - } - } pub fn increment_pool_stale_pick_total(&self) { if self.telemetry_me_allows_normal() { self.pool_stale_pick_total.fetch_add(1, Ordering::Relaxed); } } - pub fn increment_me_writer_close_signal_drop_total(&self) { - if self.telemetry_me_allows_normal() { - self.me_writer_close_signal_drop_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_writer_close_signal_channel_full_total(&self) { - if self.telemetry_me_allows_normal() { - self.me_writer_close_signal_channel_full_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_draining_writers_reap_progress_total(&self) { - if self.telemetry_me_allows_normal() { - self.me_draining_writers_reap_progress_total - .fetch_add(1, Ordering::Relaxed); - } - } pub fn increment_me_writer_removed_total(&self) { if self.telemetry_me_allows_debug() { self.me_writer_removed_total.fetch_add(1, Ordering::Relaxed); @@ -908,74 +774,6 @@ impl Stats { self.me_writer_removed_unexpected_total.fetch_add(1, Ordering::Relaxed); } } - pub fn increment_me_writer_teardown_attempt_total( - &self, - reason: MeWriterTeardownReason, - mode: MeWriterTeardownMode, - ) { - if self.telemetry_me_allows_normal() { - self.me_writer_teardown_attempt_total[reason.idx()][mode.idx()] - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_writer_teardown_success_total(&self, mode: MeWriterTeardownMode) { - if self.telemetry_me_allows_normal() { - self.me_writer_teardown_success_total[mode.idx()].fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_writer_teardown_timeout_total(&self) { - if self.telemetry_me_allows_normal() { - self.me_writer_teardown_timeout_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_writer_teardown_escalation_total(&self) { - if self.telemetry_me_allows_normal() { - self.me_writer_teardown_escalation_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_writer_teardown_noop_total(&self) { - if self.telemetry_me_allows_normal() { - self.me_writer_teardown_noop_total - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn increment_me_writer_cleanup_side_effect_failures_total( - &self, - step: MeWriterCleanupSideEffectStep, - ) { - if self.telemetry_me_allows_normal() { - self.me_writer_cleanup_side_effect_failures_total[step.idx()] - .fetch_add(1, Ordering::Relaxed); - } - } - pub fn observe_me_writer_teardown_duration( - &self, - mode: MeWriterTeardownMode, - duration: Duration, - ) { - if !self.telemetry_me_allows_normal() { - return; - } - let duration_micros = duration.as_micros().min(u64::MAX as u128) as u64; - let mut bucket_idx = ME_WRITER_TEARDOWN_DURATION_BUCKET_COUNT; - for (idx, upper_bound_micros) in ME_WRITER_TEARDOWN_DURATION_BUCKET_BOUNDS_MICROS - .iter() - .copied() - .enumerate() - { - if duration_micros <= upper_bound_micros { - bucket_idx = idx; - break; - } - } - self.me_writer_teardown_duration_bucket_hits[mode.idx()][bucket_idx] - .fetch_add(1, Ordering::Relaxed); - self.me_writer_teardown_duration_sum_micros[mode.idx()] - .fetch_add(duration_micros, Ordering::Relaxed); - self.me_writer_teardown_duration_count[mode.idx()].fetch_add(1, Ordering::Relaxed); - } pub fn increment_me_refill_triggered_total(&self) { if self.telemetry_me_allows_debug() { self.me_refill_triggered_total.fetch_add(1, Ordering::Relaxed); @@ -1214,22 +1012,6 @@ impl Stats { self.get_current_connections_direct() .saturating_add(self.get_current_connections_me()) } - pub fn get_relay_adaptive_promotions_total(&self) -> u64 { - self.relay_adaptive_promotions_total.load(Ordering::Relaxed) - } - pub fn get_relay_adaptive_demotions_total(&self) -> u64 { - self.relay_adaptive_demotions_total.load(Ordering::Relaxed) - } - pub fn get_relay_adaptive_hard_promotions_total(&self) -> u64 { - self.relay_adaptive_hard_promotions_total - .load(Ordering::Relaxed) - } - pub fn get_reconnect_evict_total(&self) -> u64 { - self.reconnect_evict_total.load(Ordering::Relaxed) - } - pub fn get_reconnect_stale_close_total(&self) -> u64 { - self.reconnect_stale_close_total.load(Ordering::Relaxed) - } pub fn get_me_keepalive_sent(&self) -> u64 { self.me_keepalive_sent.load(Ordering::Relaxed) } pub fn get_me_keepalive_failed(&self) -> u64 { self.me_keepalive_failed.load(Ordering::Relaxed) } pub fn get_me_keepalive_pong(&self) -> u64 { self.me_keepalive_pong.load(Ordering::Relaxed) } @@ -1265,6 +1047,18 @@ impl Stats { pub fn get_me_idle_close_by_peer_total(&self) -> u64 { self.me_idle_close_by_peer_total.load(Ordering::Relaxed) } + pub fn get_relay_idle_soft_mark_total(&self) -> u64 { + self.relay_idle_soft_mark_total.load(Ordering::Relaxed) + } + pub fn get_relay_idle_hard_close_total(&self) -> u64 { + self.relay_idle_hard_close_total.load(Ordering::Relaxed) + } + pub fn get_relay_pressure_evict_total(&self) -> u64 { + self.relay_pressure_evict_total.load(Ordering::Relaxed) + } + pub fn get_relay_protocol_desync_close_total(&self) -> u64 { + self.relay_protocol_desync_close_total.load(Ordering::Relaxed) + } pub fn get_me_crc_mismatch(&self) -> u64 { self.me_crc_mismatch.load(Ordering::Relaxed) } pub fn get_me_seq_mismatch(&self) -> u64 { self.me_seq_mismatch.load(Ordering::Relaxed) } pub fn get_me_endpoint_quarantine_total(&self) -> u64 { @@ -1482,105 +1276,15 @@ impl Stats { pub fn get_pool_force_close_total(&self) -> u64 { self.pool_force_close_total.load(Ordering::Relaxed) } - pub fn get_pool_drain_soft_evict_total(&self) -> u64 { - self.pool_drain_soft_evict_total.load(Ordering::Relaxed) - } - pub fn get_pool_drain_soft_evict_writer_total(&self) -> u64 { - self.pool_drain_soft_evict_writer_total.load(Ordering::Relaxed) - } pub fn get_pool_stale_pick_total(&self) -> u64 { self.pool_stale_pick_total.load(Ordering::Relaxed) } - pub fn get_me_writer_close_signal_drop_total(&self) -> u64 { - self.me_writer_close_signal_drop_total.load(Ordering::Relaxed) - } - pub fn get_me_writer_close_signal_channel_full_total(&self) -> u64 { - self.me_writer_close_signal_channel_full_total - .load(Ordering::Relaxed) - } - pub fn get_me_draining_writers_reap_progress_total(&self) -> u64 { - self.me_draining_writers_reap_progress_total - .load(Ordering::Relaxed) - } pub fn get_me_writer_removed_total(&self) -> u64 { self.me_writer_removed_total.load(Ordering::Relaxed) } pub fn get_me_writer_removed_unexpected_total(&self) -> u64 { self.me_writer_removed_unexpected_total.load(Ordering::Relaxed) } - pub fn get_me_writer_teardown_attempt_total( - &self, - reason: MeWriterTeardownReason, - mode: MeWriterTeardownMode, - ) -> u64 { - self.me_writer_teardown_attempt_total[reason.idx()][mode.idx()] - .load(Ordering::Relaxed) - } - pub fn get_me_writer_teardown_attempt_total_by_mode(&self, mode: MeWriterTeardownMode) -> u64 { - MeWriterTeardownReason::ALL - .iter() - .copied() - .map(|reason| self.get_me_writer_teardown_attempt_total(reason, mode)) - .sum() - } - pub fn get_me_writer_teardown_success_total(&self, mode: MeWriterTeardownMode) -> u64 { - self.me_writer_teardown_success_total[mode.idx()].load(Ordering::Relaxed) - } - pub fn get_me_writer_teardown_timeout_total(&self) -> u64 { - self.me_writer_teardown_timeout_total.load(Ordering::Relaxed) - } - pub fn get_me_writer_teardown_escalation_total(&self) -> u64 { - self.me_writer_teardown_escalation_total - .load(Ordering::Relaxed) - } - pub fn get_me_writer_teardown_noop_total(&self) -> u64 { - self.me_writer_teardown_noop_total.load(Ordering::Relaxed) - } - pub fn get_me_writer_cleanup_side_effect_failures_total( - &self, - step: MeWriterCleanupSideEffectStep, - ) -> u64 { - self.me_writer_cleanup_side_effect_failures_total[step.idx()] - .load(Ordering::Relaxed) - } - pub fn get_me_writer_cleanup_side_effect_failures_total_all(&self) -> u64 { - MeWriterCleanupSideEffectStep::ALL - .iter() - .copied() - .map(|step| self.get_me_writer_cleanup_side_effect_failures_total(step)) - .sum() - } - pub fn me_writer_teardown_duration_bucket_labels( - ) -> &'static [&'static str; ME_WRITER_TEARDOWN_DURATION_BUCKET_COUNT] { - &ME_WRITER_TEARDOWN_DURATION_BUCKET_LABELS - } - pub fn get_me_writer_teardown_duration_bucket_hits( - &self, - mode: MeWriterTeardownMode, - bucket_idx: usize, - ) -> u64 { - self.me_writer_teardown_duration_bucket_hits[mode.idx()][bucket_idx] - .load(Ordering::Relaxed) - } - pub fn get_me_writer_teardown_duration_bucket_total( - &self, - mode: MeWriterTeardownMode, - bucket_idx: usize, - ) -> u64 { - let capped_idx = bucket_idx.min(ME_WRITER_TEARDOWN_DURATION_BUCKET_COUNT); - let mut total = 0u64; - for idx in 0..=capped_idx { - total = total.saturating_add(self.get_me_writer_teardown_duration_bucket_hits(mode, idx)); - } - total - } - pub fn get_me_writer_teardown_duration_count(&self, mode: MeWriterTeardownMode) -> u64 { - self.me_writer_teardown_duration_count[mode.idx()].load(Ordering::Relaxed) - } - pub fn get_me_writer_teardown_duration_sum_seconds(&self, mode: MeWriterTeardownMode) -> f64 { - self.me_writer_teardown_duration_sum_micros[mode.idx()].load(Ordering::Relaxed) as f64 - / 1_000_000.0 - } pub fn get_me_refill_triggered_total(&self) -> u64 { self.me_refill_triggered_total.load(Ordering::Relaxed) } @@ -1643,11 +1347,35 @@ impl Stats { Self::touch_user_stats(stats.value()); stats.curr_connects.fetch_add(1, Ordering::Relaxed); } + + pub fn try_acquire_user_curr_connects(&self, user: &str, limit: Option) -> bool { + if !self.telemetry_user_enabled() { + return true; + } + + self.maybe_cleanup_user_stats(); + let stats = self.user_stats.entry(user.to_string()).or_default(); + Self::touch_user_stats(stats.value()); + + let counter = &stats.curr_connects; + let mut current = counter.load(Ordering::Relaxed); + loop { + if let Some(max) = limit && current >= max { + return false; + } + match counter.compare_exchange_weak( + current, + current.saturating_add(1), + Ordering::Relaxed, + Ordering::Relaxed, + ) { + Ok(_) => return true, + Err(actual) => current = actual, + } + } + } pub fn decrement_user_curr_connects(&self, user: &str) { - if !self.telemetry_user_enabled() { - return; - } self.maybe_cleanup_user_stats(); if let Some(stats) = self.user_stats.get(user) { Self::touch_user_stats(stats.value()); @@ -1820,9 +1548,11 @@ impl Stats { // ============= Replay Checker ============= pub struct ReplayChecker { - shards: Vec>, + handshake_shards: Vec>, + tls_shards: Vec>, shard_mask: usize, window: Duration, + tls_window: Duration, checks: AtomicU64, hits: AtomicU64, additions: AtomicU64, @@ -1899,19 +1629,24 @@ impl ReplayShard { impl ReplayChecker { pub fn new(total_capacity: usize, window: Duration) -> Self { + const MIN_TLS_REPLAY_WINDOW: Duration = Duration::from_secs(120); let num_shards = 64; let shard_capacity = (total_capacity / num_shards).max(1); let cap = NonZeroUsize::new(shard_capacity).unwrap(); - let mut shards = Vec::with_capacity(num_shards); + let mut handshake_shards = Vec::with_capacity(num_shards); + let mut tls_shards = Vec::with_capacity(num_shards); for _ in 0..num_shards { - shards.push(Mutex::new(ReplayShard::new(cap))); + handshake_shards.push(Mutex::new(ReplayShard::new(cap))); + tls_shards.push(Mutex::new(ReplayShard::new(cap))); } Self { - shards, + handshake_shards, + tls_shards, shard_mask: num_shards - 1, window, + tls_window: window.max(MIN_TLS_REPLAY_WINDOW), checks: AtomicU64::new(0), hits: AtomicU64::new(0), additions: AtomicU64::new(0), @@ -1925,46 +1660,60 @@ impl ReplayChecker { (hasher.finish() as usize) & self.shard_mask } - fn check_and_add_internal(&self, data: &[u8]) -> bool { + fn check_and_add_internal( + &self, + data: &[u8], + shards: &[Mutex], + window: Duration, + ) -> bool { self.checks.fetch_add(1, Ordering::Relaxed); let idx = self.get_shard_idx(data); - let mut shard = self.shards[idx].lock(); + let mut shard = shards[idx].lock(); let now = Instant::now(); - let found = shard.check(data, now, self.window); + let found = shard.check(data, now, window); if found { self.hits.fetch_add(1, Ordering::Relaxed); } else { - shard.add(data, now, self.window); + shard.add(data, now, window); self.additions.fetch_add(1, Ordering::Relaxed); } found } - fn add_only(&self, data: &[u8]) { + fn add_only(&self, data: &[u8], shards: &[Mutex], window: Duration) { self.additions.fetch_add(1, Ordering::Relaxed); let idx = self.get_shard_idx(data); - let mut shard = self.shards[idx].lock(); - shard.add(data, Instant::now(), self.window); + let mut shard = shards[idx].lock(); + shard.add(data, Instant::now(), window); } pub fn check_and_add_handshake(&self, data: &[u8]) -> bool { - self.check_and_add_internal(data) + self.check_and_add_internal(data, &self.handshake_shards, self.window) } pub fn check_and_add_tls_digest(&self, data: &[u8]) -> bool { - self.check_and_add_internal(data) + self.check_and_add_internal(data, &self.tls_shards, self.tls_window) } // Compatibility helpers (non-atomic split operations) — prefer check_and_add_*. pub fn check_handshake(&self, data: &[u8]) -> bool { self.check_and_add_handshake(data) } - pub fn add_handshake(&self, data: &[u8]) { self.add_only(data) } + pub fn add_handshake(&self, data: &[u8]) { + self.add_only(data, &self.handshake_shards, self.window) + } pub fn check_tls_digest(&self, data: &[u8]) -> bool { self.check_and_add_tls_digest(data) } - pub fn add_tls_digest(&self, data: &[u8]) { self.add_only(data) } + pub fn add_tls_digest(&self, data: &[u8]) { + self.add_only(data, &self.tls_shards, self.tls_window) + } pub fn stats(&self) -> ReplayStats { let mut total_entries = 0; let mut total_queue_len = 0; - for shard in &self.shards { + for shard in &self.handshake_shards { + let s = shard.lock(); + total_entries += s.cache.len(); + total_queue_len += s.queue.len(); + } + for shard in &self.tls_shards { let s = shard.lock(); total_entries += s.cache.len(); total_queue_len += s.queue.len(); @@ -1977,7 +1726,7 @@ impl ReplayChecker { total_hits: self.hits.load(Ordering::Relaxed), total_additions: self.additions.load(Ordering::Relaxed), total_cleanups: self.cleanups.load(Ordering::Relaxed), - num_shards: self.shards.len(), + num_shards: self.handshake_shards.len() + self.tls_shards.len(), window_secs: self.window.as_secs(), } } @@ -1995,13 +1744,20 @@ impl ReplayChecker { let now = Instant::now(); let mut cleaned = 0usize; - for shard_mutex in &self.shards { + for shard_mutex in &self.handshake_shards { let mut shard = shard_mutex.lock(); let before = shard.len(); shard.cleanup(now, self.window); let after = shard.len(); cleaned += before.saturating_sub(after); } + for shard_mutex in &self.tls_shards { + let mut shard = shard_mutex.lock(); + let before = shard.len(); + shard.cleanup(now, self.tls_window); + let after = shard.len(); + cleaned += before.saturating_sub(after); + } self.cleanups.fetch_add(1, Ordering::Relaxed); @@ -2084,79 +1840,6 @@ mod tests { assert_eq!(stats.get_me_keepalive_sent(), 0); assert_eq!(stats.get_me_route_drop_queue_full(), 0); } - - #[test] - fn test_teardown_counters_and_duration() { - let stats = Stats::new(); - stats.increment_me_writer_teardown_attempt_total( - MeWriterTeardownReason::ReaderExit, - MeWriterTeardownMode::Normal, - ); - stats.increment_me_writer_teardown_success_total(MeWriterTeardownMode::Normal); - stats.observe_me_writer_teardown_duration( - MeWriterTeardownMode::Normal, - Duration::from_millis(3), - ); - stats.increment_me_writer_cleanup_side_effect_failures_total( - MeWriterCleanupSideEffectStep::CloseSignalChannelFull, - ); - - assert_eq!( - stats.get_me_writer_teardown_attempt_total( - MeWriterTeardownReason::ReaderExit, - MeWriterTeardownMode::Normal - ), - 1 - ); - assert_eq!( - stats.get_me_writer_teardown_success_total(MeWriterTeardownMode::Normal), - 1 - ); - assert_eq!( - stats.get_me_writer_teardown_duration_count(MeWriterTeardownMode::Normal), - 1 - ); - assert!( - stats.get_me_writer_teardown_duration_sum_seconds(MeWriterTeardownMode::Normal) > 0.0 - ); - assert_eq!( - stats.get_me_writer_cleanup_side_effect_failures_total( - MeWriterCleanupSideEffectStep::CloseSignalChannelFull - ), - 1 - ); - } - - #[test] - fn test_teardown_counters_respect_me_silent() { - let stats = Stats::new(); - stats.apply_telemetry_policy(TelemetryPolicy { - core_enabled: true, - user_enabled: true, - me_level: MeTelemetryLevel::Silent, - }); - stats.increment_me_writer_teardown_attempt_total( - MeWriterTeardownReason::ReaderExit, - MeWriterTeardownMode::Normal, - ); - stats.increment_me_writer_teardown_timeout_total(); - stats.observe_me_writer_teardown_duration( - MeWriterTeardownMode::Normal, - Duration::from_millis(1), - ); - assert_eq!( - stats.get_me_writer_teardown_attempt_total( - MeWriterTeardownReason::ReaderExit, - MeWriterTeardownMode::Normal - ), - 0 - ); - assert_eq!(stats.get_me_writer_teardown_timeout_total(), 0); - assert_eq!( - stats.get_me_writer_teardown_duration_count(MeWriterTeardownMode::Normal), - 0 - ); - } #[test] fn test_replay_checker_basic() { @@ -2200,7 +1883,7 @@ mod tests { fn test_replay_checker_many_keys() { let checker = ReplayChecker::new(10_000, Duration::from_secs(60)); for i in 0..500u32 { - checker.add_only(&i.to_le_bytes()); + checker.add_handshake(&i.to_le_bytes()); } for i in 0..500u32 { assert!(checker.check_handshake(&i.to_le_bytes())); @@ -2208,3 +1891,11 @@ mod tests { assert_eq!(checker.stats().total_entries, 500); } } + +#[cfg(test)] +#[path = "tests/connection_lease_security_tests.rs"] +mod connection_lease_security_tests; + +#[cfg(test)] +#[path = "tests/replay_checker_security_tests.rs"] +mod replay_checker_security_tests; diff --git a/src/stats/tests/connection_lease_security_tests.rs b/src/stats/tests/connection_lease_security_tests.rs new file mode 100644 index 0000000..69ae89a --- /dev/null +++ b/src/stats/tests/connection_lease_security_tests.rs @@ -0,0 +1,265 @@ +use super::*; +use std::panic::{self, AssertUnwindSafe}; +use std::sync::Arc; +use std::time::Duration; +use tokio::sync::Barrier; + +#[test] +fn direct_connection_lease_balances_on_drop() { + let stats = Arc::new(Stats::new()); + assert_eq!(stats.get_current_connections_direct(), 0); + + { + let _lease = stats.acquire_direct_connection_lease(); + assert_eq!(stats.get_current_connections_direct(), 1); + } + + assert_eq!(stats.get_current_connections_direct(), 0); +} + +#[test] +fn middle_connection_lease_balances_on_drop() { + let stats = Arc::new(Stats::new()); + assert_eq!(stats.get_current_connections_me(), 0); + + { + let _lease = stats.acquire_me_connection_lease(); + assert_eq!(stats.get_current_connections_me(), 1); + } + + assert_eq!(stats.get_current_connections_me(), 0); +} + +#[test] +fn connection_lease_disarm_prevents_double_release() { + let stats = Arc::new(Stats::new()); + + let mut lease = stats.acquire_direct_connection_lease(); + assert_eq!(stats.get_current_connections_direct(), 1); + + stats.decrement_current_connections_direct(); + assert_eq!(stats.get_current_connections_direct(), 0); + + lease.disarm(); + drop(lease); + + assert_eq!(stats.get_current_connections_direct(), 0); +} + +#[test] +fn direct_connection_lease_balances_on_panic_unwind() { + let stats = Arc::new(Stats::new()); + let stats_for_panic = stats.clone(); + + let panic_result = panic::catch_unwind(AssertUnwindSafe(move || { + let _lease = stats_for_panic.acquire_direct_connection_lease(); + panic!("intentional panic to verify lease drop path"); + })); + + assert!(panic_result.is_err(), "panic must propagate from test closure"); + assert_eq!( + stats.get_current_connections_direct(), + 0, + "panic unwind must release direct route gauge" + ); +} + +#[test] +fn middle_connection_lease_balances_on_panic_unwind() { + let stats = Arc::new(Stats::new()); + let stats_for_panic = stats.clone(); + + let panic_result = panic::catch_unwind(AssertUnwindSafe(move || { + let _lease = stats_for_panic.acquire_me_connection_lease(); + panic!("intentional panic to verify middle lease drop path"); + })); + + assert!(panic_result.is_err(), "panic must propagate from test closure"); + assert_eq!( + stats.get_current_connections_me(), + 0, + "panic unwind must release middle route gauge" + ); +} + +#[tokio::test] +async fn concurrent_mixed_route_lease_churn_balances_to_zero() { + const TASKS: usize = 48; + const ITERATIONS_PER_TASK: usize = 256; + + let stats = Arc::new(Stats::new()); + let barrier = Arc::new(Barrier::new(TASKS)); + let mut workers = Vec::with_capacity(TASKS); + + for task_idx in 0..TASKS { + let stats_for_task = stats.clone(); + let barrier_for_task = barrier.clone(); + workers.push(tokio::spawn(async move { + barrier_for_task.wait().await; + for iter in 0..ITERATIONS_PER_TASK { + if (task_idx + iter) % 2 == 0 { + let _lease = stats_for_task.acquire_direct_connection_lease(); + tokio::task::yield_now().await; + } else { + let _lease = stats_for_task.acquire_me_connection_lease(); + tokio::task::yield_now().await; + } + } + })); + } + + for worker in workers { + worker + .await + .expect("lease churn worker must not panic"); + } + + assert_eq!( + stats.get_current_connections_direct(), + 0, + "direct route gauge must return to zero after concurrent lease churn" + ); + assert_eq!( + stats.get_current_connections_me(), + 0, + "middle route gauge must return to zero after concurrent lease churn" + ); +} + +#[tokio::test] +async fn abort_storm_mixed_route_leases_returns_all_gauges_to_zero() { + const TASKS: usize = 64; + + let stats = Arc::new(Stats::new()); + let mut workers = Vec::with_capacity(TASKS); + + for task_idx in 0..TASKS { + let stats_for_task = stats.clone(); + workers.push(tokio::spawn(async move { + if task_idx % 2 == 0 { + let _lease = stats_for_task.acquire_direct_connection_lease(); + tokio::time::sleep(Duration::from_secs(60)).await; + } else { + let _lease = stats_for_task.acquire_me_connection_lease(); + tokio::time::sleep(Duration::from_secs(60)).await; + } + })); + } + + tokio::time::timeout(Duration::from_secs(2), async { + loop { + let total = stats.get_current_connections_direct() + stats.get_current_connections_me(); + if total == TASKS as u64 { + break; + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + }) + .await + .expect("all storm tasks must acquire route leases before abort"); + + for worker in &workers { + worker.abort(); + } + for worker in workers { + let joined = worker.await; + assert!(joined.is_err(), "aborted worker must return join error"); + } + + tokio::time::timeout(Duration::from_secs(2), async { + loop { + if stats.get_current_connections_direct() == 0 && stats.get_current_connections_me() == 0 { + break; + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + }) + .await + .expect("all route gauges must drain to zero after abort storm"); +} + +#[test] +fn saturating_route_decrements_do_not_underflow_under_race() { + const THREADS: usize = 16; + const DECREMENTS_PER_THREAD: usize = 4096; + + let stats = Arc::new(Stats::new()); + let mut workers = Vec::with_capacity(THREADS); + + for _ in 0..THREADS { + let stats_for_thread = stats.clone(); + workers.push(std::thread::spawn(move || { + for _ in 0..DECREMENTS_PER_THREAD { + stats_for_thread.decrement_current_connections_direct(); + stats_for_thread.decrement_current_connections_me(); + } + })); + } + + for worker in workers { + worker + .join() + .expect("decrement race worker must not panic"); + } + + assert_eq!( + stats.get_current_connections_direct(), + 0, + "direct route decrement races must never underflow" + ); + assert_eq!( + stats.get_current_connections_me(), + 0, + "middle route decrement races must never underflow" + ); +} + +#[tokio::test] +async fn direct_connection_lease_balances_on_task_abort() { + let stats = Arc::new(Stats::new()); + let stats_for_task = stats.clone(); + + let task = tokio::spawn(async move { + let _lease = stats_for_task.acquire_direct_connection_lease(); + tokio::time::sleep(Duration::from_secs(60)).await; + }); + + tokio::time::sleep(Duration::from_millis(20)).await; + assert_eq!(stats.get_current_connections_direct(), 1); + + task.abort(); + let joined = task.await; + assert!(joined.is_err(), "aborted task must return a join error"); + + tokio::time::sleep(Duration::from_millis(20)).await; + assert_eq!( + stats.get_current_connections_direct(), + 0, + "aborted task must release direct route gauge" + ); +} + +#[tokio::test] +async fn middle_connection_lease_balances_on_task_abort() { + let stats = Arc::new(Stats::new()); + let stats_for_task = stats.clone(); + + let task = tokio::spawn(async move { + let _lease = stats_for_task.acquire_me_connection_lease(); + tokio::time::sleep(Duration::from_secs(60)).await; + }); + + tokio::time::sleep(Duration::from_millis(20)).await; + assert_eq!(stats.get_current_connections_me(), 1); + + task.abort(); + let joined = task.await; + assert!(joined.is_err(), "aborted task must return a join error"); + + tokio::time::sleep(Duration::from_millis(20)).await; + assert_eq!( + stats.get_current_connections_me(), + 0, + "aborted task must release middle route gauge" + ); +} diff --git a/src/stats/tests/replay_checker_security_tests.rs b/src/stats/tests/replay_checker_security_tests.rs new file mode 100644 index 0000000..8e73204 --- /dev/null +++ b/src/stats/tests/replay_checker_security_tests.rs @@ -0,0 +1,80 @@ +use super::*; +use std::time::Duration; + +#[test] +fn replay_checker_keeps_tls_and_handshake_domains_isolated_for_same_key() { + let checker = ReplayChecker::new(128, Duration::from_millis(20)); + let key = b"same-key-domain-separation"; + + assert!( + !checker.check_and_add_handshake(key), + "first handshake use should be fresh" + ); + assert!( + !checker.check_and_add_tls_digest(key), + "same bytes in TLS domain should still be fresh" + ); + + assert!( + checker.check_and_add_handshake(key), + "second handshake use should be replay-hit" + ); + assert!( + checker.check_and_add_tls_digest(key), + "second TLS use should be replay-hit independently" + ); +} + +#[test] +fn replay_checker_tls_window_is_clamped_beyond_small_handshake_window() { + let checker = ReplayChecker::new(128, Duration::from_millis(20)); + let handshake_key = b"short-window-handshake"; + let tls_key = b"short-window-tls"; + + assert!(!checker.check_and_add_handshake(handshake_key)); + assert!(!checker.check_and_add_tls_digest(tls_key)); + + std::thread::sleep(Duration::from_millis(80)); + + assert!( + !checker.check_and_add_handshake(handshake_key), + "handshake key should expire under short configured window" + ); + assert!( + checker.check_and_add_tls_digest(tls_key), + "TLS key should still be replay-hit because TLS window is clamped to a secure minimum" + ); +} + +#[test] +fn replay_checker_compat_add_paths_do_not_cross_pollute_domains() { + let checker = ReplayChecker::new(128, Duration::from_secs(1)); + let key = b"compat-domain-separation"; + + checker.add_handshake(key); + assert!( + checker.check_and_add_handshake(key), + "handshake add helper must populate handshake domain" + ); + assert!( + !checker.check_and_add_tls_digest(key), + "handshake add helper must not pollute TLS domain" + ); + + checker.add_tls_digest(key); + assert!( + checker.check_and_add_tls_digest(key), + "TLS add helper must populate TLS domain" + ); +} + +#[test] +fn replay_checker_stats_reflect_dual_shard_domains() { + let checker = ReplayChecker::new(128, Duration::from_secs(1)); + let stats = checker.stats(); + + assert_eq!( + stats.num_shards, 128, + "stats should expose both shard domains (handshake + TLS)" + ); +} diff --git a/src/stream/frame_codec.rs b/src/stream/frame_codec.rs index 2ff7de7..403f695 100644 --- a/src/stream/frame_codec.rs +++ b/src/stream/frame_codec.rs @@ -513,6 +513,7 @@ impl FrameCodecTrait for SecureCodec { #[cfg(test)] mod tests { use super::*; + use std::collections::HashSet; use tokio_util::codec::{FramedRead, FramedWrite}; use tokio::io::duplex; use futures::{SinkExt, StreamExt}; @@ -630,4 +631,31 @@ mod tests { let result = codec.decode(&mut buf); assert!(result.is_err()); } + + #[test] + fn secure_codec_always_adds_padding_and_jitters_wire_length() { + let codec = SecureCodec::new(Arc::new(SecureRandom::new())); + let payload = Bytes::from_static(&[1, 2, 3, 4, 5, 6, 7, 8]); + let mut wire_lens = HashSet::new(); + + for _ in 0..64 { + let frame = Frame::new(payload.clone()); + let mut out = BytesMut::new(); + codec.encode(&frame, &mut out).unwrap(); + + assert!(out.len() >= 4 + payload.len() + 1); + let wire_len = u32::from_le_bytes([out[0], out[1], out[2], out[3]]) as usize; + assert!( + (payload.len() + 1..=payload.len() + 3).contains(&wire_len), + "Secure wire length must be payload+1..3, got {wire_len}" + ); + assert_ne!(wire_len % 4, 0, "Secure wire length must be non-4-aligned"); + wire_lens.insert(wire_len); + } + + assert!( + wire_lens.len() >= 2, + "Secure padding should create observable wire-length jitter" + ); + } } diff --git a/src/stream/tls_stream.rs b/src/stream/tls_stream.rs index fe28542..7053a7b 100644 --- a/src/stream/tls_stream.rs +++ b/src/stream/tls_stream.rs @@ -12,7 +12,7 @@ //! Telegram MTProto proxy "FakeTLS" mode uses a TLS-looking outer layer for //! domain fronting / traffic camouflage. iOS Telegram clients are known to //! produce slightly different TLS record sizing patterns than Android/Desktop, -//! including records that exceed 16384 payload bytes by a small overhead. +//! including records that exceed MAX_TLS_PLAINTEXT_SIZE payload bytes by a small overhead. //! //! Key design principles: //! - Explicit state machines for all async operations @@ -23,14 +23,13 @@ //! - Proper handling of all TLS record types //! //! Important nuance (Telegram FakeTLS): -//! - The TLS spec limits "plaintext fragments" to 2^14 (16384) bytes. -//! - However, the on-the-wire record length can exceed 16384 because TLS 1.3 +//! - The TLS spec limits "plaintext fragments" to MAX_TLS_PLAINTEXT_SIZE bytes. +//! - However, the on-the-wire record length can exceed MAX_TLS_PLAINTEXT_SIZE because TLS 1.3 //! uses AEAD and can include tag/overhead/padding. //! - Telegram FakeTLS clients (notably iOS) may send Application Data records -//! with length up to 16384 + 256 bytes (RFC 8446 §5.2). We accept that as -//! MAX_TLS_CHUNK_SIZE. +//! with length up to MAX_TLS_CIPHERTEXT_SIZE bytes (RFC 8446 §5.2). //! -//! If you reject those (e.g. validate length <= 16384), you will see errors like: +//! If you reject those (e.g. validate length <= MAX_TLS_PLAINTEXT_SIZE), you will see errors like: //! "TLS record too large: 16408 bytes" //! and uploads from iOS will break (media/file sending), while small traffic //! may still work. @@ -42,10 +41,11 @@ use std::task::{Context, Poll}; use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt, ReadBuf}; use crate::protocol::constants::{ + MAX_TLS_PLAINTEXT_SIZE, TLS_VERSION, TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER, TLS_RECORD_HANDSHAKE, TLS_RECORD_ALERT, - MAX_TLS_CHUNK_SIZE, + MAX_TLS_CIPHERTEXT_SIZE, }; use super::state::{StreamState, HeaderBuffer, YieldBuffer, WriteBuffer}; @@ -56,7 +56,7 @@ const TLS_HEADER_SIZE: usize = 5; /// Maximum TLS fragment size we emit for Application Data. /// Real TLS 1.3 allows up to 16384 + 256 bytes of ciphertext (incl. tag). -const MAX_TLS_PAYLOAD: usize = 16384 + 256; +const MAX_TLS_PAYLOAD: usize = MAX_TLS_CIPHERTEXT_SIZE; /// Maximum pending write buffer for one record remainder. /// Note: we never queue unlimited amount of data here; state holds at most one record. @@ -93,10 +93,10 @@ impl TlsRecordHeader { /// - We accept TLS 1.0 header version for ClientHello-like records (0x03 0x01), /// and TLS 1.2/1.3 style version bytes for the rest (we use TLS_VERSION = 0x03 0x03). /// - For Application Data, Telegram FakeTLS may send payload length up to - /// MAX_TLS_CHUNK_SIZE (16384 + 256). + /// MAX_TLS_CIPHERTEXT_SIZE (16384 + 256). /// - For other record types we keep stricter bounds to avoid memory abuse. fn validate(&self) -> Result<()> { - // Version: accept TLS 1.0 header (ClientHello quirk) and TLS_VERSION (0x0303). + // Version precheck: only 0x0301 and 0x0303 are recognized at all. if self.version != [0x03, 0x01] && self.version != TLS_VERSION { return Err(Error::new( ErrorKind::InvalidData, @@ -104,31 +104,75 @@ impl TlsRecordHeader { )); } + // Narrow FakeTLS wire profile: TLS 1.0 compatibility header is allowed + // only on Handshake records (ClientHello compatibility quirk). + if self.record_type != TLS_RECORD_HANDSHAKE && self.version != TLS_VERSION { + return Err(Error::new( + ErrorKind::InvalidData, + format!( + "invalid TLS version for record type 0x{:02x}: {:02x?}", + self.record_type, + self.version + ), + )); + } + let len = self.length as usize; // Length checks depend on record type. - // Telegram FakeTLS: ApplicationData length may be 16384 + 256. + // Telegram FakeTLS: ApplicationData may use ciphertext envelope limit, + // while control records stay structurally strict to reduce probe surface. match self.record_type { TLS_RECORD_APPLICATION => { - if len > MAX_TLS_CHUNK_SIZE { + if len == 0 || len > MAX_TLS_CIPHERTEXT_SIZE { return Err(Error::new( ErrorKind::InvalidData, - format!("TLS record too large: {} bytes (max {})", len, MAX_TLS_CHUNK_SIZE), + format!( + "invalid TLS application data length: {} (min 1, max {})", + len, + MAX_TLS_CIPHERTEXT_SIZE + ), )); } } - // ChangeCipherSpec/Alert/Handshake should never be that large for our usage - // (post-handshake we don't expect Handshake at all). - // Keep strict to reduce attack surface. - _ => { - if len > MAX_TLS_PAYLOAD { + TLS_RECORD_CHANGE_CIPHER => { + if len != 1 { return Err(Error::new( ErrorKind::InvalidData, - format!("TLS control record too large: {} bytes (max {})", len, MAX_TLS_PAYLOAD), + format!("invalid TLS ChangeCipherSpec length: {} (expected 1)", len), )); } } + + TLS_RECORD_ALERT => { + if len != 2 { + return Err(Error::new( + ErrorKind::InvalidData, + format!("invalid TLS alert length: {} (expected 2)", len), + )); + } + } + + TLS_RECORD_HANDSHAKE => { + if len < 4 || len > MAX_TLS_PLAINTEXT_SIZE { + return Err(Error::new( + ErrorKind::InvalidData, + format!( + "invalid TLS handshake length: {} (min 4, max {})", + len, + MAX_TLS_PLAINTEXT_SIZE + ), + )); + } + } + + _ => { + return Err(Error::new( + ErrorKind::InvalidData, + format!("unknown TLS record type: 0x{:02x}", self.record_type), + )); + } } Ok(()) @@ -250,6 +294,19 @@ impl FakeTlsReader { pub fn get_mut(&mut self) -> &mut R { &mut self.upstream } pub fn into_inner(self) -> R { self.upstream } + pub fn into_inner_with_pending_plaintext(mut self) -> (R, Vec) { + let pending = match std::mem::replace(&mut self.state, TlsReaderState::Idle) { + TlsReaderState::Yielding { buffer } => buffer.as_slice().to_vec(), + TlsReaderState::ReadingBody { record_type, buffer, .. } + if record_type == TLS_RECORD_APPLICATION => + { + buffer.to_vec() + } + _ => Vec::new(), + }; + (self.upstream, pending) + } + pub fn is_poisoned(&self) -> bool { self.state.is_poisoned() } pub fn state_name(&self) -> &'static str { self.state.state_name() } @@ -584,10 +641,10 @@ impl StreamState for TlsWriterState { /// Writer that wraps bytes into TLS 1.3 Application Data records. /// -/// We chunk outgoing data into records of <= 16384 payload bytes (MAX_TLS_PAYLOAD). +/// We chunk outgoing data into records of <= MAX_TLS_CIPHERTEXT_SIZE payload bytes. /// We do not try to mimic AEAD overhead on the wire; Telegram clients accept it. /// If you want to be more camouflage-accurate later, you could add optional padding -/// to produce records sized closer to MAX_TLS_CHUNK_SIZE. +/// to produce records sized closer to MAX_TLS_CIPHERTEXT_SIZE. pub struct FakeTlsWriter { upstream: W, state: TlsWriterState, @@ -823,7 +880,7 @@ impl AsyncWrite for FakeTlsWriter { impl FakeTlsWriter { /// Write all data wrapped in TLS records. /// - /// Convenience method that chunks into <= 16384 records. + /// Convenience method that chunks into <= MAX_TLS_CIPHERTEXT_SIZE records. pub async fn write_all_tls(&mut self, data: &[u8]) -> Result<()> { let mut written = 0; while written < data.len() { @@ -838,6 +895,10 @@ impl FakeTlsWriter { } } +#[cfg(test)] +#[path = "tls_stream_size_adversarial_tests.rs"] +mod size_adversarial_tests; + // ============= Tests ============= #[cfg(test)] @@ -1237,3 +1298,7 @@ mod tests { assert_eq!(bytes, [0x17, 0x03, 0x03, 0x12, 0x34]); } } + +#[cfg(test)] +#[path = "tls_stream_pending_plaintext_security_tests.rs"] +mod pending_plaintext_security_tests; diff --git a/src/stream/tls_stream_pending_plaintext_security_tests.rs b/src/stream/tls_stream_pending_plaintext_security_tests.rs new file mode 100644 index 0000000..30a11ad --- /dev/null +++ b/src/stream/tls_stream_pending_plaintext_security_tests.rs @@ -0,0 +1,143 @@ +use super::*; +use bytes::{Bytes, BytesMut}; + +#[test] +fn reading_body_pending_application_plaintext_is_preserved_on_into_inner() { + let sample = b"coalesced-tail-after-mtproto"; + let mut reader = FakeTlsReader::new(tokio::io::empty()); + reader.state = TlsReaderState::ReadingBody { + record_type: TLS_RECORD_APPLICATION, + length: sample.len(), + buffer: BytesMut::from(&sample[..]), + }; + + let (_inner, pending) = reader.into_inner_with_pending_plaintext(); + assert_eq!( + pending, + sample, + "partial application-data body must survive into fallback path" + ); +} + +#[test] +fn yielding_pending_plaintext_is_preserved_on_into_inner() { + let sample = b"already-decoded-buffer"; + let mut reader = FakeTlsReader::new(tokio::io::empty()); + reader.state = TlsReaderState::Yielding { + buffer: YieldBuffer::new(Bytes::copy_from_slice(sample)), + }; + + let (_inner, pending) = reader.into_inner_with_pending_plaintext(); + assert_eq!(pending, sample); +} + +#[test] +fn reading_body_non_application_record_does_not_produce_plaintext() { + let sample = b"unexpected-handshake-fragment"; + let mut reader = FakeTlsReader::new(tokio::io::empty()); + reader.state = TlsReaderState::ReadingBody { + record_type: TLS_RECORD_HANDSHAKE, + length: sample.len(), + buffer: BytesMut::from(&sample[..]), + }; + + let (_inner, pending) = reader.into_inner_with_pending_plaintext(); + assert!( + pending.is_empty(), + "non-application partial body must not be surfaced as plaintext" + ); +} + +#[test] +fn partial_header_state_does_not_produce_plaintext() { + let mut header = HeaderBuffer::::new(); + let unfilled = header.unfilled_mut(); + unfilled[0] = TLS_RECORD_APPLICATION; + header.advance(1); + + let mut reader = FakeTlsReader::new(tokio::io::empty()); + reader.state = TlsReaderState::ReadingHeader { header }; + + let (_inner, pending) = reader.into_inner_with_pending_plaintext(); + assert!(pending.is_empty(), "partial header bytes are not plaintext payload"); +} + +#[test] +fn edge_zero_length_application_fragment_remains_empty_without_panics() { + let mut reader = FakeTlsReader::new(tokio::io::empty()); + reader.state = TlsReaderState::ReadingBody { + record_type: TLS_RECORD_APPLICATION, + length: 0, + buffer: BytesMut::new(), + }; + + let (_inner, pending) = reader.into_inner_with_pending_plaintext(); + assert!(pending.is_empty()); +} + +#[test] +fn adversarial_poisoned_state_never_leaks_pending_bytes() { + let mut reader = FakeTlsReader::new(tokio::io::empty()); + reader.state = TlsReaderState::Poisoned { + error: Some(std::io::Error::other("poisoned by adversarial input")), + }; + + let (_inner, pending) = reader.into_inner_with_pending_plaintext(); + assert!(pending.is_empty(), "poisoned state must fail-closed for fallback payload"); +} + +#[test] +fn stress_large_application_fragment_survives_state_extraction() { + let mut payload = vec![0u8; 96 * 1024]; + for (i, b) in payload.iter_mut().enumerate() { + *b = (i as u8).wrapping_mul(17).wrapping_add(3); + } + + let mut reader = FakeTlsReader::new(tokio::io::empty()); + reader.state = TlsReaderState::ReadingBody { + record_type: TLS_RECORD_APPLICATION, + length: payload.len(), + buffer: BytesMut::from(&payload[..]), + }; + + let (_inner, pending) = reader.into_inner_with_pending_plaintext(); + assert_eq!(pending, payload, "large pending application plaintext must be preserved exactly"); +} + +#[test] +fn light_fuzz_state_matrix_preserves_pending_contract() { + let mut seed = 0x9E37_79B9_7F4A_7C15u64; + + for _ in 0..4096 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + + let len = (seed & 0x1ff) as usize; + let mut payload = vec![0u8; len]; + for (idx, b) in payload.iter_mut().enumerate() { + *b = (seed as u8).wrapping_add(idx as u8); + } + + let record_type = match seed & 0x3 { + 0 => TLS_RECORD_APPLICATION, + 1 => TLS_RECORD_HANDSHAKE, + 2 => TLS_RECORD_ALERT, + _ => TLS_RECORD_CHANGE_CIPHER, + }; + + let mut reader = FakeTlsReader::new(tokio::io::empty()); + reader.state = TlsReaderState::ReadingBody { + record_type, + length: payload.len(), + buffer: BytesMut::from(&payload[..]), + }; + + let (_inner, pending) = reader.into_inner_with_pending_plaintext(); + if record_type == TLS_RECORD_APPLICATION { + assert_eq!(pending, payload); + } else { + assert!(pending.is_empty()); + } + } +} diff --git a/src/stream/tls_stream_size_adversarial_tests.rs b/src/stream/tls_stream_size_adversarial_tests.rs new file mode 100644 index 0000000..ec408cd --- /dev/null +++ b/src/stream/tls_stream_size_adversarial_tests.rs @@ -0,0 +1,579 @@ +use super::*; +use crate::protocol::constants::MAX_TLS_PLAINTEXT_SIZE; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; + +#[test] +fn handshake_record_above_plaintext_limit_must_be_rejected_early() { + let header = TlsRecordHeader { + record_type: TLS_RECORD_HANDSHAKE, + version: TLS_VERSION, + length: (MAX_TLS_PLAINTEXT_SIZE + 1) as u16, + }; + + assert!( + header.validate().is_err(), + "control-plane handshake record > MAX_TLS_PLAINTEXT_SIZE must fail closed" + ); +} + +#[test] +fn alert_record_above_plaintext_limit_must_be_rejected_early() { + let header = TlsRecordHeader { + record_type: TLS_RECORD_ALERT, + version: TLS_VERSION, + length: (MAX_TLS_PLAINTEXT_SIZE + 1) as u16, + }; + + assert!( + header.validate().is_err(), + "TLS alert record > MAX_TLS_PLAINTEXT_SIZE must be rejected" + ); +} + +#[test] +fn ccs_record_len_not_equal_one_must_be_rejected() { + let header = TlsRecordHeader { + record_type: TLS_RECORD_CHANGE_CIPHER, + version: TLS_VERSION, + length: 2, + }; + + assert!( + header.validate().is_err(), + "ChangeCipherSpec length must be exactly 1 byte in compat mode" + ); +} + +#[test] +fn handshake_record_len_zero_must_be_rejected() { + let header = TlsRecordHeader { + record_type: TLS_RECORD_HANDSHAKE, + version: TLS_VERSION, + length: 0, + }; + + assert!( + header.validate().is_err(), + "zero-length handshake record is structurally invalid" + ); +} + +#[test] +fn handshake_record_len_one_must_be_rejected() { + let header = TlsRecordHeader { + record_type: TLS_RECORD_HANDSHAKE, + version: TLS_VERSION, + length: 1, + }; + + assert!( + header.validate().is_err(), + "tiny handshake record must be rejected to avoid malformed parser states" + ); +} + +#[test] +fn handshake_record_len_four_is_accepted() { + let header = TlsRecordHeader { + record_type: TLS_RECORD_HANDSHAKE, + version: TLS_VERSION, + length: 4, + }; + + assert!( + header.validate().is_ok(), + "4-byte handshake payload is the minimum carrying handshake header" + ); +} + +#[test] +fn handshake_record_at_plaintext_limit_is_accepted() { + let header = TlsRecordHeader { + record_type: TLS_RECORD_HANDSHAKE, + version: TLS_VERSION, + length: MAX_TLS_PLAINTEXT_SIZE as u16, + }; + + assert!( + header.validate().is_ok(), + "handshake record at plaintext RFC limit must be accepted" + ); +} + +#[test] +fn handshake_record_at_ciphertext_limit_must_be_rejected() { + let header = TlsRecordHeader { + record_type: TLS_RECORD_HANDSHAKE, + version: TLS_VERSION, + length: MAX_TLS_CIPHERTEXT_SIZE as u16, + }; + + assert!( + header.validate().is_err(), + "control-plane handshake must never use ciphertext upper bound" + ); +} + +#[test] +fn alert_record_len_zero_must_be_rejected() { + let header = TlsRecordHeader { + record_type: TLS_RECORD_ALERT, + version: TLS_VERSION, + length: 0, + }; + + assert!( + header.validate().is_err(), + "TLS alert must always carry level+description bytes" + ); +} + +#[test] +fn alert_record_len_one_must_be_rejected() { + let header = TlsRecordHeader { + record_type: TLS_RECORD_ALERT, + version: TLS_VERSION, + length: 1, + }; + + assert!( + header.validate().is_err(), + "one-byte TLS alert is malformed and must fail closed" + ); +} + +#[test] +fn alert_record_len_two_is_accepted() { + let header = TlsRecordHeader { + record_type: TLS_RECORD_ALERT, + version: TLS_VERSION, + length: 2, + }; + + assert!( + header.validate().is_ok(), + "standard TLS alert shape should be accepted" + ); +} + +#[test] +fn alert_record_len_three_must_be_rejected() { + let header = TlsRecordHeader { + record_type: TLS_RECORD_ALERT, + version: TLS_VERSION, + length: 3, + }; + + assert!( + header.validate().is_err(), + "oversized plaintext alert should be rejected to avoid parser confusion" + ); +} + +#[test] +fn ccs_record_len_zero_must_be_rejected() { + let header = TlsRecordHeader { + record_type: TLS_RECORD_CHANGE_CIPHER, + version: TLS_VERSION, + length: 0, + }; + + assert!( + header.validate().is_err(), + "ChangeCipherSpec with zero length is malformed" + ); +} + +#[test] +fn ccs_record_len_one_is_accepted() { + let header = TlsRecordHeader { + record_type: TLS_RECORD_CHANGE_CIPHER, + version: TLS_VERSION, + length: 1, + }; + + assert!( + header.validate().is_ok(), + "ChangeCipherSpec compat record length must be accepted only for len=1" + ); +} + +#[test] +fn ccs_record_len_at_plaintext_limit_must_be_rejected() { + let header = TlsRecordHeader { + record_type: TLS_RECORD_CHANGE_CIPHER, + version: TLS_VERSION, + length: MAX_TLS_PLAINTEXT_SIZE as u16, + }; + + assert!( + header.validate().is_err(), + "oversized CCS control frame must fail closed" + ); +} + +#[test] +fn unknown_record_type_small_len_must_be_rejected_early() { + let header = TlsRecordHeader { + record_type: 0x19, + version: TLS_VERSION, + length: 8, + }; + + assert!( + header.validate().is_err(), + "unknown TLS record type should be rejected during header validation" + ); +} + +#[test] +fn unknown_record_type_large_len_must_be_rejected_early() { + let header = TlsRecordHeader { + record_type: 0x7f, + version: TLS_VERSION, + length: MAX_TLS_CIPHERTEXT_SIZE as u16, + }; + + assert!( + header.validate().is_err(), + "unknown record type with large payload must fail before body allocation" + ); +} + +#[test] +fn handshake_tls10_header_with_plaintext_plus_one_must_be_rejected() { + let header = TlsRecordHeader { + record_type: TLS_RECORD_HANDSHAKE, + version: [0x03, 0x01], + length: (MAX_TLS_PLAINTEXT_SIZE + 1) as u16, + }; + + assert!( + header.validate().is_err(), + "TLS 1.0 compatibility header must not bypass plaintext size cap" + ); +} + +#[test] +fn alert_tls10_header_with_invalid_len_must_be_rejected() { + let header = TlsRecordHeader { + record_type: TLS_RECORD_ALERT, + version: [0x03, 0x01], + length: 3, + }; + + assert!( + header.validate().is_err(), + "TLS 1.0 compatibility header must not bypass strict alert framing" + ); +} + +fn validates(record_type: u8, version: [u8; 2], length: u16) -> bool { + TlsRecordHeader { + record_type, + version, + length, + } + .validate() + .is_ok() +} + +macro_rules! expect_reject { + ($name:ident, $record_type:expr, $version:expr, $length:expr) => { + #[test] + fn $name() { + assert!( + !validates($record_type, $version, $length), + "expected reject for type=0x{:02x} version={:02x?} len={}", + $record_type, + $version, + $length + ); + } + }; +} + +macro_rules! expect_accept { + ($name:ident, $record_type:expr, $version:expr, $length:expr) => { + #[test] + fn $name() { + assert!( + validates($record_type, $version, $length), + "expected accept for type=0x{:02x} version={:02x?} len={}", + $record_type, + $version, + $length + ); + } + }; +} + +expect_reject!(appdata_zero_len_must_be_rejected, TLS_RECORD_APPLICATION, TLS_VERSION, 0); +expect_accept!(appdata_one_len_is_accepted, TLS_RECORD_APPLICATION, TLS_VERSION, 1); +expect_accept!(appdata_small_len_is_accepted, TLS_RECORD_APPLICATION, TLS_VERSION, 32); +expect_accept!(appdata_medium_len_is_accepted, TLS_RECORD_APPLICATION, TLS_VERSION, 1024); +expect_accept!(appdata_plaintext_limit_is_accepted, TLS_RECORD_APPLICATION, TLS_VERSION, MAX_TLS_PLAINTEXT_SIZE as u16); +expect_accept!(appdata_ciphertext_limit_is_accepted, TLS_RECORD_APPLICATION, TLS_VERSION, MAX_TLS_CIPHERTEXT_SIZE as u16); +expect_reject!(appdata_ciphertext_plus_one_must_be_rejected, TLS_RECORD_APPLICATION, TLS_VERSION, (MAX_TLS_CIPHERTEXT_SIZE as u16) + 1); + +expect_reject!(appdata_tls10_header_len_one_must_be_rejected, TLS_RECORD_APPLICATION, [0x03, 0x01], 1); +expect_reject!(appdata_tls10_header_medium_must_be_rejected, TLS_RECORD_APPLICATION, [0x03, 0x01], 1024); +expect_reject!(appdata_tls10_header_ciphertext_limit_must_be_rejected, TLS_RECORD_APPLICATION, [0x03, 0x01], MAX_TLS_CIPHERTEXT_SIZE as u16); + +expect_reject!(ccs_tls10_header_len_one_must_be_rejected, TLS_RECORD_CHANGE_CIPHER, [0x03, 0x01], 1); +expect_reject!(ccs_tls10_header_len_zero_must_be_rejected, TLS_RECORD_CHANGE_CIPHER, [0x03, 0x01], 0); +expect_reject!(ccs_tls10_header_len_two_must_be_rejected, TLS_RECORD_CHANGE_CIPHER, [0x03, 0x01], 2); + +expect_reject!(alert_tls10_header_len_two_must_be_rejected, TLS_RECORD_ALERT, [0x03, 0x01], 2); +expect_reject!(alert_tls10_header_len_one_must_be_rejected, TLS_RECORD_ALERT, [0x03, 0x01], 1); +expect_reject!(alert_tls10_header_len_three_must_be_rejected, TLS_RECORD_ALERT, [0x03, 0x01], 3); + +expect_accept!(handshake_tls10_header_min_len_is_accepted, TLS_RECORD_HANDSHAKE, [0x03, 0x01], 4); +expect_accept!(handshake_tls10_header_plaintext_limit_is_accepted, TLS_RECORD_HANDSHAKE, [0x03, 0x01], MAX_TLS_PLAINTEXT_SIZE as u16); +expect_reject!(handshake_tls10_header_too_small_must_be_rejected, TLS_RECORD_HANDSHAKE, [0x03, 0x01], 3); +expect_reject!(handshake_tls10_header_too_large_must_be_rejected, TLS_RECORD_HANDSHAKE, [0x03, 0x01], (MAX_TLS_PLAINTEXT_SIZE as u16) + 1); + +expect_reject!(unknown_type_tls13_zero_must_be_rejected, 0x00, TLS_VERSION, 0); +expect_reject!(unknown_type_tls13_small_must_be_rejected, 0x13, TLS_VERSION, 32); +expect_reject!(unknown_type_tls13_large_must_be_rejected, 0xfe, TLS_VERSION, MAX_TLS_CIPHERTEXT_SIZE as u16); +expect_reject!(unknown_type_tls10_small_must_be_rejected, 0x13, [0x03, 0x01], 32); + +expect_reject!(appdata_invalid_version_0302_must_be_rejected, TLS_RECORD_APPLICATION, [0x03, 0x02], 128); +expect_reject!(handshake_invalid_version_0302_must_be_rejected, TLS_RECORD_HANDSHAKE, [0x03, 0x02], 128); +expect_reject!(alert_invalid_version_0302_must_be_rejected, TLS_RECORD_ALERT, [0x03, 0x02], 2); +expect_reject!(ccs_invalid_version_0302_must_be_rejected, TLS_RECORD_CHANGE_CIPHER, [0x03, 0x02], 1); + +expect_reject!(appdata_invalid_version_0304_must_be_rejected, TLS_RECORD_APPLICATION, [0x03, 0x04], 128); +expect_reject!(handshake_invalid_version_0304_must_be_rejected, TLS_RECORD_HANDSHAKE, [0x03, 0x04], 128); +expect_reject!(alert_invalid_version_0304_must_be_rejected, TLS_RECORD_ALERT, [0x03, 0x04], 2); +expect_reject!(ccs_invalid_version_0304_must_be_rejected, TLS_RECORD_CHANGE_CIPHER, [0x03, 0x04], 1); + +expect_accept!(handshake_tls13_len_5_is_accepted, TLS_RECORD_HANDSHAKE, TLS_VERSION, 5); +expect_accept!(appdata_tls13_len_16385_is_accepted, TLS_RECORD_APPLICATION, TLS_VERSION, (MAX_TLS_PLAINTEXT_SIZE as u16) + 1); + +#[test] +fn matrix_version_policy_is_strict_and_deterministic() { + let versions = [[0x03, 0x01], TLS_VERSION, [0x03, 0x02], [0x03, 0x04], [0x00, 0x00]]; + let record_types = [ + TLS_RECORD_APPLICATION, + TLS_RECORD_CHANGE_CIPHER, + TLS_RECORD_ALERT, + TLS_RECORD_HANDSHAKE, + ]; + + for version in versions { + for record_type in record_types { + let len = match record_type { + TLS_RECORD_APPLICATION => 1, + TLS_RECORD_CHANGE_CIPHER => 1, + TLS_RECORD_ALERT => 2, + TLS_RECORD_HANDSHAKE => 4, + _ => unreachable!(), + }; + + let accepted = validates(record_type, version, len); + let expected = if version == TLS_VERSION { + true + } else { + version == [0x03, 0x01] && record_type == TLS_RECORD_HANDSHAKE + }; + + assert_eq!( + accepted, expected, + "version policy mismatch for type=0x{:02x} version={:02x?}", + record_type, version + ); + } + } +} + +#[test] +fn appdata_partition_property_holds_for_all_u16_edges() { + for len in [0u16, 1, 2, 3, 64, 255, 1024, 4096, 8192, 16_384, 16_385, 16_640, 16_641, u16::MAX] { + let accepted = validates(TLS_RECORD_APPLICATION, TLS_VERSION, len); + let expected = len >= 1 && usize::from(len) <= MAX_TLS_CIPHERTEXT_SIZE; + assert_eq!(accepted, expected, "unexpected appdata decision for len={len}"); + } +} + +#[test] +fn handshake_partition_property_holds_for_all_u16_edges() { + for len in [0u16, 1, 2, 3, 4, 5, 64, 255, 1024, 4096, 8192, 16_383, 16_384, 16_385, u16::MAX] { + let accepted_tls13 = validates(TLS_RECORD_HANDSHAKE, TLS_VERSION, len); + let accepted_tls10 = validates(TLS_RECORD_HANDSHAKE, [0x03, 0x01], len); + let expected = (4..=MAX_TLS_PLAINTEXT_SIZE).contains(&usize::from(len)); + + assert_eq!(accepted_tls13, expected, "TLS1.3 handshake mismatch for len={len}"); + assert_eq!(accepted_tls10, expected, "TLS1.0 compat handshake mismatch for len={len}"); + } +} + +#[test] +fn control_record_exact_lengths_are_enforced_under_fuzzed_lengths() { + let mut x: u32 = 0xC0FFEE11; + for _ in 0..5000 { + x = x.wrapping_mul(1664525).wrapping_add(1013904223); + let len = (x & 0xFFFF) as u16; + + let ccs_ok = validates(TLS_RECORD_CHANGE_CIPHER, TLS_VERSION, len); + let alert_ok = validates(TLS_RECORD_ALERT, TLS_VERSION, len); + + assert_eq!(ccs_ok, len == 1, "ccs length gate mismatch for len={len}"); + assert_eq!(alert_ok, len == 2, "alert length gate mismatch for len={len}"); + } +} + +#[test] +fn unknown_record_types_never_validate_under_supported_versions() { + for record_type in 0u8..=255 { + if matches!(record_type, TLS_RECORD_APPLICATION | TLS_RECORD_CHANGE_CIPHER | TLS_RECORD_ALERT | TLS_RECORD_HANDSHAKE) { + continue; + } + + assert!( + !validates(record_type, TLS_VERSION, 1), + "unknown type must not validate under TLS_VERSION: 0x{record_type:02x}" + ); + assert!( + !validates(record_type, [0x03, 0x01], 4), + "unknown type must not validate under TLS1.0 compat: 0x{record_type:02x}" + ); + } +} + +#[tokio::test] +async fn reader_rejects_tls10_appdata_header_before_payload_processing() { + let (mut tx, rx) = tokio::io::duplex(128); + tx.write_all(&[TLS_RECORD_APPLICATION, 0x03, 0x01, 0x00, 0x01, 0xAB]) + .await + .unwrap(); + tx.shutdown().await.unwrap(); + + let mut reader = FakeTlsReader::new(rx); + let mut out = [0u8; 1]; + let err = reader.read(&mut out).await.unwrap_err(); + assert_eq!(err.kind(), std::io::ErrorKind::InvalidData); +} + +#[tokio::test] +async fn reader_rejects_zero_len_appdata_record() { + let (mut tx, rx) = tokio::io::duplex(128); + tx.write_all(&[TLS_RECORD_APPLICATION, TLS_VERSION[0], TLS_VERSION[1], 0x00, 0x00]) + .await + .unwrap(); + tx.shutdown().await.unwrap(); + + let mut reader = FakeTlsReader::new(rx); + let mut out = [0u8; 1]; + let err = reader.read(&mut out).await.unwrap_err(); + assert_eq!(err.kind(), std::io::ErrorKind::InvalidData); +} + +#[tokio::test] +async fn reader_accepts_single_byte_tls13_appdata_and_yields_payload() { + let (mut tx, rx) = tokio::io::duplex(128); + tx.write_all(&[TLS_RECORD_APPLICATION, TLS_VERSION[0], TLS_VERSION[1], 0x00, 0x01, 0x5A]) + .await + .unwrap(); + tx.shutdown().await.unwrap(); + + let mut reader = FakeTlsReader::new(rx); + let mut out = [0u8; 1]; + let n = reader.read(&mut out).await.unwrap(); + assert_eq!(n, 1); + assert_eq!(out[0], 0x5A); +} + +#[tokio::test] +async fn reader_rejects_tls10_alert_even_with_structural_length() { + let (mut tx, rx) = tokio::io::duplex(128); + tx.write_all(&[TLS_RECORD_ALERT, 0x03, 0x01, 0x00, 0x02, 0x02, 0x28]) + .await + .unwrap(); + tx.shutdown().await.unwrap(); + + let mut reader = FakeTlsReader::new(rx); + let mut out = [0u8; 8]; + let err = reader.read(&mut out).await.unwrap_err(); + assert_eq!(err.kind(), std::io::ErrorKind::InvalidData); +} + +#[tokio::test] +async fn reader_rejects_unknown_record_type_fast() { + let (mut tx, rx) = tokio::io::duplex(128); + tx.write_all(&[0x7f, TLS_VERSION[0], TLS_VERSION[1], 0x00, 0x01, 0x01]) + .await + .unwrap(); + tx.shutdown().await.unwrap(); + + let mut reader = FakeTlsReader::new(rx); + let mut out = [0u8; 8]; + let err = reader.read(&mut out).await.unwrap_err(); + assert_eq!(err.kind(), std::io::ErrorKind::InvalidData); +} + +#[tokio::test] +async fn reader_preserves_data_after_valid_ccs_then_valid_appdata() { + let (mut tx, rx) = tokio::io::duplex(256); + tx.write_all(&[ + TLS_RECORD_CHANGE_CIPHER, + TLS_VERSION[0], + TLS_VERSION[1], + 0x00, + 0x01, + 0x01, + TLS_RECORD_APPLICATION, + TLS_VERSION[0], + TLS_VERSION[1], + 0x00, + 0x03, + 0xDE, + 0xAD, + 0xBE, + ]) + .await + .unwrap(); + tx.shutdown().await.unwrap(); + + let mut reader = FakeTlsReader::new(rx); + let mut out = [0u8; 3]; + let n = reader.read(&mut out).await.unwrap(); + assert_eq!(n, 3); + assert_eq!(out, [0xDE, 0xAD, 0xBE]); +} + +#[test] +fn deterministic_lcg_never_breaks_validation_invariants() { + let mut x: u64 = 0xD1A5_CE55_0BAD_F00D; + for _ in 0..20000 { + x = x.wrapping_mul(6364136223846793005).wrapping_add(1); + let record_type = (x & 0xFF) as u8; + let version = match (x >> 8) & 0x3 { + 0 => TLS_VERSION, + 1 => [0x03, 0x01], + 2 => [0x03, 0x02], + _ => [0x03, 0x04], + }; + let len = ((x >> 16) & 0xFFFF) as u16; + + let accepted = validates(record_type, version, len); + + let expected = match record_type { + TLS_RECORD_APPLICATION => { + version == TLS_VERSION && len >= 1 && usize::from(len) <= MAX_TLS_CIPHERTEXT_SIZE + } + TLS_RECORD_CHANGE_CIPHER => version == TLS_VERSION && len == 1, + TLS_RECORD_ALERT => version == TLS_VERSION && len == 2, + TLS_RECORD_HANDSHAKE => { + (version == TLS_VERSION || version == [0x03, 0x01]) + && (4..=MAX_TLS_PLAINTEXT_SIZE).contains(&usize::from(len)) + } + _ => false, + }; + + assert_eq!( + accepted, expected, + "invariant mismatch: type=0x{record_type:02x} version={version:02x?} len={len}" + ); + } +} diff --git a/src/tests/ip_tracker_hotpath_adversarial_tests.rs b/src/tests/ip_tracker_hotpath_adversarial_tests.rs new file mode 100644 index 0000000..53c4123 --- /dev/null +++ b/src/tests/ip_tracker_hotpath_adversarial_tests.rs @@ -0,0 +1,168 @@ +use std::net::{IpAddr, Ipv4Addr}; +use std::sync::Arc; +use std::time::Duration; + +use crate::config::UserMaxUniqueIpsMode; +use crate::ip_tracker::UserIpTracker; + +fn ip_from_idx(idx: u32) -> IpAddr { + IpAddr::V4(Ipv4Addr::new(10, ((idx >> 16) & 0xff) as u8, ((idx >> 8) & 0xff) as u8, (idx & 0xff) as u8)) +} + +#[tokio::test] +async fn hotpath_empty_drain_is_idempotent() { + let tracker = UserIpTracker::new(); + for _ in 0..128 { + tracker.drain_cleanup_queue().await; + } + assert_eq!(tracker.get_active_ip_count("none").await, 0); +} + +#[tokio::test] +async fn hotpath_batch_cleanup_drain_clears_all_active_entries() { + let tracker = UserIpTracker::new(); + tracker.set_user_limit("u", 100).await; + + for idx in 0..32 { + let ip = ip_from_idx(idx); + tracker.check_and_add("u", ip).await.unwrap(); + tracker.enqueue_cleanup("u".to_string(), ip); + } + + tracker.drain_cleanup_queue().await; + assert_eq!(tracker.get_active_ip_count("u").await, 0); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn hotpath_parallel_enqueue_and_drain_does_not_deadlock() { + let tracker = Arc::new(UserIpTracker::new()); + tracker.set_user_limit("p", 64).await; + + let mut tasks = Vec::new(); + for worker in 0..32u32 { + let t = tracker.clone(); + tasks.push(tokio::spawn(async move { + let ip = ip_from_idx(1_000 + worker); + for _ in 0..64 { + let _ = t.check_and_add("p", ip).await; + t.enqueue_cleanup("p".to_string(), ip); + t.drain_cleanup_queue().await; + } + })); + } + + for task in tasks { + tokio::time::timeout(Duration::from_secs(3), task) + .await + .expect("worker must not deadlock") + .expect("worker task must not panic"); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn hotpath_parallel_unique_ip_limit_never_exceeds_cap() { + let tracker = Arc::new(UserIpTracker::new()); + tracker.set_user_limit("limit", 5).await; + + let mut tasks = Vec::new(); + for idx in 0..64u32 { + let t = tracker.clone(); + tasks.push(tokio::spawn(async move { t.check_and_add("limit", ip_from_idx(idx)).await.is_ok() })); + } + + let mut admitted = 0usize; + for task in tasks { + if task.await.expect("task must not panic") { + admitted += 1; + } + } + + assert!(admitted <= 5, "admitted unique IPs must not exceed configured cap"); + assert!(tracker.get_active_ip_count("limit").await <= 5); +} + +#[tokio::test] +async fn hotpath_repeated_same_ip_counter_balances_to_zero() { + let tracker = UserIpTracker::new(); + tracker.set_user_limit("same", 1).await; + let ip = ip_from_idx(77); + + for _ in 0..512 { + tracker.check_and_add("same", ip).await.unwrap(); + } + for _ in 0..512 { + tracker.remove_ip("same", ip).await; + } + + assert_eq!(tracker.get_active_ip_count("same").await, 0); +} + +#[tokio::test] +async fn hotpath_light_fuzz_mixed_operations_preserve_limit_invariants() { + let tracker = UserIpTracker::new(); + tracker.set_user_limit("fuzz", 4).await; + + let mut state: u64 = 0xA55A_5AA5_D15C_B00B; + for _ in 0..4_000 { + state ^= state << 7; + state ^= state >> 9; + state ^= state << 8; + + let ip = ip_from_idx((state as u32) % 8); + match state & 0x3 { + 0 | 1 => { + let _ = tracker.check_and_add("fuzz", ip).await; + } + _ => { + tracker.remove_ip("fuzz", ip).await; + } + } + + assert!( + tracker.get_active_ip_count("fuzz").await <= 4, + "active count must stay within configured cap" + ); + } +} + +#[tokio::test] +async fn hotpath_multi_user_churn_keeps_isolation() { + let tracker = UserIpTracker::new(); + tracker.set_user_limit("u1", 2).await; + tracker.set_user_limit("u2", 3).await; + + for idx in 0..200u32 { + let ip1 = ip_from_idx(idx % 5); + let ip2 = ip_from_idx(100 + (idx % 7)); + let _ = tracker.check_and_add("u1", ip1).await; + let _ = tracker.check_and_add("u2", ip2).await; + if idx % 2 == 0 { + tracker.remove_ip("u1", ip1).await; + } + if idx % 3 == 0 { + tracker.remove_ip("u2", ip2).await; + } + } + + assert!(tracker.get_active_ip_count("u1").await <= 2); + assert!(tracker.get_active_ip_count("u2").await <= 3); +} + +#[tokio::test] +async fn hotpath_time_window_expiry_allows_new_ip_after_window() { + let tracker = UserIpTracker::new(); + tracker.set_user_limit("tw", 1).await; + tracker + .set_limit_policy(UserMaxUniqueIpsMode::TimeWindow, 1) + .await; + + let ip1 = ip_from_idx(901); + let ip2 = ip_from_idx(902); + + tracker.check_and_add("tw", ip1).await.unwrap(); + tracker.remove_ip("tw", ip1).await; + assert!(tracker.check_and_add("tw", ip2).await.is_err()); + + tokio::time::sleep(Duration::from_millis(1_100)).await; + assert!(tracker.check_and_add("tw", ip2).await.is_ok()); +} diff --git a/src/ip_tracker_regression_tests.rs b/src/tests/ip_tracker_regression_tests.rs similarity index 69% rename from src/ip_tracker_regression_tests.rs rename to src/tests/ip_tracker_regression_tests.rs index 5d6b358..57e135d 100644 --- a/src/ip_tracker_regression_tests.rs +++ b/src/tests/ip_tracker_regression_tests.rs @@ -448,3 +448,172 @@ async fn concurrent_reconnect_and_disconnect_preserves_non_negative_counts() { assert!(tracker.get_active_ip_count("cc").await <= 8); } + +#[tokio::test] +async fn enqueue_cleanup_recovers_from_poisoned_mutex() { + let tracker = UserIpTracker::new(); + let ip = ip_from_idx(99); + + // Poison the lock by panicking while holding it + let result = std::panic::catch_unwind(|| { + let _guard = tracker.cleanup_queue.lock().unwrap(); + panic!("Intentional poison panic"); + }); + assert!(result.is_err(), "Expected panic to poison mutex"); + + // Attempt to enqueue anyway; should hit the poison catch arm and still insert + tracker.enqueue_cleanup("poison-user".to_string(), ip); + + tracker.drain_cleanup_queue().await; + + assert_eq!(tracker.get_active_ip_count("poison-user").await, 0); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn mass_reconnect_sync_cleanup_prevents_temporary_reservation_bloat() { + // Tests that synchronous M-01 drop mechanism protects against starvation + let tracker = Arc::new(UserIpTracker::new()); + tracker.set_user_limit("mass", 5).await; + + let ip = ip_from_idx(42); + let mut join_handles = Vec::new(); + + // 10,000 rapid concurrent requests hitting the same IP limit + for _ in 0..10_000 { + let tracker_clone = tracker.clone(); + join_handles.push(tokio::spawn(async move { + if tracker_clone.check_and_add("mass", ip).await.is_ok() { + // Instantly enqueue cleanup, simulating synchronous reservation drop + tracker_clone.enqueue_cleanup("mass".to_string(), ip); + // The next caller will drain it before acquiring again + } + })); + } + + for handle in join_handles { + let _ = handle.await; + } + + // Force flush + tracker.drain_cleanup_queue().await; + assert_eq!(tracker.get_active_ip_count("mass").await, 0, "No leaked footprints"); +} + +#[tokio::test] +async fn adversarial_drain_cleanup_queue_race_does_not_cause_false_rejections() { + // Regression guard: concurrent cleanup draining must not produce false + // limit denials for a new IP when the previous IP is already queued. + let tracker = Arc::new(UserIpTracker::new()); + tracker.set_user_limit("racer", 1).await; + let ip1 = ip_from_idx(1); + let ip2 = ip_from_idx(2); + + // Initial state: add ip1 + tracker.check_and_add("racer", ip1).await.unwrap(); + + // User disconnects from ip1, queuing it + tracker.enqueue_cleanup("racer".to_string(), ip1); + + let mut saw_false_rejection = false; + for _ in 0..100 { + // Queue cleanup then race explicit drain and check-and-add on the alternative IP. + tracker.enqueue_cleanup("racer".to_string(), ip1); + let tracker_a = tracker.clone(); + let tracker_b = tracker.clone(); + + let drain_handle = tokio::spawn(async move { + tracker_a.drain_cleanup_queue().await; + }); + let handle = tokio::spawn(async move { + tracker_b.check_and_add("racer", ip2).await + }); + + drain_handle.await.unwrap(); + let res = handle.await.unwrap(); + if res.is_err() { + saw_false_rejection = true; + break; + } + + // Restore baseline for next iteration. + tracker.remove_ip("racer", ip2).await; + tracker.check_and_add("racer", ip1).await.unwrap(); + } + + assert!( + !saw_false_rejection, + "Concurrent cleanup draining must not cause false-positive IP denials" + ); +} + +#[tokio::test] +async fn poisoned_cleanup_queue_still_releases_slot_for_next_ip() { + let tracker = UserIpTracker::new(); + tracker.set_user_limit("poison-slot", 1).await; + let ip1 = ip_from_idx(7001); + let ip2 = ip_from_idx(7002); + + tracker.check_and_add("poison-slot", ip1).await.unwrap(); + + // Poison the queue lock as an adversarial condition. + let _ = std::panic::catch_unwind(|| { + let _guard = tracker.cleanup_queue.lock().unwrap(); + panic!("intentional queue poison"); + }); + + // Disconnect path must still queue cleanup so the next IP can be admitted. + tracker.enqueue_cleanup("poison-slot".to_string(), ip1); + let admitted = tracker.check_and_add("poison-slot", ip2).await; + assert!( + admitted.is_ok(), + "cleanup queue poison must not permanently block slot release for the next IP" + ); +} + +#[tokio::test] +async fn duplicate_cleanup_entries_do_not_break_future_admission() { + let tracker = UserIpTracker::new(); + tracker.set_user_limit("dup-cleanup", 1).await; + let ip1 = ip_from_idx(7101); + let ip2 = ip_from_idx(7102); + + tracker.check_and_add("dup-cleanup", ip1).await.unwrap(); + tracker.enqueue_cleanup("dup-cleanup".to_string(), ip1); + tracker.enqueue_cleanup("dup-cleanup".to_string(), ip1); + tracker.enqueue_cleanup("dup-cleanup".to_string(), ip1); + + tracker.drain_cleanup_queue().await; + + assert_eq!(tracker.get_active_ip_count("dup-cleanup").await, 0); + assert!( + tracker.check_and_add("dup-cleanup", ip2).await.is_ok(), + "extra queued cleanup entries must not leave user stuck in denied state" + ); +} + +#[tokio::test] +async fn stress_repeated_queue_poison_recovery_preserves_admission_progress() { + let tracker = UserIpTracker::new(); + tracker.set_user_limit("poison-stress", 1).await; + let ip_primary = ip_from_idx(7201); + let ip_alt = ip_from_idx(7202); + + tracker.check_and_add("poison-stress", ip_primary).await.unwrap(); + + for _ in 0..64 { + let _ = std::panic::catch_unwind(|| { + let _guard = tracker.cleanup_queue.lock().unwrap(); + panic!("intentional queue poison in stress loop"); + }); + + tracker.enqueue_cleanup("poison-stress".to_string(), ip_primary); + + assert!( + tracker.check_and_add("poison-stress", ip_alt).await.is_ok(), + "poison recovery must preserve admission progress under repeated queue poisoning" + ); + + tracker.remove_ip("poison-stress", ip_alt).await; + tracker.check_and_add("poison-stress", ip_primary).await.unwrap(); + } +} diff --git a/src/tls_front/emulator.rs b/src/tls_front/emulator.rs index 3278f63..e8fdf16 100644 --- a/src/tls_front/emulator.rs +++ b/src/tls_front/emulator.rs @@ -1,12 +1,13 @@ use crate::crypto::{sha256_hmac, SecureRandom}; use crate::protocol::constants::{ + MAX_TLS_CIPHERTEXT_SIZE, TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER, TLS_RECORD_HANDSHAKE, TLS_VERSION, }; use crate::protocol::tls::{TLS_DIGEST_LEN, TLS_DIGEST_POS, gen_fake_x25519_key}; use crate::tls_front::types::{CachedTlsData, ParsedCertificateInfo, TlsProfileSource}; const MIN_APP_DATA: usize = 64; -const MAX_APP_DATA: usize = 16640; // RFC 8446 §5.2 allows up to 2^14 + 256 +const MAX_APP_DATA: usize = MAX_TLS_CIPHERTEXT_SIZE; fn jitter_and_clamp_sizes(sizes: &[usize], rng: &SecureRandom) -> Vec { sizes @@ -117,15 +118,6 @@ pub fn build_emulated_server_hello( extensions.extend_from_slice(&0x002bu16.to_be_bytes()); extensions.extend_from_slice(&(2u16).to_be_bytes()); extensions.extend_from_slice(&0x0304u16.to_be_bytes()); - if let Some(alpn_proto) = &alpn { - extensions.extend_from_slice(&0x0010u16.to_be_bytes()); - let list_len: u16 = 1 + alpn_proto.len() as u16; - let ext_len: u16 = 2 + list_len; - extensions.extend_from_slice(&ext_len.to_be_bytes()); - extensions.extend_from_slice(&list_len.to_be_bytes()); - extensions.push(alpn_proto.len() as u8); - extensions.extend_from_slice(alpn_proto); - } let extensions_len = extensions.len() as u16; let body_len = 2 + // version @@ -207,8 +199,22 @@ pub fn build_emulated_server_hello( } let mut app_data = Vec::new(); + let alpn_marker = alpn + .as_ref() + .filter(|p| !p.is_empty() && p.len() <= u8::MAX as usize) + .map(|proto| { + let proto_list_len = 1usize + proto.len(); + let ext_data_len = 2usize + proto_list_len; + let mut marker = Vec::with_capacity(4 + ext_data_len); + marker.extend_from_slice(&0x0010u16.to_be_bytes()); + marker.extend_from_slice(&(ext_data_len as u16).to_be_bytes()); + marker.extend_from_slice(&(proto_list_len as u16).to_be_bytes()); + marker.push(proto.len() as u8); + marker.extend_from_slice(proto); + marker + }); let mut payload_offset = 0usize; - for size in sizes { + for (idx, size) in sizes.into_iter().enumerate() { let mut rec = Vec::with_capacity(5 + size); rec.push(TLS_RECORD_APPLICATION); rec.extend_from_slice(&TLS_VERSION); @@ -233,7 +239,20 @@ pub fn build_emulated_server_hello( } } else if size > 17 { let body_len = size - 17; - rec.extend_from_slice(&rng.bytes(body_len)); + let mut body = Vec::with_capacity(body_len); + if idx == 0 && let Some(marker) = &alpn_marker { + if marker.len() <= body_len { + body.extend_from_slice(marker); + if body_len > marker.len() { + body.extend_from_slice(&rng.bytes(body_len - marker.len())); + } + } else { + body.extend_from_slice(&rng.bytes(body_len)); + } + } else { + body.extend_from_slice(&rng.bytes(body_len)); + } + rec.extend_from_slice(&body); rec.push(0x16); // inner content type marker (handshake) rec.extend_from_slice(&rng.bytes(16)); // AEAD-like tag } else { @@ -245,8 +264,9 @@ pub fn build_emulated_server_hello( // --- Combine --- // Optional NewSessionTicket mimic records (opaque ApplicationData for fingerprint). let mut tickets = Vec::new(); - if new_session_tickets > 0 { - for _ in 0..new_session_tickets { + let ticket_count = new_session_tickets.min(4); + if ticket_count > 0 { + for _ in 0..ticket_count { let ticket_len: usize = rng.range(48) + 48; let mut rec = Vec::with_capacity(5 + ticket_len); rec.push(TLS_RECORD_APPLICATION); @@ -273,6 +293,10 @@ pub fn build_emulated_server_hello( response } +#[cfg(test)] +#[path = "tests/emulator_security_tests.rs"] +mod security_tests; + #[cfg(test)] mod tests { use std::time::SystemTime; diff --git a/src/tls_front/tests/emulator_security_tests.rs b/src/tls_front/tests/emulator_security_tests.rs new file mode 100644 index 0000000..c49d15a --- /dev/null +++ b/src/tls_front/tests/emulator_security_tests.rs @@ -0,0 +1,136 @@ +use std::time::SystemTime; + +use crate::crypto::SecureRandom; +use crate::protocol::constants::{TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER, TLS_RECORD_HANDSHAKE}; +use crate::tls_front::emulator::build_emulated_server_hello; +use crate::tls_front::types::{ + CachedTlsData, ParsedServerHello, TlsBehaviorProfile, TlsCertPayload, TlsProfileSource, +}; + +fn make_cached(cert_payload: Option) -> CachedTlsData { + CachedTlsData { + server_hello_template: ParsedServerHello { + version: [0x03, 0x03], + random: [0u8; 32], + session_id: Vec::new(), + cipher_suite: [0x13, 0x01], + compression: 0, + extensions: Vec::new(), + }, + cert_info: None, + cert_payload, + app_data_records_sizes: vec![64], + total_app_data_len: 64, + behavior_profile: TlsBehaviorProfile { + change_cipher_spec_count: 1, + app_data_record_sizes: vec![64], + ticket_record_sizes: Vec::new(), + source: TlsProfileSource::Default, + }, + fetched_at: SystemTime::now(), + domain: "example.com".to_string(), + } +} + +fn first_app_data_payload(response: &[u8]) -> &[u8] { + let hello_len = u16::from_be_bytes([response[3], response[4]]) as usize; + let ccs_start = 5 + hello_len; + let ccs_len = u16::from_be_bytes([response[ccs_start + 3], response[ccs_start + 4]]) as usize; + let app_start = ccs_start + 5 + ccs_len; + let app_len = u16::from_be_bytes([response[app_start + 3], response[app_start + 4]]) as usize; + &response[app_start + 5..app_start + 5 + app_len] +} + +#[test] +fn emulated_server_hello_ignores_oversized_alpn_when_marker_would_not_fit() { + let cached = make_cached(None); + let rng = SecureRandom::new(); + let oversized_alpn = vec![0xAB; u8::MAX as usize + 1]; + + let response = build_emulated_server_hello( + b"secret", + &[0x11; 32], + &[0x22; 16], + &cached, + true, + &rng, + Some(oversized_alpn), + 0, + ); + + assert_eq!(response[0], TLS_RECORD_HANDSHAKE); + let hello_len = u16::from_be_bytes([response[3], response[4]]) as usize; + let ccs_start = 5 + hello_len; + assert_eq!(response[ccs_start], TLS_RECORD_CHANGE_CIPHER); + let app_start = ccs_start + 6; + assert_eq!(response[app_start], TLS_RECORD_APPLICATION); + + let payload = first_app_data_payload(&response); + let mut marker_prefix = Vec::new(); + marker_prefix.extend_from_slice(&0x0010u16.to_be_bytes()); + marker_prefix.extend_from_slice(&0x0102u16.to_be_bytes()); + marker_prefix.extend_from_slice(&0x0100u16.to_be_bytes()); + marker_prefix.push(0xff); + marker_prefix.extend_from_slice(&[0xab; 8]); + assert!( + !payload.starts_with(&marker_prefix), + "oversized ALPN must not be partially embedded into the emulated first application record" + ); +} + +#[test] +fn emulated_server_hello_embeds_full_alpn_marker_when_body_can_fit() { + let cached = make_cached(None); + let rng = SecureRandom::new(); + + let response = build_emulated_server_hello( + b"secret", + &[0x31; 32], + &[0x41; 16], + &cached, + true, + &rng, + Some(b"h2".to_vec()), + 0, + ); + + let payload = first_app_data_payload(&response); + let expected = [0x00u8, 0x10, 0x00, 0x05, 0x00, 0x03, 0x02, b'h', b'2']; + assert!( + payload.starts_with(&expected), + "when body has enough capacity, emulated first application record must include full ALPN marker" + ); +} + +#[test] +fn emulated_server_hello_prefers_cert_payload_over_alpn_marker() { + let cert_msg = vec![0x0b, 0x00, 0x00, 0x05, 0x00, 0xaa, 0xbb, 0xcc, 0xdd]; + let cached = make_cached(Some(TlsCertPayload { + cert_chain_der: vec![vec![0x30, 0x01, 0x00]], + certificate_message: cert_msg.clone(), + })); + let rng = SecureRandom::new(); + + let response = build_emulated_server_hello( + b"secret", + &[0x32; 32], + &[0x42; 16], + &cached, + true, + &rng, + Some(b"h2".to_vec()), + 0, + ); + + let payload = first_app_data_payload(&response); + let alpn_marker = [0x00u8, 0x10, 0x00, 0x05, 0x00, 0x03, 0x02, b'h', b'2']; + + assert!( + payload.starts_with(&cert_msg), + "when certificate payload is available, first record must start with cert payload bytes" + ); + assert!( + !payload.starts_with(&alpn_marker), + "ALPN marker must not displace selected certificate payload" + ); +} diff --git a/src/transport/middle_proxy/health.rs b/src/transport/middle_proxy/health.rs index d53b4ef..fa54a27 100644 --- a/src/transport/middle_proxy/health.rs +++ b/src/transport/middle_proxy/health.rs @@ -4,16 +4,14 @@ use std::net::SocketAddr; use std::sync::Arc; use std::time::{Duration, Instant}; -use rand::Rng; +use rand::RngExt; use tracing::{debug, info, warn}; use crate::config::MeFloorMode; use crate::crypto::SecureRandom; use crate::network::IpFamily; -use crate::stats::MeWriterTeardownReason; use super::MePool; -use super::pool::{MeFamilyRuntimeState, MeWriter}; const JITTER_FRAC_NUM: u64 = 2; // jitter up to 50% of backoff #[allow(dead_code)] @@ -30,37 +28,7 @@ const HEALTH_RECONNECT_BUDGET_MAX: usize = 128; const HEALTH_DRAIN_CLOSE_BUDGET_PER_CORE: usize = 16; const HEALTH_DRAIN_CLOSE_BUDGET_MIN: usize = 16; const HEALTH_DRAIN_CLOSE_BUDGET_MAX: usize = 256; -const HEALTH_DRAIN_SOFT_EVICT_BUDGET_MIN: usize = 8; -const HEALTH_DRAIN_SOFT_EVICT_BUDGET_MAX: usize = 256; -const HEALTH_DRAIN_REAP_OPPORTUNISTIC_INTERVAL_SECS: u64 = 1; const HEALTH_DRAIN_TIMEOUT_ENFORCER_INTERVAL_SECS: u64 = 1; -const FAMILY_SUPPRESS_FAIL_STREAK_THRESHOLD: u32 = 6; -const FAMILY_SUPPRESS_WINDOW_SECS: u64 = 120; -const FAMILY_RECOVER_PROBE_INTERVAL_SECS: u64 = 5; -const FAMILY_RECOVER_SUCCESS_STREAK_REQUIRED: u32 = 3; - -#[derive(Debug, Clone)] -struct FamilyCircuitState { - state: MeFamilyRuntimeState, - state_since_at: Instant, - suppressed_until: Option, - next_probe_at: Instant, - fail_streak: u32, - recover_success_streak: u32, -} - -impl FamilyCircuitState { - fn new(now: Instant) -> Self { - Self { - state: MeFamilyRuntimeState::Healthy, - state_since_at: now, - suppressed_until: None, - next_probe_at: now, - fail_streak: 0, - recover_success_streak: 0, - } - } -} #[derive(Debug, Clone)] struct DcFloorPlanEntry { @@ -99,26 +67,6 @@ pub async fn me_health_monitor(pool: Arc, rng: Arc, _min_c let mut adaptive_recover_until: HashMap<(i32, IpFamily), Instant> = HashMap::new(); let mut floor_warn_next_allowed: HashMap<(i32, IpFamily), Instant> = HashMap::new(); let mut drain_warn_next_allowed: HashMap = HashMap::new(); - let mut drain_soft_evict_next_allowed: HashMap = HashMap::new(); - let mut family_v4_circuit = FamilyCircuitState::new(Instant::now()); - let mut family_v6_circuit = FamilyCircuitState::new(Instant::now()); - let init_epoch_secs = MePool::now_epoch_secs(); - pool.set_family_runtime_state( - IpFamily::V4, - family_v4_circuit.state, - init_epoch_secs, - 0, - family_v4_circuit.fail_streak, - family_v4_circuit.recover_success_streak, - ); - pool.set_family_runtime_state( - IpFamily::V6, - family_v6_circuit.state, - init_epoch_secs, - 0, - family_v6_circuit.fail_streak, - family_v6_circuit.recover_success_streak, - ); let mut degraded_interval = true; loop { let interval = if degraded_interval { @@ -128,15 +76,8 @@ pub async fn me_health_monitor(pool: Arc, rng: Arc, _min_c }; tokio::time::sleep(interval).await; pool.prune_closed_writers().await; - reap_draining_writers( - &pool, - &mut drain_warn_next_allowed, - &mut drain_soft_evict_next_allowed, - ) - .await; - let now = Instant::now(); - let now_epoch_secs = MePool::now_epoch_secs(); - let v4_degraded_raw = check_family( + reap_draining_writers(&pool, &mut drain_warn_next_allowed).await; + let v4_degraded = check_family( IpFamily::V4, &pool, &rng, @@ -151,256 +92,43 @@ pub async fn me_health_monitor(pool: Arc, rng: Arc, _min_c &mut adaptive_idle_since, &mut adaptive_recover_until, &mut floor_warn_next_allowed, - &mut drain_warn_next_allowed, - &mut drain_soft_evict_next_allowed, ) .await; - let v4_degraded = apply_family_circuit_result( - &pool, - IpFamily::V4, - &mut family_v4_circuit, - Some(v4_degraded_raw), - false, - now, - now_epoch_secs, - ); - - let v6_check_ran = should_run_family_check(&mut family_v6_circuit, now); - let v6_degraded_raw = if v6_check_ran { - check_family( - IpFamily::V6, - &pool, - &rng, - &mut backoff, - &mut next_attempt, - &mut inflight, - &mut outage_backoff, - &mut outage_next_attempt, - &mut single_endpoint_outage, - &mut shadow_rotate_deadline, - &mut idle_refresh_next_attempt, - &mut adaptive_idle_since, - &mut adaptive_recover_until, - &mut floor_warn_next_allowed, - &mut drain_warn_next_allowed, - &mut drain_soft_evict_next_allowed, - ) - .await - } else { - false - }; - let v6_degraded = apply_family_circuit_result( - &pool, + let v6_degraded = check_family( IpFamily::V6, - &mut family_v6_circuit, - if v6_check_ran { - Some(v6_degraded_raw) - } else { - None - }, - true, - now, - now_epoch_secs, - ); + &pool, + &rng, + &mut backoff, + &mut next_attempt, + &mut inflight, + &mut outage_backoff, + &mut outage_next_attempt, + &mut single_endpoint_outage, + &mut shadow_rotate_deadline, + &mut idle_refresh_next_attempt, + &mut adaptive_idle_since, + &mut adaptive_recover_until, + &mut floor_warn_next_allowed, + ) + .await; degraded_interval = v4_degraded || v6_degraded; } } pub async fn me_drain_timeout_enforcer(pool: Arc) { let mut drain_warn_next_allowed: HashMap = HashMap::new(); - let mut drain_soft_evict_next_allowed: HashMap = HashMap::new(); loop { tokio::time::sleep(Duration::from_secs( HEALTH_DRAIN_TIMEOUT_ENFORCER_INTERVAL_SECS, )) .await; - reap_draining_writers( - &pool, - &mut drain_warn_next_allowed, - &mut drain_soft_evict_next_allowed, - ) - .await; + reap_draining_writers(&pool, &mut drain_warn_next_allowed).await; } } -fn should_run_family_check(circuit: &mut FamilyCircuitState, now: Instant) -> bool { - match circuit.state { - MeFamilyRuntimeState::Suppressed => { - if now < circuit.next_probe_at { - return false; - } - circuit.next_probe_at = - now + Duration::from_secs(FAMILY_RECOVER_PROBE_INTERVAL_SECS); - true - } - _ => true, - } -} - -fn apply_family_circuit_result( - pool: &Arc, - family: IpFamily, - circuit: &mut FamilyCircuitState, - degraded: Option, - allow_suppress: bool, - now: Instant, - now_epoch_secs: u64, -) -> bool { - let Some(degraded) = degraded else { - // Preserve suppression state when probe tick is intentionally skipped. - return false; - }; - - let previous_state = circuit.state; - match circuit.state { - MeFamilyRuntimeState::Suppressed => { - if degraded { - circuit.fail_streak = circuit.fail_streak.saturating_add(1); - circuit.recover_success_streak = 0; - let until = now + Duration::from_secs(FAMILY_SUPPRESS_WINDOW_SECS); - circuit.suppressed_until = Some(until); - circuit.state_since_at = now; - warn!( - ?family, - fail_streak = circuit.fail_streak, - suppress_secs = FAMILY_SUPPRESS_WINDOW_SECS, - "ME family remains suppressed due to ongoing failures" - ); - } else { - circuit.fail_streak = 0; - circuit.recover_success_streak = 1; - circuit.state = MeFamilyRuntimeState::Recovering; - } - } - MeFamilyRuntimeState::Recovering => { - if degraded { - circuit.fail_streak = circuit.fail_streak.saturating_add(1); - if allow_suppress { - circuit.state = MeFamilyRuntimeState::Suppressed; - let until = now + Duration::from_secs(FAMILY_SUPPRESS_WINDOW_SECS); - circuit.suppressed_until = Some(until); - circuit.next_probe_at = - now + Duration::from_secs(FAMILY_RECOVER_PROBE_INTERVAL_SECS); - warn!( - ?family, - fail_streak = circuit.fail_streak, - suppress_secs = FAMILY_SUPPRESS_WINDOW_SECS, - "ME family temporarily suppressed after repeated degradation" - ); - } else { - circuit.state = MeFamilyRuntimeState::Degraded; - } - } else { - circuit.recover_success_streak = circuit.recover_success_streak.saturating_add(1); - if circuit.recover_success_streak >= FAMILY_RECOVER_SUCCESS_STREAK_REQUIRED { - circuit.fail_streak = 0; - circuit.recover_success_streak = 0; - circuit.suppressed_until = None; - circuit.state = MeFamilyRuntimeState::Healthy; - info!( - ?family, - "ME family suppression lifted after stable recovery probes" - ); - } - } - } - _ => { - if degraded { - circuit.fail_streak = circuit.fail_streak.saturating_add(1); - circuit.recover_success_streak = 0; - circuit.state = MeFamilyRuntimeState::Degraded; - if allow_suppress && circuit.fail_streak >= FAMILY_SUPPRESS_FAIL_STREAK_THRESHOLD { - circuit.state = MeFamilyRuntimeState::Suppressed; - let until = now + Duration::from_secs(FAMILY_SUPPRESS_WINDOW_SECS); - circuit.suppressed_until = Some(until); - circuit.next_probe_at = - now + Duration::from_secs(FAMILY_RECOVER_PROBE_INTERVAL_SECS); - warn!( - ?family, - fail_streak = circuit.fail_streak, - suppress_secs = FAMILY_SUPPRESS_WINDOW_SECS, - "ME family temporarily suppressed after repeated degradation" - ); - } - } else { - circuit.fail_streak = 0; - circuit.recover_success_streak = 0; - circuit.suppressed_until = None; - circuit.state = MeFamilyRuntimeState::Healthy; - } - } - } - - if previous_state != circuit.state { - circuit.state_since_at = now; - } - - let suppressed_until_epoch_secs = circuit - .suppressed_until - .and_then(|until| { - if until > now { - Some( - now_epoch_secs - .saturating_add(until.saturating_duration_since(now).as_secs()), - ) - } else { - None - } - }) - .unwrap_or(0); - let state_since_epoch_secs = if previous_state == circuit.state { - pool.family_runtime_state_since_epoch_secs(family) - } else { - now_epoch_secs - }; - pool.set_family_runtime_state( - family, - circuit.state, - state_since_epoch_secs, - suppressed_until_epoch_secs, - circuit.fail_streak, - circuit.recover_success_streak, - ); - - !matches!(circuit.state, MeFamilyRuntimeState::Suppressed) && degraded -} - -fn draining_writer_timeout_expired( - pool: &MePool, - writer: &MeWriter, - now_epoch_secs: u64, - drain_ttl_secs: u64, -) -> bool { - if pool - .me_instadrain - .load(std::sync::atomic::Ordering::Relaxed) - { - return true; - } - - let deadline_epoch_secs = writer - .drain_deadline_epoch_secs - .load(std::sync::atomic::Ordering::Relaxed); - if deadline_epoch_secs != 0 { - return now_epoch_secs >= deadline_epoch_secs; - } - - if drain_ttl_secs == 0 { - return false; - } - let drain_started_at_epoch_secs = writer - .draining_started_at_epoch_secs - .load(std::sync::atomic::Ordering::Relaxed); - if drain_started_at_epoch_secs == 0 { - return false; - } - now_epoch_secs.saturating_sub(drain_started_at_epoch_secs) > drain_ttl_secs -} - pub(super) async fn reap_draining_writers( pool: &Arc, warn_next_allowed: &mut HashMap, - soft_evict_next_allowed: &mut HashMap, ) { let now_epoch_secs = MePool::now_epoch_secs(); let now = Instant::now(); @@ -408,20 +136,15 @@ pub(super) async fn reap_draining_writers( let drain_threshold = pool .me_pool_drain_threshold .load(std::sync::atomic::Ordering::Relaxed); - let writers = pool.writers.read().await.clone(); let activity = pool.registry.writer_activity_snapshot().await; - let mut draining_writers = Vec::new(); + let mut draining_writers = Vec::::new(); let mut empty_writer_ids = Vec::::new(); - let mut timeout_expired_writer_ids = Vec::::new(); let mut force_close_writer_ids = Vec::::new(); - for writer in writers { + let writers = pool.writers.read().await; + for writer in writers.iter() { if !writer.draining.load(std::sync::atomic::Ordering::Relaxed) { continue; } - if draining_writer_timeout_expired(pool, &writer, now_epoch_secs, drain_ttl_secs) { - timeout_expired_writer_ids.push(writer.id); - continue; - } if activity .bound_clients_by_writer .get(&writer.id) @@ -432,23 +155,38 @@ pub(super) async fn reap_draining_writers( empty_writer_ids.push(writer.id); continue; } - draining_writers.push(writer); + draining_writers.push(DrainingWriterSnapshot { + id: writer.id, + writer_dc: writer.writer_dc, + addr: writer.addr, + generation: writer.generation, + created_at: writer.created_at, + draining_started_at_epoch_secs: writer + .draining_started_at_epoch_secs + .load(std::sync::atomic::Ordering::Relaxed), + drain_deadline_epoch_secs: writer + .drain_deadline_epoch_secs + .load(std::sync::atomic::Ordering::Relaxed), + allow_drain_fallback: writer + .allow_drain_fallback + .load(std::sync::atomic::Ordering::Relaxed), + }); } + drop(writers); - if drain_threshold > 0 && draining_writers.len() > drain_threshold as usize { + let overflow = if drain_threshold > 0 && draining_writers.len() > drain_threshold as usize { + draining_writers.len().saturating_sub(drain_threshold as usize) + } else { + 0 + }; + + if overflow > 0 { draining_writers.sort_by(|left, right| { - let left_started = left - .draining_started_at_epoch_secs - .load(std::sync::atomic::Ordering::Relaxed); - let right_started = right - .draining_started_at_epoch_secs - .load(std::sync::atomic::Ordering::Relaxed); - left_started - .cmp(&right_started) + left.draining_started_at_epoch_secs + .cmp(&right.draining_started_at_epoch_secs) .then_with(|| left.created_at.cmp(&right.created_at)) .then_with(|| left.id.cmp(&right.id)) }); - let overflow = draining_writers.len().saturating_sub(drain_threshold as usize); warn!( draining_writers = draining_writers.len(), me_pool_drain_threshold = drain_threshold, @@ -460,15 +198,10 @@ pub(super) async fn reap_draining_writers( } } - let mut active_draining_writer_ids = HashSet::with_capacity(draining_writers.len()); - for writer in &draining_writers { - active_draining_writer_ids.insert(writer.id); - let drain_started_at_epoch_secs = writer - .draining_started_at_epoch_secs - .load(std::sync::atomic::Ordering::Relaxed); + for writer in draining_writers { if drain_ttl_secs > 0 - && drain_started_at_epoch_secs != 0 - && now_epoch_secs.saturating_sub(drain_started_at_epoch_secs) > drain_ttl_secs + && writer.draining_started_at_epoch_secs != 0 + && now_epoch_secs.saturating_sub(writer.draining_started_at_epoch_secs) > drain_ttl_secs && should_emit_writer_warn( warn_next_allowed, writer.id, @@ -483,110 +216,22 @@ pub(super) async fn reap_draining_writers( generation = writer.generation, drain_ttl_secs, force_close_secs = pool.me_pool_force_close_secs.load(std::sync::atomic::Ordering::Relaxed), - allow_drain_fallback = writer.allow_drain_fallback.load(std::sync::atomic::Ordering::Relaxed), + allow_drain_fallback = writer.allow_drain_fallback, "ME draining writer remains non-empty past drain TTL" ); } - } - - warn_next_allowed.retain(|writer_id, _| active_draining_writer_ids.contains(writer_id)); - soft_evict_next_allowed.retain(|writer_id, _| active_draining_writer_ids.contains(writer_id)); - - if pool.drain_soft_evict_enabled() && drain_ttl_secs > 0 && !draining_writers.is_empty() { - let mut force_close_ids = HashSet::::with_capacity(force_close_writer_ids.len()); - for writer_id in &force_close_writer_ids { - force_close_ids.insert(*writer_id); - } - let soft_grace_secs = pool.drain_soft_evict_grace_secs(); - let soft_trigger_age_secs = drain_ttl_secs.saturating_add(soft_grace_secs); - let per_writer_limit = pool.drain_soft_evict_per_writer(); - let soft_budget = health_drain_soft_evict_budget(pool); - let soft_cooldown = pool.drain_soft_evict_cooldown(); - let mut soft_evicted_total = 0usize; - - for writer in &draining_writers { - if soft_evicted_total >= soft_budget { - break; - } - if force_close_ids.contains(&writer.id) { - continue; - } - if pool.writer_accepts_new_binding(writer) { - continue; - } - let started_epoch_secs = writer - .draining_started_at_epoch_secs - .load(std::sync::atomic::Ordering::Relaxed); - if started_epoch_secs == 0 - || now_epoch_secs.saturating_sub(started_epoch_secs) < soft_trigger_age_secs - { - continue; - } - if !should_emit_writer_warn( - soft_evict_next_allowed, - writer.id, - now, - soft_cooldown, - ) { - continue; - } - - let remaining_budget = soft_budget.saturating_sub(soft_evicted_total); - let limit = per_writer_limit.min(remaining_budget); - if limit == 0 { - break; - } - let conn_ids = pool - .registry - .bound_conn_ids_for_writer_limited(writer.id, limit) - .await; - if conn_ids.is_empty() { - continue; - } - - let mut evicted_for_writer = 0usize; - for conn_id in conn_ids { - if pool.registry.evict_bound_conn_if_writer(conn_id, writer.id).await { - evicted_for_writer = evicted_for_writer.saturating_add(1); - soft_evicted_total = soft_evicted_total.saturating_add(1); - pool.stats.increment_pool_drain_soft_evict_total(); - if soft_evicted_total >= soft_budget { - break; - } - } - } - - if evicted_for_writer > 0 { - pool.stats.increment_pool_drain_soft_evict_writer_total(); - info!( - writer_id = writer.id, - writer_dc = writer.writer_dc, - endpoint = %writer.addr, - drained_connections = evicted_for_writer, - soft_budget, - soft_trigger_age_secs, - "ME draining writer soft-evicted bound clients" - ); - } + if writer.drain_deadline_epoch_secs != 0 && now_epoch_secs >= writer.drain_deadline_epoch_secs + { + warn!(writer_id = writer.id, "Drain timeout, force-closing"); + force_close_writer_ids.push(writer.id); } } - let mut closed_writer_ids = HashSet::::new(); - for writer_id in timeout_expired_writer_ids { - if !closed_writer_ids.insert(writer_id) { - continue; - } - pool.stats.increment_pool_force_close_total(); - pool.remove_writer_and_close_clients(writer_id, MeWriterTeardownReason::ReapTimeoutExpired) - .await; - pool.stats - .increment_me_draining_writers_reap_progress_total(); - } - + let close_budget = health_drain_close_budget(); let requested_force_close = force_close_writer_ids.len(); let requested_empty_close = empty_writer_ids.len(); let requested_close_total = requested_force_close.saturating_add(requested_empty_close); - let close_budget = health_drain_close_budget(); + let mut closed_writer_ids = HashSet::::new(); let mut closed_total = 0usize; for writer_id in force_close_writer_ids { if closed_total >= close_budget { @@ -596,10 +241,7 @@ pub(super) async fn reap_draining_writers( continue; } pool.stats.increment_pool_force_close_total(); - pool.remove_writer_and_close_clients(writer_id, MeWriterTeardownReason::ReapThresholdForce) - .await; - pool.stats - .increment_me_draining_writers_reap_progress_total(); + pool.remove_writer_and_close_clients(writer_id).await; closed_total = closed_total.saturating_add(1); } for writer_id in empty_writer_ids { @@ -609,10 +251,7 @@ pub(super) async fn reap_draining_writers( if !closed_writer_ids.insert(writer_id) { continue; } - pool.remove_writer_and_close_clients(writer_id, MeWriterTeardownReason::ReapEmpty) - .await; - pool.stats - .increment_me_draining_writers_reap_progress_total(); + pool.remove_writer_and_close_clients(writer_id).await; closed_total = closed_total.saturating_add(1); } @@ -625,6 +264,18 @@ pub(super) async fn reap_draining_writers( "ME draining close backlog deferred to next health cycle" ); } + + // Keep warn cooldown state for draining writers still present in the pool; + // drop state only once a writer is actually removed. + let active_draining_writer_ids = { + let writers = pool.writers.read().await; + writers + .iter() + .filter(|writer| writer.draining.load(std::sync::atomic::Ordering::Relaxed)) + .map(|writer| writer.id) + .collect::>() + }; + warn_next_allowed.retain(|writer_id, _| active_draining_writer_ids.contains(writer_id)); } pub(super) fn health_drain_close_budget() -> usize { @@ -636,17 +287,16 @@ pub(super) fn health_drain_close_budget() -> usize { .clamp(HEALTH_DRAIN_CLOSE_BUDGET_MIN, HEALTH_DRAIN_CLOSE_BUDGET_MAX) } -pub(super) fn health_drain_soft_evict_budget(pool: &MePool) -> usize { - let cpu_cores = std::thread::available_parallelism() - .map(std::num::NonZeroUsize::get) - .unwrap_or(1); - let per_core = pool.drain_soft_evict_budget_per_core(); - cpu_cores - .saturating_mul(per_core) - .clamp( - HEALTH_DRAIN_SOFT_EVICT_BUDGET_MIN, - HEALTH_DRAIN_SOFT_EVICT_BUDGET_MAX, - ) +#[derive(Debug, Clone)] +struct DrainingWriterSnapshot { + id: u64, + writer_dc: i32, + addr: SocketAddr, + generation: u64, + created_at: Instant, + draining_started_at_epoch_secs: u64, + drain_deadline_epoch_secs: u64, + allow_drain_fallback: bool, } fn should_emit_writer_warn( @@ -681,8 +331,6 @@ async fn check_family( adaptive_idle_since: &mut HashMap<(i32, IpFamily), Instant>, adaptive_recover_until: &mut HashMap<(i32, IpFamily), Instant>, floor_warn_next_allowed: &mut HashMap<(i32, IpFamily), Instant>, - drain_warn_next_allowed: &mut HashMap, - drain_soft_evict_next_allowed: &mut HashMap, ) -> bool { let enabled = match family { IpFamily::V4 => pool.decision.ipv4_me, @@ -763,15 +411,8 @@ async fn check_family( floor_plan.active_writers_current, floor_plan.warm_writers_current, ); - let mut next_drain_reap_at = Instant::now(); for (dc, endpoints) in dc_endpoints { - if Instant::now() >= next_drain_reap_at { - reap_draining_writers(pool, drain_warn_next_allowed, drain_soft_evict_next_allowed) - .await; - next_drain_reap_at = Instant::now() - + Duration::from_secs(HEALTH_DRAIN_REAP_OPPORTUNISTIC_INTERVAL_SECS); - } if endpoints.is_empty() { continue; } @@ -915,12 +556,6 @@ async fn check_family( let mut restored = 0usize; for _ in 0..missing { - if Instant::now() >= next_drain_reap_at { - reap_draining_writers(pool, drain_warn_next_allowed, drain_soft_evict_next_allowed) - .await; - next_drain_reap_at = Instant::now() - + Duration::from_secs(HEALTH_DRAIN_REAP_OPPORTUNISTIC_INTERVAL_SECS); - } if reconnect_budget == 0 { break; } @@ -1868,34 +1503,20 @@ pub async fn me_zombie_writer_watchdog(pool: Arc) { for (writer_id, had_clients) in &zombie_ids_with_meta { let result = tokio::time::timeout( Duration::from_secs(REMOVE_TIMEOUT_SECS), - pool.remove_writer_and_close_clients( - *writer_id, - MeWriterTeardownReason::WatchdogStuckDraining, - ), + pool.remove_writer_and_close_clients(*writer_id), ) .await; match result { - Ok(true) => { + Ok(()) => { removal_timeout_streak.remove(writer_id); pool.stats.increment_pool_force_close_total(); - pool.stats - .increment_me_draining_writers_reap_progress_total(); info!( writer_id, had_clients, "Zombie writer removed by watchdog" ); } - Ok(false) => { - removal_timeout_streak.remove(writer_id); - debug!( - writer_id, - had_clients, - "Zombie writer watchdog removal became no-op" - ); - } Err(_) => { - pool.stats.increment_me_writer_teardown_timeout_total(); let streak = removal_timeout_streak .entry(*writer_id) .and_modify(|value| *value = value.saturating_add(1)) @@ -1909,22 +1530,16 @@ pub async fn me_zombie_writer_watchdog(pool: Arc) { if *streak < HARD_DETACH_TIMEOUT_STREAK { continue; } - pool.stats.increment_me_writer_teardown_escalation_total(); let hard_detach = tokio::time::timeout( Duration::from_secs(REMOVE_TIMEOUT_SECS), - pool.remove_draining_writer_hard_detach( - *writer_id, - MeWriterTeardownReason::WatchdogStuckDraining, - ), + pool.remove_draining_writer_hard_detach(*writer_id), ) .await; match hard_detach { Ok(true) => { removal_timeout_streak.remove(writer_id); pool.stats.increment_pool_force_close_total(); - pool.stats - .increment_me_draining_writers_reap_progress_total(); info!( writer_id, had_clients, @@ -1940,7 +1555,6 @@ pub async fn me_zombie_writer_watchdog(pool: Arc) { ); } Err(_) => { - pool.stats.increment_me_writer_teardown_timeout_total(); warn!( writer_id, had_clients, @@ -1964,19 +1578,13 @@ mod tests { use tokio::sync::mpsc; use tokio_util::sync::CancellationToken; - use super::{ - FamilyCircuitState, apply_family_circuit_result, reap_draining_writers, - should_run_family_check, - }; + use super::reap_draining_writers; use crate::config::{GeneralConfig, MeRouteNoWriterMode, MeSocksKdfPolicy, MeWriterPickMode}; use crate::crypto::SecureRandom; - use crate::network::IpFamily; use crate::network::probe::NetworkDecision; use crate::stats::Stats; use crate::transport::middle_proxy::codec::WriterCommand; - use crate::transport::middle_proxy::pool::{ - MeFamilyRuntimeState, MePool, MeWriter, WriterContour, - }; + use crate::transport::middle_proxy::pool::{MePool, MeWriter, WriterContour}; use crate::transport::middle_proxy::registry::ConnMeta; async fn make_pool(me_pool_drain_threshold: u64) -> Arc { @@ -1984,6 +1592,15 @@ mod tests { me_pool_drain_threshold, ..GeneralConfig::default() }; + let mut proxy_map_v4 = HashMap::new(); + proxy_map_v4.insert( + 2, + vec![(IpAddr::V4(Ipv4Addr::new(203, 0, 113, 10)), 443)], + ); + let decision = NetworkDecision { + ipv4_me: true, + ..NetworkDecision::default() + }; MePool::new( None, vec![1u8; 32], @@ -1995,10 +1612,10 @@ mod tests { None, 12, 1200, - HashMap::new(), + proxy_map_v4, HashMap::new(), None, - NetworkDecision::default(), + decision, None, Arc::new(SecureRandom::new()), Arc::new(Stats::default()), @@ -2066,8 +1683,6 @@ mod tests { general.me_warn_rate_limit_ms, MeRouteNoWriterMode::default(), general.me_route_no_writer_wait_ms, - general.me_route_hybrid_max_wait_ms, - general.me_route_blocking_send_timeout_ms, general.me_route_inline_recovery_attempts, general.me_route_inline_recovery_wait_ms, ) @@ -2117,19 +1732,66 @@ mod tests { conn_id } + async fn insert_live_writer(pool: &Arc, writer_id: u64, writer_dc: i32) { + let (tx, _writer_rx) = mpsc::channel::(8); + let writer = MeWriter { + id: writer_id, + addr: SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(203, 0, 113, (writer_id as u8).saturating_add(1))), + 4000 + writer_id as u16, + ), + source_ip: IpAddr::V4(Ipv4Addr::LOCALHOST), + writer_dc, + generation: 2, + contour: Arc::new(AtomicU8::new(WriterContour::Active.as_u8())), + created_at: Instant::now(), + tx: tx.clone(), + cancel: CancellationToken::new(), + degraded: Arc::new(AtomicBool::new(false)), + rtt_ema_ms_x10: Arc::new(AtomicU32::new(0)), + draining: Arc::new(AtomicBool::new(false)), + draining_started_at_epoch_secs: Arc::new(AtomicU64::new(0)), + drain_deadline_epoch_secs: Arc::new(AtomicU64::new(0)), + allow_drain_fallback: Arc::new(AtomicBool::new(false)), + }; + pool.writers.write().await.push(writer); + pool.registry.register_writer(writer_id, tx).await; + pool.conn_count.fetch_add(1, Ordering::Relaxed); + } + #[tokio::test] async fn reap_draining_writers_force_closes_oldest_over_threshold() { + let pool = make_pool(2).await; + insert_live_writer(&pool, 1, 2).await; + let now_epoch_secs = MePool::now_epoch_secs(); + let conn_a = insert_draining_writer(&pool, 10, now_epoch_secs.saturating_sub(30)).await; + let conn_b = insert_draining_writer(&pool, 20, now_epoch_secs.saturating_sub(20)).await; + let conn_c = insert_draining_writer(&pool, 30, now_epoch_secs.saturating_sub(10)).await; + let mut warn_next_allowed = HashMap::new(); + + reap_draining_writers(&pool, &mut warn_next_allowed).await; + + let mut writer_ids: Vec = pool.writers.read().await.iter().map(|writer| writer.id).collect(); + writer_ids.sort_unstable(); + assert_eq!(writer_ids, vec![1, 20, 30]); + assert!(pool.registry.get_writer(conn_a).await.is_none()); + assert_eq!(pool.registry.get_writer(conn_b).await.unwrap().writer_id, 20); + assert_eq!(pool.registry.get_writer(conn_c).await.unwrap().writer_id, 30); + } + + #[tokio::test] + async fn reap_draining_writers_force_closes_overflow_without_replacement() { let pool = make_pool(2).await; let now_epoch_secs = MePool::now_epoch_secs(); let conn_a = insert_draining_writer(&pool, 10, now_epoch_secs.saturating_sub(30)).await; let conn_b = insert_draining_writer(&pool, 20, now_epoch_secs.saturating_sub(20)).await; let conn_c = insert_draining_writer(&pool, 30, now_epoch_secs.saturating_sub(10)).await; let mut warn_next_allowed = HashMap::new(); - let mut soft_evict_next_allowed = HashMap::new(); - reap_draining_writers(&pool, &mut warn_next_allowed, &mut soft_evict_next_allowed).await; + reap_draining_writers(&pool, &mut warn_next_allowed).await; - let writer_ids: Vec = pool.writers.read().await.iter().map(|writer| writer.id).collect(); + let mut writer_ids: Vec = pool.writers.read().await.iter().map(|writer| writer.id).collect(); + writer_ids.sort_unstable(); assert_eq!(writer_ids, vec![20, 30]); assert!(pool.registry.get_writer(conn_a).await.is_none()); assert_eq!(pool.registry.get_writer(conn_b).await.unwrap().writer_id, 20); @@ -2144,9 +1806,8 @@ mod tests { let conn_b = insert_draining_writer(&pool, 20, now_epoch_secs.saturating_sub(20)).await; let conn_c = insert_draining_writer(&pool, 30, now_epoch_secs.saturating_sub(10)).await; let mut warn_next_allowed = HashMap::new(); - let mut soft_evict_next_allowed = HashMap::new(); - reap_draining_writers(&pool, &mut warn_next_allowed, &mut soft_evict_next_allowed).await; + reap_draining_writers(&pool, &mut warn_next_allowed).await; let writer_ids: Vec = pool.writers.read().await.iter().map(|writer| writer.id).collect(); assert_eq!(writer_ids, vec![10, 20, 30]); @@ -2154,47 +1815,4 @@ mod tests { assert_eq!(pool.registry.get_writer(conn_b).await.unwrap().writer_id, 20); assert_eq!(pool.registry.get_writer(conn_c).await.unwrap().writer_id, 30); } - - #[tokio::test] - async fn suppressed_family_probe_skip_preserves_suppressed_state() { - let pool = make_pool(0).await; - let now = Instant::now(); - let now_epoch_secs = MePool::now_epoch_secs(); - let suppressed_until_epoch_secs = now_epoch_secs.saturating_add(60); - pool.set_family_runtime_state( - IpFamily::V6, - MeFamilyRuntimeState::Suppressed, - now_epoch_secs, - suppressed_until_epoch_secs, - 7, - 0, - ); - - let mut circuit = FamilyCircuitState { - state: MeFamilyRuntimeState::Suppressed, - state_since_at: now, - suppressed_until: Some(now + Duration::from_secs(60)), - next_probe_at: now + Duration::from_secs(5), - fail_streak: 7, - recover_success_streak: 0, - }; - - assert!(!should_run_family_check(&mut circuit, now)); - assert!(!apply_family_circuit_result( - &pool, - IpFamily::V6, - &mut circuit, - None, - true, - now, - now_epoch_secs, - )); - assert_eq!(circuit.state, MeFamilyRuntimeState::Suppressed); - assert_eq!(circuit.fail_streak, 7); - assert_eq!(circuit.recover_success_streak, 0); - assert_eq!( - pool.family_runtime_state(IpFamily::V6), - MeFamilyRuntimeState::Suppressed, - ); - } } diff --git a/src/transport/middle_proxy/mod.rs b/src/transport/middle_proxy/mod.rs index 8c57717..74330b8 100644 --- a/src/transport/middle_proxy/mod.rs +++ b/src/transport/middle_proxy/mod.rs @@ -22,11 +22,23 @@ mod selftest; mod wire; mod pool_status; #[cfg(test)] +#[path = "tests/health_regression_tests.rs"] mod health_regression_tests; #[cfg(test)] +#[path = "tests/health_integration_tests.rs"] mod health_integration_tests; #[cfg(test)] +#[path = "tests/health_adversarial_tests.rs"] mod health_adversarial_tests; +#[cfg(test)] +#[path = "tests/send_adversarial_tests.rs"] +mod send_adversarial_tests; +#[cfg(test)] +#[path = "tests/pool_writer_security_tests.rs"] +mod pool_writer_security_tests; +#[cfg(test)] +#[path = "tests/pool_refill_security_tests.rs"] +mod pool_refill_security_tests; use bytes::Bytes; diff --git a/src/transport/middle_proxy/pool.rs b/src/transport/middle_proxy/pool.rs index 27bcb07..14c7d52 100644 --- a/src/transport/middle_proxy/pool.rs +++ b/src/transport/middle_proxy/pool.rs @@ -220,6 +220,7 @@ pub struct MePool { pub(super) refill_inflight: Arc>>, pub(super) refill_inflight_dc: Arc>>, pub(super) conn_count: AtomicUsize, + pub(super) draining_active_runtime: AtomicU64, pub(super) stats: Arc, pub(super) generation: AtomicU64, pub(super) active_generation: AtomicU64, @@ -254,8 +255,6 @@ pub struct MePool { pub(super) me_reader_route_data_wait_ms: Arc, pub(super) me_route_no_writer_mode: AtomicU8, pub(super) me_route_no_writer_wait: Duration, - pub(super) me_route_hybrid_max_wait: Duration, - pub(super) me_route_blocking_send_timeout: Duration, pub(super) me_route_inline_recovery_attempts: u32, pub(super) me_route_inline_recovery_wait: Duration, pub(super) me_health_interval_ms_unhealthy: AtomicU64, @@ -393,8 +392,6 @@ impl MePool { me_warn_rate_limit_ms: u64, me_route_no_writer_mode: MeRouteNoWriterMode, me_route_no_writer_wait_ms: u64, - me_route_hybrid_max_wait_ms: u64, - me_route_blocking_send_timeout_ms: u64, me_route_inline_recovery_attempts: u32, me_route_inline_recovery_wait_ms: u64, ) -> Arc { @@ -536,6 +533,7 @@ impl MePool { refill_inflight: Arc::new(Mutex::new(HashSet::new())), refill_inflight_dc: Arc::new(Mutex::new(HashSet::new())), conn_count: AtomicUsize::new(0), + draining_active_runtime: AtomicU64::new(0), generation: AtomicU64::new(1), active_generation: AtomicU64::new(1), warm_generation: AtomicU64::new(0), @@ -581,10 +579,6 @@ impl MePool { me_reader_route_data_wait_ms: Arc::new(AtomicU64::new(me_reader_route_data_wait_ms)), me_route_no_writer_mode: AtomicU8::new(me_route_no_writer_mode.as_u8()), me_route_no_writer_wait: Duration::from_millis(me_route_no_writer_wait_ms), - me_route_hybrid_max_wait: Duration::from_millis(me_route_hybrid_max_wait_ms), - me_route_blocking_send_timeout: Duration::from_millis( - me_route_blocking_send_timeout_ms, - ), me_route_inline_recovery_attempts, me_route_inline_recovery_wait: Duration::from_millis(me_route_inline_recovery_wait_ms), me_health_interval_ms_unhealthy: AtomicU64::new(me_health_interval_ms_unhealthy.max(1)), @@ -1015,6 +1009,33 @@ impl MePool { ) } + #[allow(dead_code)] + pub(super) fn draining_active_runtime(&self) -> u64 { + self.draining_active_runtime.load(Ordering::Relaxed) + } + + pub(super) fn increment_draining_active_runtime(&self) { + self.draining_active_runtime.fetch_add(1, Ordering::Relaxed); + } + + pub(super) fn decrement_draining_active_runtime(&self) { + let mut current = self.draining_active_runtime.load(Ordering::Relaxed); + loop { + if current == 0 { + break; + } + match self.draining_active_runtime.compare_exchange_weak( + current, + current - 1, + Ordering::Relaxed, + Ordering::Relaxed, + ) { + Ok(_) => break, + Err(actual) => current = actual, + } + } + } + pub(super) async fn key_selector(&self) -> u32 { self.proxy_secret.read().await.key_selector } diff --git a/src/transport/middle_proxy/pool_init.rs b/src/transport/middle_proxy/pool_init.rs index 29a70c5..5edbb37 100644 --- a/src/transport/middle_proxy/pool_init.rs +++ b/src/transport/middle_proxy/pool_init.rs @@ -2,7 +2,7 @@ use std::collections::HashSet; use std::net::{IpAddr, SocketAddr}; use std::sync::Arc; -use rand::Rng; +use rand::RngExt; use rand::seq::SliceRandom; use tracing::{debug, info, warn}; diff --git a/src/transport/middle_proxy/pool_refill.rs b/src/transport/middle_proxy/pool_refill.rs index 3c5d4b3..7dff7b5 100644 --- a/src/transport/middle_proxy/pool_refill.rs +++ b/src/transport/middle_proxy/pool_refill.rs @@ -71,16 +71,31 @@ impl MePool { } if let Some((addr, expiry)) = earliest_quarantine { + let remaining = expiry.saturating_duration_since(now); + if remaining.is_zero() { + return vec![addr]; + } + drop(guard); debug!( %addr, wait_ms = expiry.saturating_duration_since(now).as_millis(), "All ME endpoints are quarantined for the DC group; waiting for quarantine expiry" ); + tokio::time::sleep(remaining).await; + return vec![addr]; } Vec::new() } + #[cfg(test)] + pub(super) async fn connectable_endpoints_for_test( + &self, + endpoints: &[SocketAddr], + ) -> Vec { + self.connectable_endpoints(endpoints).await + } + pub(super) async fn has_refill_inflight_for_dc_key(&self, key: RefillDcKey) -> bool { let guard = self.refill_inflight_dc.lock().await; guard.contains(&key) diff --git a/src/transport/middle_proxy/pool_reinit.rs b/src/transport/middle_proxy/pool_reinit.rs index bfd56c6..23dca03 100644 --- a/src/transport/middle_proxy/pool_reinit.rs +++ b/src/transport/middle_proxy/pool_reinit.rs @@ -5,7 +5,7 @@ use std::sync::Arc; use std::sync::atomic::Ordering; use std::time::Duration; -use rand::Rng; +use rand::RngExt; use rand::seq::SliceRandom; use tracing::{debug, info, warn}; use std::collections::hash_map::DefaultHasher; @@ -147,6 +147,38 @@ impl MePool { out } + pub(super) async fn has_non_draining_writer_per_desired_dc_group(&self) -> bool { + let desired_by_dc = self.desired_dc_endpoints().await; + let required_dcs: HashSet = desired_by_dc + .iter() + .filter_map(|(dc, endpoints)| { + if endpoints.is_empty() { + None + } else { + Some(*dc) + } + }) + .collect(); + if required_dcs.is_empty() { + return true; + } + + let ws = self.writers.read().await; + let mut covered_dcs = HashSet::::with_capacity(required_dcs.len()); + for writer in ws.iter() { + if writer.draining.load(Ordering::Relaxed) { + continue; + } + if required_dcs.contains(&writer.writer_dc) { + covered_dcs.insert(writer.writer_dc); + if covered_dcs.len() == required_dcs.len() { + return true; + } + } + } + false + } + fn hardswap_warmup_connect_delay_ms(&self) -> u64 { let min_ms = self.me_hardswap_warmup_delay_min_ms.load(Ordering::Relaxed); let max_ms = self.me_hardswap_warmup_delay_max_ms.load(Ordering::Relaxed); @@ -506,12 +538,30 @@ impl MePool { coverage_ratio = format_args!("{coverage_ratio:.3}"), min_ratio = format_args!("{min_ratio:.3}"), drain_timeout_secs, - "ME reinit cycle covered; draining stale writers" + "ME reinit cycle covered; processing stale writers" ); self.stats.increment_pool_swap_total(); + let can_drop_with_replacement = self + .has_non_draining_writer_per_desired_dc_group() + .await; + if can_drop_with_replacement { + info!( + stale_writers = stale_writer_ids.len(), + "ME reinit stale writers: replacement coverage ready, force-closing clients for fast rebind" + ); + } else { + warn!( + stale_writers = stale_writer_ids.len(), + "ME reinit stale writers: replacement coverage incomplete, keeping draining fallback" + ); + } for writer_id in stale_writer_ids { self.mark_writer_draining_with_timeout(writer_id, drain_timeout, !hardswap) .await; + if can_drop_with_replacement { + self.stats.increment_pool_force_close_total(); + self.remove_writer_and_close_clients(writer_id).await; + } } if hardswap { self.clear_pending_hardswap_state(); diff --git a/src/transport/middle_proxy/pool_status.rs b/src/transport/middle_proxy/pool_status.rs index 5fe45cb..99070a8 100644 --- a/src/transport/middle_proxy/pool_status.rs +++ b/src/transport/middle_proxy/pool_status.rs @@ -40,7 +40,6 @@ pub(crate) struct MeApiDcStatusSnapshot { pub floor_max: usize, pub floor_capped: bool, pub alive_writers: usize, - pub coverage_ratio: f64, pub coverage_pct: f64, pub fresh_alive_writers: usize, pub fresh_coverage_pct: f64, @@ -63,7 +62,6 @@ pub(crate) struct MeApiStatusSnapshot { pub available_pct: f64, pub required_writers: usize, pub alive_writers: usize, - pub coverage_ratio: f64, pub coverage_pct: f64, pub fresh_alive_writers: usize, pub fresh_coverage_pct: f64, @@ -126,12 +124,6 @@ pub(crate) struct MeApiRuntimeSnapshot { pub me_reconnect_backoff_cap_ms: u64, pub me_reconnect_fast_retry_count: u32, pub me_pool_drain_ttl_secs: u64, - pub me_instadrain: bool, - pub me_pool_drain_soft_evict_enabled: bool, - pub me_pool_drain_soft_evict_grace_secs: u64, - pub me_pool_drain_soft_evict_per_writer: u8, - pub me_pool_drain_soft_evict_budget_per_core: u16, - pub me_pool_drain_soft_evict_cooldown_ms: u64, pub me_pool_force_close_secs: u64, pub me_pool_min_fresh_ratio: f32, pub me_bind_stale_mode: &'static str, @@ -345,8 +337,6 @@ impl MePool { let mut available_endpoints = 0usize; let mut alive_writers = 0usize; let mut fresh_alive_writers = 0usize; - let mut coverage_ratio_dcs_total = 0usize; - let mut coverage_ratio_dcs_covered = 0usize; let floor_mode = self.floor_mode(); let adaptive_cpu_cores = (self .me_adaptive_floor_cpu_cores_effective @@ -398,12 +388,6 @@ impl MePool { available_endpoints += dc_available_endpoints; alive_writers += dc_alive_writers; fresh_alive_writers += dc_fresh_alive_writers; - if endpoint_count > 0 { - coverage_ratio_dcs_total += 1; - if dc_alive_writers > 0 { - coverage_ratio_dcs_covered += 1; - } - } dcs.push(MeApiDcStatusSnapshot { dc, @@ -426,11 +410,6 @@ impl MePool { floor_max, floor_capped, alive_writers: dc_alive_writers, - coverage_ratio: if endpoint_count > 0 && dc_alive_writers > 0 { - 100.0 - } else { - 0.0 - }, coverage_pct: ratio_pct(dc_alive_writers, dc_required_writers), fresh_alive_writers: dc_fresh_alive_writers, fresh_coverage_pct: ratio_pct(dc_fresh_alive_writers, dc_required_writers), @@ -447,7 +426,6 @@ impl MePool { available_pct: ratio_pct(available_endpoints, configured_endpoints), required_writers, alive_writers, - coverage_ratio: ratio_pct(coverage_ratio_dcs_covered, coverage_ratio_dcs_total), coverage_pct: ratio_pct(alive_writers, required_writers), fresh_alive_writers, fresh_coverage_pct: ratio_pct(fresh_alive_writers, required_writers), @@ -584,23 +562,6 @@ impl MePool { me_reconnect_backoff_cap_ms: self.me_reconnect_backoff_cap.as_millis() as u64, me_reconnect_fast_retry_count: self.me_reconnect_fast_retry_count, me_pool_drain_ttl_secs: self.me_pool_drain_ttl_secs.load(Ordering::Relaxed), - me_instadrain: self.me_instadrain.load(Ordering::Relaxed), - me_pool_drain_soft_evict_enabled: self - .me_pool_drain_soft_evict_enabled - .load(Ordering::Relaxed), - me_pool_drain_soft_evict_grace_secs: self - .me_pool_drain_soft_evict_grace_secs - .load(Ordering::Relaxed), - me_pool_drain_soft_evict_per_writer: self - .me_pool_drain_soft_evict_per_writer - .load(Ordering::Relaxed), - me_pool_drain_soft_evict_budget_per_core: self - .me_pool_drain_soft_evict_budget_per_core - .load(Ordering::Relaxed) - .min(u16::MAX as u32) as u16, - me_pool_drain_soft_evict_cooldown_ms: self - .me_pool_drain_soft_evict_cooldown_ms - .load(Ordering::Relaxed), me_pool_force_close_secs: self.me_pool_force_close_secs.load(Ordering::Relaxed), me_pool_min_fresh_ratio: Self::permille_to_ratio( self.me_pool_min_fresh_ratio_permille.load(Ordering::Relaxed), diff --git a/src/transport/middle_proxy/pool_writer.rs b/src/transport/middle_proxy/pool_writer.rs index b0ba776..39f7121 100644 --- a/src/transport/middle_proxy/pool_writer.rs +++ b/src/transport/middle_proxy/pool_writer.rs @@ -6,9 +6,8 @@ use std::io::ErrorKind; use bytes::Bytes; use bytes::BytesMut; -use rand::Rng; +use rand::RngExt; use tokio::sync::mpsc; -use tokio::sync::mpsc::error::TrySendError; use tokio_util::sync::CancellationToken; use tracing::{debug, info, warn}; @@ -16,9 +15,6 @@ use crate::config::MeBindStaleMode; use crate::crypto::SecureRandom; use crate::error::{ProxyError, Result}; use crate::protocol::constants::{RPC_CLOSE_EXT_U32, RPC_PING_U32}; -use crate::stats::{ - MeWriterCleanupSideEffectStep, MeWriterTeardownMode, MeWriterTeardownReason, -}; use super::codec::{RpcWriter, WriterCommand}; use super::pool::{MePool, MeWriter, WriterContour}; @@ -31,7 +27,7 @@ const ME_IDLE_KEEPALIVE_MAX_SECS: u64 = 5; const ME_RPC_PROXY_REQ_RESPONSE_WAIT_MS: u64 = 700; #[derive(Clone, Copy)] -enum WriterRemoveGuardMode { +enum WriterTeardownMode { Any, DrainingOnly, } @@ -51,18 +47,7 @@ impl MePool { } for writer_id in closed_writer_ids { - if self.registry.is_writer_empty(writer_id).await { - let _ = self - .remove_writer_only(writer_id, MeWriterTeardownReason::PruneClosedWriter) - .await; - } else { - let _ = self - .remove_writer_and_close_clients( - writer_id, - MeWriterTeardownReason::PruneClosedWriter, - ) - .await; - } + let _ = self.remove_writer_and_close_clients(writer_id).await; } } @@ -183,11 +168,7 @@ impl MePool { .is_ok() { if let Some(pool) = pool_writer_task.upgrade() { - pool.remove_writer_and_close_clients( - writer_id, - MeWriterTeardownReason::WriterTaskExit, - ) - .await; + pool.remove_writer_and_close_clients(writer_id).await; } else { cancel_wr.cancel(); } @@ -278,11 +259,7 @@ impl MePool { .is_ok() { if let Some(pool) = pool.upgrade() { - pool.remove_writer_and_close_clients( - writer_id, - MeWriterTeardownReason::ReaderExit, - ) - .await; + pool.remove_writer_and_close_clients(writer_id).await; } else { // Fallback for shutdown races: make writer task exit quickly so stale // channels are observable by periodic prune. @@ -352,28 +329,41 @@ impl MePool { let mut p = Vec::with_capacity(12); p.extend_from_slice(&RPC_PING_U32.to_le_bytes()); p.extend_from_slice(&sent_id.to_le_bytes()); - let now_epoch_ms = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap_or_default() - .as_millis() as u64; - let mut run_cleanup = false; - if let Some(pool) = pool_ping.upgrade() { - let last_cleanup_ms = pool - .ping_tracker_last_cleanup_epoch_ms - .load(Ordering::Relaxed); - if now_epoch_ms.saturating_sub(last_cleanup_ms) >= 30_000 - && pool + { + let mut tracker = ping_tracker_ping.lock().await; + let now_epoch_ms = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_millis() as u64; + let mut run_cleanup = false; + if let Some(pool) = pool_ping.upgrade() { + let last_cleanup_ms = pool .ping_tracker_last_cleanup_epoch_ms - .compare_exchange( - last_cleanup_ms, - now_epoch_ms, - Ordering::AcqRel, - Ordering::Relaxed, - ) - .is_ok() - { - run_cleanup = true; + .load(Ordering::Relaxed); + if now_epoch_ms.saturating_sub(last_cleanup_ms) >= 30_000 + && pool + .ping_tracker_last_cleanup_epoch_ms + .compare_exchange( + last_cleanup_ms, + now_epoch_ms, + Ordering::AcqRel, + Ordering::Relaxed, + ) + .is_ok() + { + run_cleanup = true; + } } + + if run_cleanup { + let before = tracker.len(); + tracker.retain(|_, (ts, _)| ts.elapsed() < Duration::from_secs(120)); + let expired = before.saturating_sub(tracker.len()); + if expired > 0 { + stats_ping.increment_me_keepalive_timeout_by(expired as u64); + } + } + tracker.insert(sent_id, (std::time::Instant::now(), writer_id)); } ping_id = ping_id.wrapping_add(1); stats_ping.increment_me_keepalive_sent(); @@ -385,29 +375,16 @@ impl MePool { stats_ping.increment_me_keepalive_failed(); debug!("ME ping failed, removing dead writer"); cancel_ping.cancel(); - if let Some(pool) = pool_ping.upgrade() - && cleanup_for_ping - .compare_exchange(false, true, Ordering::AcqRel, Ordering::Relaxed) - .is_ok() + if cleanup_for_ping + .compare_exchange(false, true, Ordering::AcqRel, Ordering::Relaxed) + .is_ok() { - pool.remove_writer_and_close_clients( - writer_id, - MeWriterTeardownReason::PingSendFail, - ) - .await; + if let Some(pool) = pool_ping.upgrade() { + pool.remove_writer_and_close_clients(writer_id).await; + } } break; } - let mut tracker = ping_tracker_ping.lock().await; - if run_cleanup { - let before = tracker.len(); - tracker.retain(|_, (ts, _)| ts.elapsed() < Duration::from_secs(120)); - let expired = before.saturating_sub(tracker.len()); - if expired > 0 { - stats_ping.increment_me_keepalive_timeout_by(expired as u64); - } - } - tracker.insert(sent_id, (std::time::Instant::now(), writer_id)); } }); @@ -487,11 +464,7 @@ impl MePool { .compare_exchange(false, true, Ordering::AcqRel, Ordering::Relaxed) .is_ok() { - pool.remove_writer_and_close_clients( - writer_id, - MeWriterTeardownReason::SignalSendFail, - ) - .await; + pool.remove_writer_and_close_clients(writer_id).await; } break; } @@ -525,11 +498,7 @@ impl MePool { .compare_exchange(false, true, Ordering::AcqRel, Ordering::Relaxed) .is_ok() { - pool.remove_writer_and_close_clients( - writer_id, - MeWriterTeardownReason::SignalSendFail, - ) - .await; + pool.remove_writer_and_close_clients(writer_id).await; } break; } @@ -542,48 +511,25 @@ impl MePool { Ok(()) } - pub(crate) async fn remove_writer_and_close_clients( - self: &Arc, - writer_id: u64, - reason: MeWriterTeardownReason, - ) -> bool { + pub(crate) async fn remove_writer_and_close_clients(self: &Arc, writer_id: u64) { // Full client cleanup now happens inside `registry.writer_lost` to keep // writer reap/remove paths strictly non-blocking per connection. - self.remove_writer_with_mode( - writer_id, - reason, - MeWriterTeardownMode::Normal, - WriterRemoveGuardMode::Any, - ) - .await + let _ = self + .remove_writer_with_mode(writer_id, WriterTeardownMode::Any) + .await; } pub(super) async fn remove_draining_writer_hard_detach( self: &Arc, writer_id: u64, - reason: MeWriterTeardownReason, ) -> bool { - self.remove_writer_with_mode( - writer_id, - reason, - MeWriterTeardownMode::HardDetach, - WriterRemoveGuardMode::DrainingOnly, - ) - .await + self.remove_writer_with_mode(writer_id, WriterTeardownMode::DrainingOnly) + .await } - async fn remove_writer_only( - self: &Arc, - writer_id: u64, - reason: MeWriterTeardownReason, - ) -> bool { - self.remove_writer_with_mode( - writer_id, - reason, - MeWriterTeardownMode::Normal, - WriterRemoveGuardMode::Any, - ) - .await + async fn remove_writer_only(self: &Arc, writer_id: u64) -> bool { + self.remove_writer_with_mode(writer_id, WriterTeardownMode::Any) + .await } // Authoritative teardown primitive shared by normal cleanup and watchdog path. @@ -595,13 +541,8 @@ impl MePool { async fn remove_writer_with_mode( self: &Arc, writer_id: u64, - reason: MeWriterTeardownReason, - mode: MeWriterTeardownMode, - guard_mode: WriterRemoveGuardMode, + mode: WriterTeardownMode, ) -> bool { - let started_at = Instant::now(); - self.stats - .increment_me_writer_teardown_attempt_total(reason, mode); let mut close_tx: Option> = None; let mut removed_addr: Option = None; let mut removed_dc: Option = None; @@ -611,18 +552,16 @@ impl MePool { { let mut ws = self.writers.write().await; if let Some(pos) = ws.iter().position(|w| w.id == writer_id) { - if matches!(guard_mode, WriterRemoveGuardMode::DrainingOnly) + if matches!(mode, WriterTeardownMode::DrainingOnly) && !ws[pos].draining.load(Ordering::Relaxed) { - self.stats.increment_me_writer_teardown_noop_total(); - self.stats - .observe_me_writer_teardown_duration(mode, started_at.elapsed()); return false; } let w = ws.remove(pos); let was_draining = w.draining.load(Ordering::Relaxed); if was_draining { self.stats.decrement_pool_drain_active(); + self.decrement_draining_active_runtime(); } self.stats.increment_me_writer_removed_total(); w.cancel.cancel(); @@ -650,34 +589,11 @@ impl MePool { } self.rtt_stats.lock().await.remove(&writer_id); if let Some(tx) = close_tx { - match tx.try_send(WriterCommand::Close) { - Ok(()) => {} - Err(TrySendError::Full(_)) => { - self.stats.increment_me_writer_close_signal_drop_total(); - self.stats - .increment_me_writer_close_signal_channel_full_total(); - self.stats.increment_me_writer_cleanup_side_effect_failures_total( - MeWriterCleanupSideEffectStep::CloseSignalChannelFull, - ); - debug!( - writer_id, - "Skipping close signal for removed writer: command channel is full" - ); - } - Err(TrySendError::Closed(_)) => { - self.stats.increment_me_writer_close_signal_drop_total(); - self.stats.increment_me_writer_cleanup_side_effect_failures_total( - MeWriterCleanupSideEffectStep::CloseSignalChannelClosed, - ); - debug!( - writer_id, - "Skipping close signal for removed writer: command channel is closed" - ); - } - } + let _ = tx.send(WriterCommand::Close).await; } if let Some(addr) = removed_addr { if let Some(uptime) = removed_uptime { + // Quarantine flapping endpoints regardless of draining state. self.maybe_quarantine_flapping_endpoint(addr, uptime).await; } if trigger_refill @@ -686,13 +602,6 @@ impl MePool { self.trigger_immediate_refill_for_dc(addr, writer_dc); } } - if removed { - self.stats.increment_me_writer_teardown_success_total(mode); - } else { - self.stats.increment_me_writer_teardown_noop_total(); - } - self.stats - .observe_me_writer_teardown_duration(mode, started_at.elapsed()); removed } @@ -719,6 +628,7 @@ impl MePool { .store(drain_deadline_epoch_secs, Ordering::Relaxed); if !already_draining { self.stats.increment_pool_drain_active(); + self.increment_draining_active_runtime(); } w.contour .store(WriterContour::Draining.as_u8(), Ordering::Relaxed); diff --git a/src/transport/middle_proxy/registry.rs b/src/transport/middle_proxy/registry.rs index 2ee55c1..a22b98d 100644 --- a/src/transport/middle_proxy/registry.rs +++ b/src/transport/middle_proxy/registry.rs @@ -169,7 +169,6 @@ impl ConnRegistry { None } - #[allow(dead_code)] pub async fn route(&self, id: u64, resp: MeResponse) -> RouteResult { let tx = { let inner = self.inner.read().await; @@ -395,89 +394,31 @@ impl ConnRegistry { inner.writer_for_conn.keys().copied().collect() } - pub(super) async fn bound_conn_ids_for_writer_limited( - &self, - writer_id: u64, - limit: usize, - ) -> Vec { - if limit == 0 { - return Vec::new(); - } - let inner = self.inner.read().await; - let Some(conn_ids) = inner.conns_for_writer.get(&writer_id) else { - return Vec::new(); - }; - let mut out = conn_ids.iter().copied().collect::>(); - out.sort_unstable(); - out.truncate(limit); - out - } - - pub(super) async fn evict_bound_conn_if_writer(&self, conn_id: u64, writer_id: u64) -> bool { - let maybe_client_tx = { - let mut inner = self.inner.write().await; - if inner.writer_for_conn.get(&conn_id).copied() != Some(writer_id) { - return false; - } - - let client_tx = inner.map.get(&conn_id).cloned(); - inner.map.remove(&conn_id); - inner.meta.remove(&conn_id); - inner.writer_for_conn.remove(&conn_id); - - let became_empty = if let Some(set) = inner.conns_for_writer.get_mut(&writer_id) { - set.remove(&conn_id); - set.is_empty() - } else { - false - }; - if became_empty { - inner - .writer_idle_since_epoch_secs - .insert(writer_id, Self::now_epoch_secs()); - } - client_tx - }; - - if let Some(client_tx) = maybe_client_tx { - let _ = client_tx.try_send(MeResponse::Close); - } - true - } - pub async fn writer_lost(&self, writer_id: u64) -> Vec { - let mut close_txs = Vec::>::new(); - let mut out = Vec::new(); - { - let mut inner = self.inner.write().await; - inner.writers.remove(&writer_id); - inner.last_meta_for_writer.remove(&writer_id); - inner.writer_idle_since_epoch_secs.remove(&writer_id); - let conns = inner - .conns_for_writer - .remove(&writer_id) - .unwrap_or_default() - .into_iter() - .collect::>(); + let mut inner = self.inner.write().await; + inner.writers.remove(&writer_id); + inner.last_meta_for_writer.remove(&writer_id); + inner.writer_idle_since_epoch_secs.remove(&writer_id); + let conns = inner + .conns_for_writer + .remove(&writer_id) + .unwrap_or_default() + .into_iter() + .collect::>(); - for conn_id in conns { - if inner.writer_for_conn.get(&conn_id).copied() != Some(writer_id) { - continue; - } - inner.writer_for_conn.remove(&conn_id); - if let Some(client_tx) = inner.map.remove(&conn_id) { - close_txs.push(client_tx); - } - if let Some(meta) = inner.meta.remove(&conn_id) { - out.push(BoundConn { conn_id, meta }); - } + let mut out = Vec::new(); + for conn_id in conns { + if inner.writer_for_conn.get(&conn_id).copied() != Some(writer_id) { + continue; + } + inner.writer_for_conn.remove(&conn_id); + if let Some(m) = inner.meta.get(&conn_id) { + out.push(BoundConn { + conn_id, + meta: m.clone(), + }); } } - - for client_tx in close_txs { - let _ = client_tx.try_send(MeResponse::Close); - } - out } @@ -495,16 +436,45 @@ impl ConnRegistry { .map(|s| s.is_empty()) .unwrap_or(true) } + + pub async fn unregister_writer_if_empty(&self, writer_id: u64) -> bool { + let mut inner = self.inner.write().await; + let Some(conn_ids) = inner.conns_for_writer.get(&writer_id) else { + // Writer is already absent from the registry. + return true; + }; + if !conn_ids.is_empty() { + return false; + } + + inner.writers.remove(&writer_id); + inner.last_meta_for_writer.remove(&writer_id); + inner.writer_idle_since_epoch_secs.remove(&writer_id); + inner.conns_for_writer.remove(&writer_id); + true + } + + #[allow(dead_code)] + pub(super) async fn non_empty_writer_ids(&self, writer_ids: &[u64]) -> HashSet { + let inner = self.inner.read().await; + let mut out = HashSet::::with_capacity(writer_ids.len()); + for writer_id in writer_ids { + if let Some(conns) = inner.conns_for_writer.get(writer_id) + && !conns.is_empty() + { + out.insert(*writer_id); + } + } + out + } } #[cfg(test)] mod tests { use std::net::{IpAddr, Ipv4Addr, SocketAddr}; - use std::time::Duration; use super::ConnMeta; use super::ConnRegistry; - use super::MeResponse; #[tokio::test] async fn writer_activity_snapshot_tracks_writer_and_dc_load() { @@ -673,39 +643,6 @@ mod tests { assert!(registry.is_writer_empty(20).await); } - #[tokio::test] - async fn writer_lost_removes_bound_conn_from_registry_and_signals_close() { - let registry = ConnRegistry::new(); - let (conn_id, mut rx) = registry.register().await; - let (writer_tx, _writer_rx) = tokio::sync::mpsc::channel(8); - registry.register_writer(10, writer_tx).await; - let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443); - - assert!( - registry - .bind_writer( - conn_id, - 10, - ConnMeta { - target_dc: 2, - client_addr: addr, - our_addr: addr, - proto_flags: 0, - }, - ) - .await - ); - - let lost = registry.writer_lost(10).await; - assert_eq!(lost.len(), 1); - assert_eq!(lost[0].conn_id, conn_id); - assert!(registry.get_writer(conn_id).await.is_none()); - assert!(registry.get_meta(conn_id).await.is_none()); - assert_eq!(registry.unregister(conn_id).await, None); - let close = tokio::time::timeout(Duration::from_millis(50), rx.recv()).await; - assert!(matches!(close, Ok(Some(MeResponse::Close)))); - } - #[tokio::test] async fn bind_writer_rejects_unregistered_writer() { let registry = ConnRegistry::new(); @@ -730,47 +667,15 @@ mod tests { } #[tokio::test] - async fn bound_conn_ids_for_writer_limited_is_sorted_and_bounded() { + async fn non_empty_writer_ids_returns_only_writers_with_bound_clients() { let registry = ConnRegistry::new(); - let (writer_tx, _writer_rx) = tokio::sync::mpsc::channel(8); - registry.register_writer(10, writer_tx).await; - let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443); - let mut conn_ids = Vec::new(); - for _ in 0..5 { - let (conn_id, _rx) = registry.register().await; - assert!( - registry - .bind_writer( - conn_id, - 10, - ConnMeta { - target_dc: 2, - client_addr: addr, - our_addr: addr, - proto_flags: 0, - }, - ) - .await - ); - conn_ids.push(conn_id); - } - conn_ids.sort_unstable(); - - let limited = registry.bound_conn_ids_for_writer_limited(10, 3).await; - assert_eq!(limited.len(), 3); - assert_eq!(limited, conn_ids.into_iter().take(3).collect::>()); - } - - #[tokio::test] - async fn evict_bound_conn_if_writer_does_not_touch_rebound_conn() { - let registry = ConnRegistry::new(); - let (conn_id, mut rx) = registry.register().await; + let (conn_id, _rx) = registry.register().await; let (writer_tx_a, _writer_rx_a) = tokio::sync::mpsc::channel(8); let (writer_tx_b, _writer_rx_b) = tokio::sync::mpsc::channel(8); registry.register_writer(10, writer_tx_a).await; registry.register_writer(20, writer_tx_b).await; - let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443); + let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443); assert!( registry .bind_writer( @@ -785,29 +690,10 @@ mod tests { ) .await ); - assert!( - registry - .bind_writer( - conn_id, - 20, - ConnMeta { - target_dc: 2, - client_addr: addr, - our_addr: addr, - proto_flags: 1, - }, - ) - .await - ); - let evicted = registry.evict_bound_conn_if_writer(conn_id, 10).await; - assert!(!evicted); - assert_eq!(registry.get_writer(conn_id).await.expect("writer").writer_id, 20); - assert!(rx.try_recv().is_err()); - - let evicted = registry.evict_bound_conn_if_writer(conn_id, 20).await; - assert!(evicted); - assert!(registry.get_writer(conn_id).await.is_none()); - assert!(matches!(rx.try_recv(), Ok(MeResponse::Close))); + let non_empty = registry.non_empty_writer_ids(&[10, 20, 30]).await; + assert!(non_empty.contains(&10)); + assert!(!non_empty.contains(&20)); + assert!(!non_empty.contains(&30)); } } diff --git a/src/transport/middle_proxy/send.rs b/src/transport/middle_proxy/send.rs index 82118d8..5e0e562 100644 --- a/src/transport/middle_proxy/send.rs +++ b/src/transport/middle_proxy/send.rs @@ -6,7 +6,6 @@ use std::sync::atomic::Ordering; use std::time::{Duration, Instant}; use bytes::Bytes; -use tokio::sync::mpsc; use tokio::sync::mpsc::error::TrySendError; use tracing::{debug, warn}; @@ -14,7 +13,6 @@ use crate::config::{MeRouteNoWriterMode, MeWriterPickMode}; use crate::error::{ProxyError, Result}; use crate::network::IpFamily; use crate::protocol::constants::{RPC_CLOSE_CONN_U32, RPC_CLOSE_EXT_U32}; -use crate::stats::MeWriterTeardownReason; use super::MePool; use super::codec::WriterCommand; @@ -31,29 +29,6 @@ const PICK_PENALTY_DRAINING: u64 = 600; const PICK_PENALTY_STALE: u64 = 300; const PICK_PENALTY_DEGRADED: u64 = 250; -enum TimedSendError { - Closed(T), - Timeout(T), -} - -async fn send_writer_command_with_timeout( - tx: &mpsc::Sender, - cmd: WriterCommand, - timeout: Duration, -) -> std::result::Result<(), TimedSendError> { - if timeout.is_zero() { - return tx.send(cmd).await.map_err(|err| TimedSendError::Closed(err.0)); - } - match tokio::time::timeout(timeout, tx.reserve()).await { - Ok(Ok(permit)) => { - permit.send(cmd); - Ok(()) - } - Ok(Err(_)) => Err(TimedSendError::Closed(cmd)), - Err(_) => Err(TimedSendError::Timeout(cmd)), - } -} - impl MePool { /// Send RPC_PROXY_REQ. `tag_override`: per-user ad_tag (from access.user_ad_tags); if None, uses pool default. pub async fn send_proxy_req( @@ -103,18 +78,8 @@ impl MePool { let mut hybrid_last_recovery_at: Option = None; let hybrid_wait_step = self.me_route_no_writer_wait.max(Duration::from_millis(50)); let mut hybrid_wait_current = hybrid_wait_step; - let hybrid_deadline = Instant::now() + self.me_route_hybrid_max_wait; loop { - if matches!(no_writer_mode, MeRouteNoWriterMode::HybridAsyncPersistent) - && Instant::now() >= hybrid_deadline - { - self.stats.increment_me_no_writer_failfast_total(); - return Err(ProxyError::Proxy( - "No ME writer available in hybrid wait window".into(), - )); - } - let mut skip_writer_id: Option = None; let current_meta = self .registry .get_meta(conn_id) @@ -125,42 +90,16 @@ impl MePool { match current.tx.try_send(WriterCommand::Data(current_payload.clone())) { Ok(()) => return Ok(()), Err(TrySendError::Full(cmd)) => { - match send_writer_command_with_timeout( - ¤t.tx, - cmd, - self.me_route_blocking_send_timeout, - ) - .await - { - Ok(()) => return Ok(()), - Err(TimedSendError::Closed(_)) => { - warn!(writer_id = current.writer_id, "ME writer channel closed"); - self.remove_writer_and_close_clients( - current.writer_id, - MeWriterTeardownReason::RouteChannelClosed, - ) - .await; - continue; - } - Err(TimedSendError::Timeout(_)) => { - debug!( - conn_id, - writer_id = current.writer_id, - timeout_ms = self.me_route_blocking_send_timeout.as_millis() - as u64, - "ME writer send timed out for bound writer, trying reroute" - ); - skip_writer_id = Some(current.writer_id); - } + if current.tx.send(cmd).await.is_ok() { + return Ok(()); } + warn!(writer_id = current.writer_id, "ME writer channel closed"); + self.remove_writer_and_close_clients(current.writer_id).await; + continue; } Err(TrySendError::Closed(_)) => { warn!(writer_id = current.writer_id, "ME writer channel closed"); - self.remove_writer_and_close_clients( - current.writer_id, - MeWriterTeardownReason::RouteChannelClosed, - ) - .await; + self.remove_writer_and_close_clients(current.writer_id).await; continue; } } @@ -261,9 +200,6 @@ impl MePool { .candidate_indices_for_dc(&writers_snapshot, routed_dc, true) .await; } - if let Some(skip_writer_id) = skip_writer_id { - candidate_indices.retain(|idx| writers_snapshot[*idx].id != skip_writer_id); - } if candidate_indices.is_empty() { let pick_mode = self.writer_pick_mode(); match no_writer_mode { @@ -436,17 +372,20 @@ impl MePool { } let effective_our_addr = SocketAddr::new(w.source_ip, our_addr.port()); let (payload, meta) = build_routed_payload(effective_our_addr); - match w.tx.try_send(WriterCommand::Data(payload.clone())) { - Ok(()) => { - self.stats.increment_me_writer_pick_success_try_total(pick_mode); + match w.tx.clone().try_reserve_owned() { + Ok(permit) => { if !self.registry.bind_writer(conn_id, w.id, meta).await { debug!( conn_id, writer_id = w.id, - "ME writer disappeared before bind commit, retrying" + "ME writer disappeared before bind commit, pruning stale writer" ); + drop(permit); + self.remove_writer_and_close_clients(w.id).await; continue; } + permit.send(WriterCommand::Data(payload.clone())); + self.stats.increment_me_writer_pick_success_try_total(pick_mode); if w.generation < self.current_generation() { self.stats.increment_pool_stale_pick_total(); debug!( @@ -467,11 +406,7 @@ impl MePool { Err(TrySendError::Closed(_)) => { self.stats.increment_me_writer_pick_closed_total(pick_mode); warn!(writer_id = w.id, "ME writer channel closed"); - self.remove_writer_and_close_clients( - w.id, - MeWriterTeardownReason::RouteChannelClosed, - ) - .await; + self.remove_writer_and_close_clients(w.id).await; continue; } } @@ -490,46 +425,30 @@ impl MePool { self.stats.increment_me_writer_pick_blocking_fallback_total(); let effective_our_addr = SocketAddr::new(w.source_ip, our_addr.port()); let (payload, meta) = build_routed_payload(effective_our_addr); - match send_writer_command_with_timeout( - &w.tx, - WriterCommand::Data(payload.clone()), - self.me_route_blocking_send_timeout, - ) - .await - { - Ok(()) => { - self.stats - .increment_me_writer_pick_success_fallback_total(pick_mode); + match w.tx.clone().reserve_owned().await { + Ok(permit) => { if !self.registry.bind_writer(conn_id, w.id, meta).await { debug!( conn_id, writer_id = w.id, - "ME writer disappeared before fallback bind commit, retrying" + "ME writer disappeared before fallback bind commit, pruning stale writer" ); + drop(permit); + self.remove_writer_and_close_clients(w.id).await; continue; } + permit.send(WriterCommand::Data(payload.clone())); + self.stats + .increment_me_writer_pick_success_fallback_total(pick_mode); if w.generation < self.current_generation() { self.stats.increment_pool_stale_pick_total(); } return Ok(()); } - Err(TimedSendError::Closed(_)) => { + Err(_) => { self.stats.increment_me_writer_pick_closed_total(pick_mode); warn!(writer_id = w.id, "ME writer channel closed (blocking)"); - self.remove_writer_and_close_clients( - w.id, - MeWriterTeardownReason::RouteChannelClosed, - ) - .await; - } - Err(TimedSendError::Timeout(_)) => { - self.stats.increment_me_writer_pick_full_total(pick_mode); - debug!( - conn_id, - writer_id = w.id, - timeout_ms = self.me_route_blocking_send_timeout.as_millis() as u64, - "ME writer blocking fallback send timed out" - ); + self.remove_writer_and_close_clients(w.id).await; } } } @@ -660,23 +579,13 @@ impl MePool { let mut p = Vec::with_capacity(12); p.extend_from_slice(&RPC_CLOSE_EXT_U32.to_le_bytes()); p.extend_from_slice(&conn_id.to_le_bytes()); - match w.tx.try_send(WriterCommand::DataAndFlush(Bytes::from(p))) { - Ok(()) => {} - Err(TrySendError::Full(_)) => { - debug!( - conn_id, - writer_id = w.writer_id, - "ME close skipped: writer command channel is full" - ); - } - Err(TrySendError::Closed(_)) => { - debug!("ME close write failed"); - self.remove_writer_and_close_clients( - w.writer_id, - MeWriterTeardownReason::CloseRpcChannelClosed, - ) - .await; - } + if w.tx + .send(WriterCommand::DataAndFlush(Bytes::from(p))) + .await + .is_err() + { + debug!("ME close write failed"); + self.remove_writer_and_close_clients(w.writer_id).await; } } else { debug!(conn_id, "ME close skipped (writer missing)"); @@ -693,12 +602,8 @@ impl MePool { p.extend_from_slice(&conn_id.to_le_bytes()); match w.tx.try_send(WriterCommand::DataAndFlush(Bytes::from(p))) { Ok(()) => {} - Err(TrySendError::Full(_)) => { - debug!( - conn_id, - writer_id = w.writer_id, - "ME close_conn skipped: writer command channel is full" - ); + Err(TrySendError::Full(cmd)) => { + let _ = tokio::time::timeout(Duration::from_millis(50), w.tx.send(cmd)).await; } Err(TrySendError::Closed(_)) => { debug!(conn_id, "ME close_conn skipped: writer channel closed"); diff --git a/src/transport/middle_proxy/health_adversarial_tests.rs b/src/transport/middle_proxy/tests/health_adversarial_tests.rs similarity index 69% rename from src/transport/middle_proxy/health_adversarial_tests.rs rename to src/transport/middle_proxy/tests/health_adversarial_tests.rs index 93b1d2b..5a8f612 100644 --- a/src/transport/middle_proxy/health_adversarial_tests.rs +++ b/src/transport/middle_proxy/tests/health_adversarial_tests.rs @@ -1,4 +1,5 @@ use std::collections::HashMap; +use std::collections::HashSet; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::sync::Arc; use std::sync::atomic::{AtomicBool, AtomicU8, AtomicU32, AtomicU64, Ordering}; @@ -112,8 +113,6 @@ async fn make_pool( general.me_warn_rate_limit_ms, MeRouteNoWriterMode::default(), general.me_route_no_writer_wait_ms, - general.me_route_hybrid_max_wait_ms, - general.me_route_blocking_send_timeout_ms, general.me_route_inline_recovery_attempts, general.me_route_inline_recovery_wait_ms, ); @@ -189,15 +188,48 @@ async fn sorted_writer_ids(pool: &Arc) -> Vec { ids } +fn lcg_next(state: &mut u64) -> u64 { + *state = state.wrapping_mul(6364136223846793005).wrapping_add(1); + *state +} + +async fn draining_writer_ids(pool: &Arc) -> HashSet { + pool.writers + .read() + .await + .iter() + .filter(|writer| writer.draining.load(Ordering::Relaxed)) + .map(|writer| writer.id) + .collect::>() +} + +async fn set_writer_runtime_state( + pool: &Arc, + writer_id: u64, + draining: bool, + drain_started_at_epoch_secs: u64, + drain_deadline_epoch_secs: u64, +) { + let writers = pool.writers.read().await; + if let Some(writer) = writers.iter().find(|writer| writer.id == writer_id) { + writer.draining.store(draining, Ordering::Relaxed); + writer + .draining_started_at_epoch_secs + .store(drain_started_at_epoch_secs, Ordering::Relaxed); + writer + .drain_deadline_epoch_secs + .store(drain_deadline_epoch_secs, Ordering::Relaxed); + } +} + #[tokio::test] async fn reap_draining_writers_clears_warn_state_when_pool_empty() { let (pool, _rng) = make_pool(128, 1, 1).await; let mut warn_next_allowed = HashMap::new(); - let mut soft_evict_next_allowed = HashMap::new(); warn_next_allowed.insert(11, Instant::now() + Duration::from_secs(5)); warn_next_allowed.insert(22, Instant::now() + Duration::from_secs(5)); - reap_draining_writers(&pool, &mut warn_next_allowed, &mut soft_evict_next_allowed).await; + reap_draining_writers(&pool, &mut warn_next_allowed).await; assert!(warn_next_allowed.is_empty()); } @@ -206,8 +238,6 @@ async fn reap_draining_writers_clears_warn_state_when_pool_empty() { async fn reap_draining_writers_respects_threshold_across_multiple_overflow_cycles() { let threshold = 3u64; let (pool, _rng) = make_pool(threshold, 1, 1).await; - pool.me_pool_drain_soft_evict_enabled - .store(false, Ordering::Relaxed); let now_epoch_secs = MePool::now_epoch_secs(); for writer_id in 1..=60u64 { @@ -222,9 +252,8 @@ async fn reap_draining_writers_respects_threshold_across_multiple_overflow_cycle } let mut warn_next_allowed = HashMap::new(); - let mut soft_evict_next_allowed = HashMap::new(); for _ in 0..64 { - reap_draining_writers(&pool, &mut warn_next_allowed, &mut soft_evict_next_allowed).await; + reap_draining_writers(&pool, &mut warn_next_allowed).await; if writer_count(&pool).await <= threshold as usize { break; } @@ -252,12 +281,11 @@ async fn reap_draining_writers_handles_large_empty_writer_population() { } let mut warn_next_allowed = HashMap::new(); - let mut soft_evict_next_allowed = HashMap::new(); for _ in 0..24 { if writer_count(&pool).await == 0 { break; } - reap_draining_writers(&pool, &mut warn_next_allowed, &mut soft_evict_next_allowed).await; + reap_draining_writers(&pool, &mut warn_next_allowed).await; } assert_eq!(writer_count(&pool).await, 0); @@ -281,12 +309,11 @@ async fn reap_draining_writers_processes_mass_deadline_expiry_without_unbounded_ } let mut warn_next_allowed = HashMap::new(); - let mut soft_evict_next_allowed = HashMap::new(); for _ in 0..40 { if writer_count(&pool).await == 0 { break; } - reap_draining_writers(&pool, &mut warn_next_allowed, &mut soft_evict_next_allowed).await; + reap_draining_writers(&pool, &mut warn_next_allowed).await; } assert_eq!(writer_count(&pool).await, 0); @@ -297,7 +324,6 @@ async fn reap_draining_writers_maintains_warn_state_subset_property_under_bulk_c let (pool, _rng) = make_pool(128, 1, 1).await; let now_epoch_secs = MePool::now_epoch_secs(); let mut warn_next_allowed = HashMap::new(); - let mut soft_evict_next_allowed = HashMap::new(); for wave in 0..40u64 { for offset in 0..8u64 { @@ -311,20 +337,15 @@ async fn reap_draining_writers_maintains_warn_state_subset_property_under_bulk_c .await; } - reap_draining_writers(&pool, &mut warn_next_allowed, &mut soft_evict_next_allowed).await; + reap_draining_writers(&pool, &mut warn_next_allowed).await; assert!(warn_next_allowed.len() <= writer_count(&pool).await); let ids = sorted_writer_ids(&pool).await; for writer_id in ids.into_iter().take(3) { - let _ = pool - .remove_writer_and_close_clients( - writer_id, - crate::stats::MeWriterTeardownReason::ReapEmpty, - ) - .await; + let _ = pool.remove_writer_and_close_clients(writer_id).await; } - reap_draining_writers(&pool, &mut warn_next_allowed, &mut soft_evict_next_allowed).await; + reap_draining_writers(&pool, &mut warn_next_allowed).await; assert!(warn_next_allowed.len() <= writer_count(&pool).await); } } @@ -346,10 +367,9 @@ async fn reap_draining_writers_budgeted_cleanup_never_increases_pool_size() { } let mut warn_next_allowed = HashMap::new(); - let mut soft_evict_next_allowed = HashMap::new(); let mut previous = writer_count(&pool).await; for _ in 0..32 { - reap_draining_writers(&pool, &mut warn_next_allowed, &mut soft_evict_next_allowed).await; + reap_draining_writers(&pool, &mut warn_next_allowed).await; let current = writer_count(&pool).await; assert!(current <= previous); previous = current; @@ -451,6 +471,149 @@ async fn me_health_monitor_eliminates_mixed_empty_and_deadline_backlog() { assert!(writer_count(&pool).await <= threshold as usize); } +#[tokio::test] +async fn reap_draining_writers_deterministic_mixed_state_churn_preserves_invariants() { + let threshold = 9u64; + let (pool, _rng) = make_pool(threshold, 1, 1).await; + let mut warn_next_allowed = HashMap::new(); + let mut seed = 0x9E37_79B9_7F4A_7C15u64; + let mut next_writer_id = 20_000u64; + let now_epoch_secs = MePool::now_epoch_secs(); + + for writer_id in 1..=72u64 { + let bound_clients = if writer_id % 4 == 0 { 0 } else { 1 }; + let deadline = if writer_id % 5 == 0 { + now_epoch_secs.saturating_sub(1) + } else { + 0 + }; + insert_draining_writer( + &pool, + writer_id, + now_epoch_secs.saturating_sub(500).saturating_add(writer_id), + bound_clients, + deadline, + ) + .await; + } + + for _round in 0..90 { + reap_draining_writers(&pool, &mut warn_next_allowed).await; + + let draining_ids = draining_writer_ids(&pool).await; + assert!( + warn_next_allowed.keys().all(|id| draining_ids.contains(id)), + "warn-state keys must always be a subset of live draining writers" + ); + + let writer_ids = sorted_writer_ids(&pool).await; + if writer_ids.is_empty() { + continue; + } + + let remove_n = (lcg_next(&mut seed) % 3) as usize; + for writer_id in writer_ids.iter().copied().take(remove_n) { + let _ = pool.remove_writer_and_close_clients(writer_id).await; + } + + let survivors = sorted_writer_ids(&pool).await; + if !survivors.is_empty() { + let idx = (lcg_next(&mut seed) as usize) % survivors.len(); + let target = survivors[idx]; + set_writer_runtime_state(&pool, target, false, 0, 0).await; + } + + let survivors = sorted_writer_ids(&pool).await; + if survivors.len() > 1 { + let idx = (lcg_next(&mut seed) as usize) % survivors.len(); + let target = survivors[idx]; + let expired_deadline = if lcg_next(&mut seed) & 1 == 0 { + now_epoch_secs.saturating_sub(1) + } else { + 0 + }; + set_writer_runtime_state( + &pool, + target, + true, + now_epoch_secs.saturating_sub(120), + expired_deadline, + ) + .await; + } + + let inject_n = (lcg_next(&mut seed) % 4) as usize; + for _ in 0..inject_n { + let bound_clients = if lcg_next(&mut seed) & 1 == 0 { 0 } else { 1 }; + let deadline = if lcg_next(&mut seed) & 1 == 0 { + now_epoch_secs.saturating_sub(1) + } else { + 0 + }; + insert_draining_writer( + &pool, + next_writer_id, + now_epoch_secs.saturating_sub(240), + bound_clients, + deadline, + ) + .await; + next_writer_id = next_writer_id.saturating_add(1); + } + } + + for _ in 0..64 { + reap_draining_writers(&pool, &mut warn_next_allowed).await; + if writer_count(&pool).await <= threshold as usize { + break; + } + } + + assert!(writer_count(&pool).await <= threshold as usize); + let draining_ids = draining_writer_ids(&pool).await; + assert!(warn_next_allowed.keys().all(|id| draining_ids.contains(id))); +} + +#[tokio::test] +async fn reap_draining_writers_repeated_draining_flips_never_leave_stale_warn_state() { + let (pool, _rng) = make_pool(64, 1, 1).await; + let now_epoch_secs = MePool::now_epoch_secs(); + + for writer_id in 1..=24u64 { + insert_draining_writer( + &pool, + writer_id, + now_epoch_secs.saturating_sub(240), + 1, + 0, + ) + .await; + } + + let mut warn_next_allowed = HashMap::new(); + for _round in 0..48u64 { + for writer_id in 1..=24u64 { + let draining = (writer_id + _round) % 3 != 0; + set_writer_runtime_state( + &pool, + writer_id, + draining, + now_epoch_secs.saturating_sub(120), + 0, + ) + .await; + } + + reap_draining_writers(&pool, &mut warn_next_allowed).await; + + let draining_ids = draining_writer_ids(&pool).await; + assert!( + warn_next_allowed.keys().all(|id| draining_ids.contains(id)), + "warn-state map must not retain entries for writers outside draining set" + ); + } +} + #[test] fn health_drain_close_budget_is_within_expected_bounds() { let budget = health_drain_close_budget(); diff --git a/src/transport/middle_proxy/health_integration_tests.rs b/src/transport/middle_proxy/tests/health_integration_tests.rs similarity index 93% rename from src/transport/middle_proxy/health_integration_tests.rs rename to src/transport/middle_proxy/tests/health_integration_tests.rs index fbbffce..c4f4dd5 100644 --- a/src/transport/middle_proxy/health_integration_tests.rs +++ b/src/transport/middle_proxy/tests/health_integration_tests.rs @@ -111,8 +111,6 @@ async fn make_pool( general.me_warn_rate_limit_ms, MeRouteNoWriterMode::default(), general.me_route_no_writer_wait_ms, - general.me_route_hybrid_max_wait_ms, - general.me_route_blocking_send_timeout_ms, general.me_route_inline_recovery_attempts, general.me_route_inline_recovery_wait_ms, ); @@ -169,6 +167,20 @@ async fn insert_draining_writer( } } +async fn wait_for_pool_empty(pool: &Arc, timeout: Duration) { + let start = Instant::now(); + loop { + if pool.writers.read().await.is_empty() { + return; + } + assert!( + start.elapsed() < timeout, + "timed out waiting for pool.writers to become empty" + ); + tokio::time::sleep(Duration::from_millis(5)).await; + } +} + #[tokio::test] async fn me_health_monitor_drains_expired_backlog_over_multiple_cycles() { let (pool, rng) = make_pool(128, 1, 1).await; @@ -186,7 +198,7 @@ async fn me_health_monitor_drains_expired_backlog_over_multiple_cycles() { } let monitor = tokio::spawn(me_health_monitor(pool.clone(), rng, 0)); - tokio::time::sleep(Duration::from_millis(60)).await; + wait_for_pool_empty(&pool, Duration::from_secs(1)).await; monitor.abort(); let _ = monitor.await; @@ -202,7 +214,7 @@ async fn me_health_monitor_cleans_empty_draining_writers_without_force_close() { } let monitor = tokio::spawn(me_health_monitor(pool.clone(), rng, 0)); - tokio::time::sleep(Duration::from_millis(30)).await; + wait_for_pool_empty(&pool, Duration::from_secs(1)).await; monitor.abort(); let _ = monitor.await; @@ -227,7 +239,7 @@ async fn me_health_monitor_converges_retry_like_threshold_backlog_to_empty() { } let monitor = tokio::spawn(me_health_monitor(pool.clone(), rng, 0)); - tokio::time::sleep(Duration::from_millis(60)).await; + wait_for_pool_empty(&pool, Duration::from_secs(1)).await; monitor.abort(); let _ = monitor.await; diff --git a/src/transport/middle_proxy/health_regression_tests.rs b/src/transport/middle_proxy/tests/health_regression_tests.rs similarity index 66% rename from src/transport/middle_proxy/health_regression_tests.rs rename to src/transport/middle_proxy/tests/health_regression_tests.rs index 3c7b919..d391484 100644 --- a/src/transport/middle_proxy/health_regression_tests.rs +++ b/src/transport/middle_proxy/tests/health_regression_tests.rs @@ -4,7 +4,6 @@ use std::sync::Arc; use std::sync::atomic::{AtomicBool, AtomicU8, AtomicU32, AtomicU64, Ordering}; use std::time::{Duration, Instant}; -use bytes::Bytes; use tokio::sync::mpsc; use tokio_util::sync::CancellationToken; @@ -42,7 +41,7 @@ async fn make_pool(me_pool_drain_threshold: u64) -> Arc { NetworkDecision::default(), None, Arc::new(SecureRandom::new()), - Arc::new(Stats::new()), + Arc::new(Stats::default()), general.me_keepalive_enabled, general.me_keepalive_interval_secs, general.me_keepalive_jitter_secs, @@ -107,8 +106,6 @@ async fn make_pool(me_pool_drain_threshold: u64) -> Arc { general.me_warn_rate_limit_ms, MeRouteNoWriterMode::default(), general.me_route_no_writer_wait_ms, - general.me_route_hybrid_max_wait_ms, - general.me_route_blocking_send_timeout_ms, general.me_route_inline_recovery_attempts, general.me_route_inline_recovery_wait_ms, ) @@ -179,6 +176,21 @@ async fn current_writer_ids(pool: &Arc) -> Vec { writer_ids } +async fn writer_exists(pool: &Arc, writer_id: u64) -> bool { + pool.writers + .read() + .await + .iter() + .any(|writer| writer.id == writer_id) +} + +async fn set_writer_draining(pool: &Arc, writer_id: u64, draining: bool) { + let writers = pool.writers.read().await; + if let Some(writer) = writers.iter().find(|writer| writer.id == writer_id) { + writer.draining.store(draining, Ordering::Relaxed); + } +} + #[tokio::test] async fn reap_draining_writers_drops_warn_state_for_removed_writer() { let pool = make_pool(128).await; @@ -192,17 +204,14 @@ async fn reap_draining_writers_drops_warn_state_for_removed_writer() { ) .await; let mut warn_next_allowed = HashMap::new(); - let mut soft_evict_next_allowed = HashMap::new(); - reap_draining_writers(&pool, &mut warn_next_allowed, &mut soft_evict_next_allowed).await; + reap_draining_writers(&pool, &mut warn_next_allowed).await; assert!(warn_next_allowed.contains_key(&7)); - let _ = pool - .remove_writer_and_close_clients(7, crate::stats::MeWriterTeardownReason::ReapEmpty) - .await; + let _ = pool.remove_writer_and_close_clients(7).await; assert!(pool.registry.get_writer(conn_ids[0]).await.is_none()); - reap_draining_writers(&pool, &mut warn_next_allowed, &mut soft_evict_next_allowed).await; + reap_draining_writers(&pool, &mut warn_next_allowed).await; assert!(!warn_next_allowed.contains_key(&7)); } @@ -214,96 +223,12 @@ async fn reap_draining_writers_removes_empty_draining_writers() { insert_draining_writer(&pool, 2, now_epoch_secs.saturating_sub(30), 0, 0).await; insert_draining_writer(&pool, 3, now_epoch_secs.saturating_sub(20), 1, 0).await; let mut warn_next_allowed = HashMap::new(); - let mut soft_evict_next_allowed = HashMap::new(); - reap_draining_writers(&pool, &mut warn_next_allowed, &mut soft_evict_next_allowed).await; + reap_draining_writers(&pool, &mut warn_next_allowed).await; assert_eq!(current_writer_ids(&pool).await, vec![3]); } -#[tokio::test] -async fn reap_draining_writers_does_not_block_on_stuck_writer_close_signal() { - let pool = make_pool(128).await; - let now_epoch_secs = MePool::now_epoch_secs(); - - let (blocked_tx, blocked_rx) = mpsc::channel::(1); - assert!( - blocked_tx - .try_send(WriterCommand::Data(Bytes::from_static(b"stuck"))) - .is_ok() - ); - let blocked_rx_guard = tokio::spawn(async move { - let _hold_rx = blocked_rx; - tokio::time::sleep(Duration::from_secs(30)).await; - }); - - let blocked_writer_id = 90u64; - let blocked_writer = MeWriter { - id: blocked_writer_id, - addr: SocketAddr::new( - IpAddr::V4(Ipv4Addr::LOCALHOST), - 4500 + blocked_writer_id as u16, - ), - source_ip: IpAddr::V4(Ipv4Addr::LOCALHOST), - writer_dc: 2, - generation: 1, - contour: Arc::new(AtomicU8::new(WriterContour::Draining.as_u8())), - created_at: Instant::now() - Duration::from_secs(blocked_writer_id), - tx: blocked_tx.clone(), - cancel: CancellationToken::new(), - degraded: Arc::new(AtomicBool::new(false)), - rtt_ema_ms_x10: Arc::new(AtomicU32::new(0)), - draining: Arc::new(AtomicBool::new(true)), - draining_started_at_epoch_secs: Arc::new(AtomicU64::new( - now_epoch_secs.saturating_sub(120), - )), - drain_deadline_epoch_secs: Arc::new(AtomicU64::new(0)), - allow_drain_fallback: Arc::new(AtomicBool::new(false)), - }; - pool.writers.write().await.push(blocked_writer); - pool.registry - .register_writer(blocked_writer_id, blocked_tx) - .await; - pool.conn_count.fetch_add(1, Ordering::Relaxed); - - insert_draining_writer(&pool, 91, now_epoch_secs.saturating_sub(110), 0, 0).await; - - let mut warn_next_allowed = HashMap::new(); - let mut soft_evict_next_allowed = HashMap::new(); - - let reap_res = tokio::time::timeout( - Duration::from_millis(500), - reap_draining_writers(&pool, &mut warn_next_allowed, &mut soft_evict_next_allowed), - ) - .await; - blocked_rx_guard.abort(); - - assert!(reap_res.is_ok(), "reap should not block on close signal"); - assert!(current_writer_ids(&pool).await.is_empty()); - assert_eq!(pool.stats.get_me_writer_close_signal_drop_total(), 2); - assert_eq!(pool.stats.get_me_writer_close_signal_channel_full_total(), 1); - assert_eq!(pool.stats.get_me_draining_writers_reap_progress_total(), 2); - let activity = pool.registry.writer_activity_snapshot().await; - assert!(!activity.bound_clients_by_writer.contains_key(&blocked_writer_id)); - assert!(!activity.bound_clients_by_writer.contains_key(&91)); - let (probe_conn_id, _rx) = pool.registry.register().await; - assert!( - !pool.registry - .bind_writer( - probe_conn_id, - blocked_writer_id, - ConnMeta { - target_dc: 2, - client_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 6400), - our_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443), - proto_flags: 0, - }, - ) - .await - ); - let _ = pool.registry.unregister(probe_conn_id).await; -} - #[tokio::test] async fn reap_draining_writers_overflow_closes_oldest_non_empty_writers() { let pool = make_pool(2).await; @@ -313,9 +238,8 @@ async fn reap_draining_writers_overflow_closes_oldest_non_empty_writers() { insert_draining_writer(&pool, 33, now_epoch_secs.saturating_sub(20), 1, 0).await; insert_draining_writer(&pool, 44, now_epoch_secs.saturating_sub(10), 1, 0).await; let mut warn_next_allowed = HashMap::new(); - let mut soft_evict_next_allowed = HashMap::new(); - reap_draining_writers(&pool, &mut warn_next_allowed, &mut soft_evict_next_allowed).await; + reap_draining_writers(&pool, &mut warn_next_allowed).await; assert_eq!(current_writer_ids(&pool).await, vec![33, 44]); } @@ -333,9 +257,8 @@ async fn reap_draining_writers_deadline_force_close_applies_under_threshold() { ) .await; let mut warn_next_allowed = HashMap::new(); - let mut soft_evict_next_allowed = HashMap::new(); - reap_draining_writers(&pool, &mut warn_next_allowed, &mut soft_evict_next_allowed).await; + reap_draining_writers(&pool, &mut warn_next_allowed).await; assert!(current_writer_ids(&pool).await.is_empty()); } @@ -357,13 +280,129 @@ async fn reap_draining_writers_limits_closes_per_health_tick() { .await; } let mut warn_next_allowed = HashMap::new(); - let mut soft_evict_next_allowed = HashMap::new(); - reap_draining_writers(&pool, &mut warn_next_allowed, &mut soft_evict_next_allowed).await; + reap_draining_writers(&pool, &mut warn_next_allowed).await; assert_eq!(pool.writers.read().await.len(), writer_total - close_budget); } +#[tokio::test] +async fn reap_draining_writers_keeps_warn_state_for_deadline_backlog_writers() { + let pool = make_pool(0).await; + let now_epoch_secs = MePool::now_epoch_secs(); + let close_budget = health_drain_close_budget(); + let writer_total = close_budget.saturating_add(5); + for writer_id in 1..=writer_total as u64 { + insert_draining_writer( + &pool, + writer_id, + now_epoch_secs.saturating_sub(60), + 1, + now_epoch_secs.saturating_sub(1), + ) + .await; + } + let target_writer_id = writer_total as u64; + let mut warn_next_allowed = HashMap::new(); + warn_next_allowed.insert( + target_writer_id, + Instant::now() + Duration::from_secs(300), + ); + + reap_draining_writers(&pool, &mut warn_next_allowed).await; + + assert!(writer_exists(&pool, target_writer_id).await); + assert!(warn_next_allowed.contains_key(&target_writer_id)); +} + +#[tokio::test] +async fn reap_draining_writers_keeps_warn_state_for_overflow_backlog_writers() { + let pool = make_pool(1).await; + let now_epoch_secs = MePool::now_epoch_secs(); + let close_budget = health_drain_close_budget(); + let writer_total = close_budget.saturating_add(6); + for writer_id in 1..=writer_total as u64 { + insert_draining_writer( + &pool, + writer_id, + now_epoch_secs.saturating_sub(300).saturating_add(writer_id), + 1, + 0, + ) + .await; + } + let target_writer_id = writer_total.saturating_sub(1) as u64; + let mut warn_next_allowed = HashMap::new(); + warn_next_allowed.insert( + target_writer_id, + Instant::now() + Duration::from_secs(300), + ); + + reap_draining_writers(&pool, &mut warn_next_allowed).await; + + assert!(writer_exists(&pool, target_writer_id).await); + assert!(warn_next_allowed.contains_key(&target_writer_id)); +} + +#[tokio::test] +async fn reap_draining_writers_drops_warn_state_when_writer_exits_draining_state() { + let pool = make_pool(128).await; + let now_epoch_secs = MePool::now_epoch_secs(); + insert_draining_writer(&pool, 71, now_epoch_secs.saturating_sub(60), 1, 0).await; + + let mut warn_next_allowed = HashMap::new(); + warn_next_allowed.insert(71, Instant::now() + Duration::from_secs(300)); + + set_writer_draining(&pool, 71, false).await; + reap_draining_writers(&pool, &mut warn_next_allowed).await; + + assert!(writer_exists(&pool, 71).await); + assert!( + !warn_next_allowed.contains_key(&71), + "warn cooldown state must be dropped after writer leaves draining state" + ); +} + +#[tokio::test] +async fn reap_draining_writers_preserves_warn_state_across_multiple_budget_deferrals() { + let pool = make_pool(0).await; + let now_epoch_secs = MePool::now_epoch_secs(); + let close_budget = health_drain_close_budget(); + let writer_total = close_budget.saturating_mul(2).saturating_add(1); + for writer_id in 1..=writer_total as u64 { + insert_draining_writer( + &pool, + writer_id, + now_epoch_secs.saturating_sub(120), + 1, + now_epoch_secs.saturating_sub(1), + ) + .await; + } + + let tail_writer_id = writer_total as u64; + let mut warn_next_allowed = HashMap::new(); + warn_next_allowed.insert( + tail_writer_id, + Instant::now() + Duration::from_secs(300), + ); + + reap_draining_writers(&pool, &mut warn_next_allowed).await; + assert!(writer_exists(&pool, tail_writer_id).await); + assert!(warn_next_allowed.contains_key(&tail_writer_id)); + + reap_draining_writers(&pool, &mut warn_next_allowed).await; + assert!(writer_exists(&pool, tail_writer_id).await); + assert!(warn_next_allowed.contains_key(&tail_writer_id)); + + reap_draining_writers(&pool, &mut warn_next_allowed).await; + assert!(!writer_exists(&pool, tail_writer_id).await); + assert!( + !warn_next_allowed.contains_key(&tail_writer_id), + "warn cooldown state must clear once writer is actually removed" + ); +} + #[tokio::test] async fn reap_draining_writers_backlog_drains_across_ticks() { let pool = make_pool(128).await; @@ -381,13 +420,12 @@ async fn reap_draining_writers_backlog_drains_across_ticks() { .await; } let mut warn_next_allowed = HashMap::new(); - let mut soft_evict_next_allowed = HashMap::new(); for _ in 0..8 { if pool.writers.read().await.is_empty() { break; } - reap_draining_writers(&pool, &mut warn_next_allowed, &mut soft_evict_next_allowed).await; + reap_draining_writers(&pool, &mut warn_next_allowed).await; } assert!(pool.writers.read().await.is_empty()); @@ -411,10 +449,9 @@ async fn reap_draining_writers_threshold_backlog_converges_to_threshold() { .await; } let mut warn_next_allowed = HashMap::new(); - let mut soft_evict_next_allowed = HashMap::new(); for _ in 0..16 { - reap_draining_writers(&pool, &mut warn_next_allowed, &mut soft_evict_next_allowed).await; + reap_draining_writers(&pool, &mut warn_next_allowed).await; if pool.writers.read().await.len() <= threshold as usize { break; } @@ -431,9 +468,8 @@ async fn reap_draining_writers_threshold_zero_preserves_non_expired_non_empty_wr insert_draining_writer(&pool, 20, now_epoch_secs.saturating_sub(30), 1, 0).await; insert_draining_writer(&pool, 30, now_epoch_secs.saturating_sub(20), 1, 0).await; let mut warn_next_allowed = HashMap::new(); - let mut soft_evict_next_allowed = HashMap::new(); - reap_draining_writers(&pool, &mut warn_next_allowed, &mut soft_evict_next_allowed).await; + reap_draining_writers(&pool, &mut warn_next_allowed).await; assert_eq!(current_writer_ids(&pool).await, vec![10, 20, 30]); } @@ -456,9 +492,8 @@ async fn reap_draining_writers_prioritizes_force_close_before_empty_cleanup() { let empty_writer_id = close_budget.saturating_add(2) as u64; insert_draining_writer(&pool, empty_writer_id, now_epoch_secs.saturating_sub(20), 0, 0).await; let mut warn_next_allowed = HashMap::new(); - let mut soft_evict_next_allowed = HashMap::new(); - reap_draining_writers(&pool, &mut warn_next_allowed, &mut soft_evict_next_allowed).await; + reap_draining_writers(&pool, &mut warn_next_allowed).await; assert_eq!(current_writer_ids(&pool).await, vec![1, empty_writer_id]); } @@ -470,9 +505,8 @@ async fn reap_draining_writers_empty_cleanup_does_not_increment_force_close_metr insert_draining_writer(&pool, 1, now_epoch_secs.saturating_sub(60), 0, 0).await; insert_draining_writer(&pool, 2, now_epoch_secs.saturating_sub(50), 0, 0).await; let mut warn_next_allowed = HashMap::new(); - let mut soft_evict_next_allowed = HashMap::new(); - reap_draining_writers(&pool, &mut warn_next_allowed, &mut soft_evict_next_allowed).await; + reap_draining_writers(&pool, &mut warn_next_allowed).await; assert!(current_writer_ids(&pool).await.is_empty()); assert_eq!(pool.stats.get_pool_force_close_total(), 0); @@ -499,9 +533,8 @@ async fn reap_draining_writers_handles_duplicate_force_close_requests_for_same_w ) .await; let mut warn_next_allowed = HashMap::new(); - let mut soft_evict_next_allowed = HashMap::new(); - reap_draining_writers(&pool, &mut warn_next_allowed, &mut soft_evict_next_allowed).await; + reap_draining_writers(&pool, &mut warn_next_allowed).await; assert!(current_writer_ids(&pool).await.is_empty()); } @@ -511,7 +544,6 @@ async fn reap_draining_writers_warn_state_never_exceeds_live_draining_population let pool = make_pool(128).await; let now_epoch_secs = MePool::now_epoch_secs(); let mut warn_next_allowed = HashMap::new(); - let mut soft_evict_next_allowed = HashMap::new(); for wave in 0..12u64 { for offset in 0..9u64 { @@ -524,19 +556,14 @@ async fn reap_draining_writers_warn_state_never_exceeds_live_draining_population ) .await; } - reap_draining_writers(&pool, &mut warn_next_allowed, &mut soft_evict_next_allowed).await; + reap_draining_writers(&pool, &mut warn_next_allowed).await; assert!(warn_next_allowed.len() <= pool.writers.read().await.len()); let existing_writer_ids = current_writer_ids(&pool).await; for writer_id in existing_writer_ids.into_iter().take(4) { - let _ = pool - .remove_writer_and_close_clients( - writer_id, - crate::stats::MeWriterTeardownReason::ReapEmpty, - ) - .await; + let _ = pool.remove_writer_and_close_clients(writer_id).await; } - reap_draining_writers(&pool, &mut warn_next_allowed, &mut soft_evict_next_allowed).await; + reap_draining_writers(&pool, &mut warn_next_allowed).await; assert!(warn_next_allowed.len() <= pool.writers.read().await.len()); } } @@ -546,7 +573,6 @@ async fn reap_draining_writers_mixed_backlog_converges_without_leaking_warn_stat let pool = make_pool(6).await; let now_epoch_secs = MePool::now_epoch_secs(); let mut warn_next_allowed = HashMap::new(); - let mut soft_evict_next_allowed = HashMap::new(); for writer_id in 1..=18u64 { let bound_clients = if writer_id % 3 == 0 { 0 } else { 1 }; @@ -566,7 +592,7 @@ async fn reap_draining_writers_mixed_backlog_converges_without_leaking_warn_stat } for _ in 0..16 { - reap_draining_writers(&pool, &mut warn_next_allowed, &mut soft_evict_next_allowed).await; + reap_draining_writers(&pool, &mut warn_next_allowed).await; if pool.writers.read().await.len() <= 6 { break; } @@ -576,83 +602,6 @@ async fn reap_draining_writers_mixed_backlog_converges_without_leaking_warn_stat assert!(warn_next_allowed.len() <= pool.writers.read().await.len()); } -#[tokio::test] -async fn reap_draining_writers_soft_evicts_stuck_writer_with_per_writer_cap() { - let pool = make_pool(128).await; - pool.me_pool_drain_soft_evict_enabled.store(true, Ordering::Relaxed); - pool.me_pool_drain_soft_evict_grace_secs.store(0, Ordering::Relaxed); - pool.me_pool_drain_soft_evict_per_writer.store(1, Ordering::Relaxed); - pool.me_pool_drain_soft_evict_budget_per_core.store(8, Ordering::Relaxed); - pool.me_pool_drain_soft_evict_cooldown_ms - .store(1, Ordering::Relaxed); - - let now_epoch_secs = MePool::now_epoch_secs(); - insert_draining_writer( - &pool, - 77, - now_epoch_secs.saturating_sub(240), - 3, - now_epoch_secs.saturating_add(3_600), - ) - .await; - let mut warn_next_allowed = HashMap::new(); - let mut soft_evict_next_allowed = HashMap::new(); - - reap_draining_writers(&pool, &mut warn_next_allowed, &mut soft_evict_next_allowed).await; - - let activity = pool.registry.writer_activity_snapshot().await; - assert_eq!(activity.bound_clients_by_writer.get(&77), Some(&2)); - assert_eq!(pool.stats.get_pool_drain_soft_evict_total(), 1); - assert_eq!(pool.stats.get_pool_drain_soft_evict_writer_total(), 1); - assert_eq!(current_writer_ids(&pool).await, vec![77]); -} - -#[tokio::test] -async fn reap_draining_writers_soft_evict_respects_cooldown_per_writer() { - let pool = make_pool(128).await; - pool.me_pool_drain_soft_evict_enabled.store(true, Ordering::Relaxed); - pool.me_pool_drain_soft_evict_grace_secs.store(0, Ordering::Relaxed); - pool.me_pool_drain_soft_evict_per_writer.store(1, Ordering::Relaxed); - pool.me_pool_drain_soft_evict_budget_per_core.store(8, Ordering::Relaxed); - pool.me_pool_drain_soft_evict_cooldown_ms - .store(60_000, Ordering::Relaxed); - - let now_epoch_secs = MePool::now_epoch_secs(); - insert_draining_writer( - &pool, - 88, - now_epoch_secs.saturating_sub(240), - 3, - now_epoch_secs.saturating_add(3_600), - ) - .await; - let mut warn_next_allowed = HashMap::new(); - let mut soft_evict_next_allowed = HashMap::new(); - - reap_draining_writers(&pool, &mut warn_next_allowed, &mut soft_evict_next_allowed).await; - reap_draining_writers(&pool, &mut warn_next_allowed, &mut soft_evict_next_allowed).await; - - let activity = pool.registry.writer_activity_snapshot().await; - assert_eq!(activity.bound_clients_by_writer.get(&88), Some(&2)); - assert_eq!(pool.stats.get_pool_drain_soft_evict_total(), 1); - assert_eq!(pool.stats.get_pool_drain_soft_evict_writer_total(), 1); -} - -#[tokio::test] -async fn reap_draining_writers_instadrain_removes_non_expired_writers_immediately() { - let pool = make_pool(0).await; - pool.me_instadrain.store(true, Ordering::Relaxed); - let now_epoch_secs = MePool::now_epoch_secs(); - insert_draining_writer(&pool, 101, now_epoch_secs.saturating_sub(5), 1, 0).await; - insert_draining_writer(&pool, 102, now_epoch_secs.saturating_sub(4), 1, 0).await; - let mut warn_next_allowed = HashMap::new(); - let mut soft_evict_next_allowed = HashMap::new(); - - reap_draining_writers(&pool, &mut warn_next_allowed, &mut soft_evict_next_allowed).await; - - assert!(current_writer_ids(&pool).await.is_empty()); -} - #[test] fn general_config_default_drain_threshold_remains_enabled() { assert_eq!(GeneralConfig::default().me_pool_drain_threshold, 32); @@ -675,3 +624,15 @@ fn general_config_default_drain_threshold_remains_enabled() { ); assert_eq!(GeneralConfig::default().me_bind_stale_mode, MeBindStaleMode::Never); } + +#[tokio::test] +async fn prune_closed_writers_closes_bound_clients_when_writer_is_non_empty() { + let pool = make_pool(128).await; + let now_epoch_secs = MePool::now_epoch_secs(); + let conn_ids = insert_draining_writer(&pool, 910, now_epoch_secs.saturating_sub(60), 1, 0).await; + + pool.prune_closed_writers().await; + + assert!(!writer_exists(&pool, 910).await); + assert!(pool.registry.get_writer(conn_ids[0]).await.is_none()); +} diff --git a/src/transport/middle_proxy/tests/pool_refill_security_tests.rs b/src/transport/middle_proxy/tests/pool_refill_security_tests.rs new file mode 100644 index 0000000..61c375f --- /dev/null +++ b/src/transport/middle_proxy/tests/pool_refill_security_tests.rs @@ -0,0 +1,155 @@ +use std::collections::HashMap; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +use std::sync::Arc; +use std::time::{Duration, Instant}; + +use crate::config::{GeneralConfig, MeRouteNoWriterMode, MeSocksKdfPolicy, MeWriterPickMode}; +use crate::crypto::SecureRandom; +use crate::network::probe::NetworkDecision; +use crate::stats::Stats; + +use super::pool::MePool; + +async fn make_pool() -> Arc { + let general = GeneralConfig::default(); + + MePool::new( + None, + vec![1u8; 32], + None, + false, + None, + Vec::new(), + 1, + None, + 12, + 1200, + HashMap::new(), + HashMap::new(), + None, + NetworkDecision::default(), + None, + Arc::new(SecureRandom::new()), + Arc::new(Stats::default()), + general.me_keepalive_enabled, + general.me_keepalive_interval_secs, + general.me_keepalive_jitter_secs, + general.me_keepalive_payload_random, + general.rpc_proxy_req_every, + general.me_warmup_stagger_enabled, + general.me_warmup_step_delay_ms, + general.me_warmup_step_jitter_ms, + general.me_reconnect_max_concurrent_per_dc, + general.me_reconnect_backoff_base_ms, + general.me_reconnect_backoff_cap_ms, + general.me_reconnect_fast_retry_count, + general.me_single_endpoint_shadow_writers, + general.me_single_endpoint_outage_mode_enabled, + general.me_single_endpoint_outage_disable_quarantine, + general.me_single_endpoint_outage_backoff_min_ms, + general.me_single_endpoint_outage_backoff_max_ms, + general.me_single_endpoint_shadow_rotate_every_secs, + general.me_floor_mode, + general.me_adaptive_floor_idle_secs, + general.me_adaptive_floor_min_writers_single_endpoint, + general.me_adaptive_floor_min_writers_multi_endpoint, + general.me_adaptive_floor_recover_grace_secs, + general.me_adaptive_floor_writers_per_core_total, + general.me_adaptive_floor_cpu_cores_override, + general.me_adaptive_floor_max_extra_writers_single_per_core, + general.me_adaptive_floor_max_extra_writers_multi_per_core, + general.me_adaptive_floor_max_active_writers_per_core, + general.me_adaptive_floor_max_warm_writers_per_core, + general.me_adaptive_floor_max_active_writers_global, + general.me_adaptive_floor_max_warm_writers_global, + general.hardswap, + general.me_pool_drain_ttl_secs, + general.me_instadrain, + general.me_pool_drain_threshold, + general.me_pool_drain_soft_evict_enabled, + general.me_pool_drain_soft_evict_grace_secs, + general.me_pool_drain_soft_evict_per_writer, + general.me_pool_drain_soft_evict_budget_per_core, + general.me_pool_drain_soft_evict_cooldown_ms, + general.effective_me_pool_force_close_secs(), + general.me_pool_min_fresh_ratio, + general.me_hardswap_warmup_delay_min_ms, + general.me_hardswap_warmup_delay_max_ms, + general.me_hardswap_warmup_extra_passes, + general.me_hardswap_warmup_pass_backoff_base_ms, + general.me_bind_stale_mode, + general.me_bind_stale_ttl_secs, + general.me_secret_atomic_snapshot, + general.me_deterministic_writer_sort, + MeWriterPickMode::default(), + general.me_writer_pick_sample_size, + MeSocksKdfPolicy::default(), + general.me_writer_cmd_channel_capacity, + general.me_route_channel_capacity, + general.me_route_backpressure_base_timeout_ms, + general.me_route_backpressure_high_timeout_ms, + general.me_route_backpressure_high_watermark_pct, + general.me_reader_route_data_wait_ms, + general.me_health_interval_ms_unhealthy, + general.me_health_interval_ms_healthy, + general.me_warn_rate_limit_ms, + MeRouteNoWriterMode::default(), + general.me_route_no_writer_wait_ms, + general.me_route_inline_recovery_attempts, + general.me_route_inline_recovery_wait_ms, + ) +} + +#[tokio::test] +async fn connectable_endpoints_waits_until_quarantine_expires() { + let pool = make_pool().await; + let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 31, 0, 11)), 443); + + { + let mut guard = pool.endpoint_quarantine.lock().await; + guard.insert(addr, Instant::now() + Duration::from_millis(80)); + } + + let started = Instant::now(); + let endpoints = pool.connectable_endpoints_for_test(&[addr]).await; + let elapsed = started.elapsed(); + + assert_eq!(endpoints, vec![addr]); + assert!( + elapsed >= Duration::from_millis(50), + "single-endpoint DC should honor quarantine before retry" + ); +} + +#[tokio::test] +async fn connectable_endpoints_releases_quarantine_lock_before_sleep() { + let pool = make_pool().await; + let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 31, 0, 12)), 443); + + { + let mut guard = pool.endpoint_quarantine.lock().await; + guard.insert(addr, Instant::now() + Duration::from_millis(120)); + } + + let pool_for_task = Arc::clone(&pool); + let task = tokio::spawn(async move { pool_for_task.connectable_endpoints_for_test(&[addr]).await }); + + tokio::time::sleep(Duration::from_millis(10)).await; + + let quarantine_check = tokio::time::timeout( + Duration::from_millis(40), + pool.is_endpoint_quarantined(addr), + ) + .await; + assert!( + quarantine_check.is_ok(), + "quarantine lock must not be held while waiting for expiry" + ); + assert!(quarantine_check.expect("timeout")); + + let endpoints = tokio::time::timeout(Duration::from_millis(300), task) + .await + .expect("connectable_endpoints task timed out") + .expect("task join failed"); + assert_eq!(endpoints, vec![addr]); +} diff --git a/src/transport/middle_proxy/tests/pool_writer_security_tests.rs b/src/transport/middle_proxy/tests/pool_writer_security_tests.rs new file mode 100644 index 0000000..27b9635 --- /dev/null +++ b/src/transport/middle_proxy/tests/pool_writer_security_tests.rs @@ -0,0 +1,382 @@ +use std::collections::{HashMap, HashSet}; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, AtomicU8, AtomicU32, AtomicU64, Ordering}; +use std::time::{Duration, Instant}; + +use tokio::sync::mpsc; +use tokio_util::sync::CancellationToken; + +use super::codec::WriterCommand; +use super::pool::{MePool, MeWriter, WriterContour}; +use super::registry::ConnMeta; +use crate::config::{GeneralConfig, MeRouteNoWriterMode, MeSocksKdfPolicy, MeWriterPickMode}; +use crate::crypto::SecureRandom; +use crate::network::probe::NetworkDecision; +use crate::stats::Stats; + +async fn make_pool() -> Arc { + let general = GeneralConfig::default(); + + MePool::new( + None, + vec![1u8; 32], + None, + false, + None, + Vec::new(), + 1, + None, + 12, + 1200, + HashMap::new(), + HashMap::new(), + None, + NetworkDecision::default(), + None, + Arc::new(SecureRandom::new()), + Arc::new(Stats::default()), + general.me_keepalive_enabled, + general.me_keepalive_interval_secs, + general.me_keepalive_jitter_secs, + general.me_keepalive_payload_random, + general.rpc_proxy_req_every, + general.me_warmup_stagger_enabled, + general.me_warmup_step_delay_ms, + general.me_warmup_step_jitter_ms, + general.me_reconnect_max_concurrent_per_dc, + general.me_reconnect_backoff_base_ms, + general.me_reconnect_backoff_cap_ms, + general.me_reconnect_fast_retry_count, + general.me_single_endpoint_shadow_writers, + general.me_single_endpoint_outage_mode_enabled, + general.me_single_endpoint_outage_disable_quarantine, + general.me_single_endpoint_outage_backoff_min_ms, + general.me_single_endpoint_outage_backoff_max_ms, + general.me_single_endpoint_shadow_rotate_every_secs, + general.me_floor_mode, + general.me_adaptive_floor_idle_secs, + general.me_adaptive_floor_min_writers_single_endpoint, + general.me_adaptive_floor_min_writers_multi_endpoint, + general.me_adaptive_floor_recover_grace_secs, + general.me_adaptive_floor_writers_per_core_total, + general.me_adaptive_floor_cpu_cores_override, + general.me_adaptive_floor_max_extra_writers_single_per_core, + general.me_adaptive_floor_max_extra_writers_multi_per_core, + general.me_adaptive_floor_max_active_writers_per_core, + general.me_adaptive_floor_max_warm_writers_per_core, + general.me_adaptive_floor_max_active_writers_global, + general.me_adaptive_floor_max_warm_writers_global, + general.hardswap, + general.me_pool_drain_ttl_secs, + general.me_instadrain, + general.me_pool_drain_threshold, + general.me_pool_drain_soft_evict_enabled, + general.me_pool_drain_soft_evict_grace_secs, + general.me_pool_drain_soft_evict_per_writer, + general.me_pool_drain_soft_evict_budget_per_core, + general.me_pool_drain_soft_evict_cooldown_ms, + general.effective_me_pool_force_close_secs(), + general.me_pool_min_fresh_ratio, + general.me_hardswap_warmup_delay_min_ms, + general.me_hardswap_warmup_delay_max_ms, + general.me_hardswap_warmup_extra_passes, + general.me_hardswap_warmup_pass_backoff_base_ms, + general.me_bind_stale_mode, + general.me_bind_stale_ttl_secs, + general.me_secret_atomic_snapshot, + general.me_deterministic_writer_sort, + MeWriterPickMode::default(), + general.me_writer_pick_sample_size, + MeSocksKdfPolicy::default(), + general.me_writer_cmd_channel_capacity, + general.me_route_channel_capacity, + general.me_route_backpressure_base_timeout_ms, + general.me_route_backpressure_high_timeout_ms, + general.me_route_backpressure_high_watermark_pct, + general.me_reader_route_data_wait_ms, + general.me_health_interval_ms_unhealthy, + general.me_health_interval_ms_healthy, + general.me_warn_rate_limit_ms, + MeRouteNoWriterMode::default(), + general.me_route_no_writer_wait_ms, + general.me_route_inline_recovery_attempts, + general.me_route_inline_recovery_wait_ms, + ) +} + +async fn insert_writer( + pool: &Arc, + writer_id: u64, + writer_dc: i32, + addr: SocketAddr, + draining: bool, + created_at: Instant, +) { + let (tx, _rx) = mpsc::channel::(8); + let contour = if draining { + WriterContour::Draining + } else { + WriterContour::Active + }; + let writer = MeWriter { + id: writer_id, + addr, + source_ip: addr.ip(), + writer_dc, + generation: pool.current_generation(), + contour: Arc::new(AtomicU8::new(contour.as_u8())), + created_at, + tx: tx.clone(), + cancel: CancellationToken::new(), + degraded: Arc::new(AtomicBool::new(false)), + rtt_ema_ms_x10: Arc::new(AtomicU32::new(0)), + draining: Arc::new(AtomicBool::new(draining)), + draining_started_at_epoch_secs: Arc::new(AtomicU64::new(0)), + drain_deadline_epoch_secs: Arc::new(AtomicU64::new(0)), + allow_drain_fallback: Arc::new(AtomicBool::new(false)), + }; + + pool.writers.write().await.push(writer); + pool.registry.register_writer(writer_id, tx).await; + pool.conn_count.fetch_add(1, Ordering::Relaxed); +} + +async fn current_writer_ids(pool: &Arc) -> HashSet { + pool.writers + .read() + .await + .iter() + .map(|writer| writer.id) + .collect() +} + +async fn bind_conn_to_writer(pool: &Arc, writer_id: u64, port: u16) -> u64 { + let (conn_id, _rx) = pool.registry.register().await; + let bound = pool + .registry + .bind_writer( + conn_id, + writer_id, + ConnMeta { + target_dc: 2, + client_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), port), + our_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443), + proto_flags: 0, + }, + ) + .await; + assert!(bound, "writer binding must succeed"); + conn_id +} + +#[tokio::test] +async fn remove_draining_writer_still_quarantines_flapping_endpoint() { + let pool = make_pool().await; + let writer_id = 77; + let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 12, 0, 77)), 443); + insert_writer( + &pool, + writer_id, + 2, + addr, + true, + Instant::now() - Duration::from_secs(1), + ) + .await; + + pool.remove_writer_and_close_clients(writer_id).await; + + let writer_still_present = pool + .writers + .read() + .await + .iter() + .any(|writer| writer.id == writer_id); + assert!( + !writer_still_present, + "writer must be removed from pool after cleanup" + ); + assert!( + pool.is_endpoint_quarantined(addr).await, + "draining removals must still quarantine flapping endpoints" + ); + assert_eq!(pool.conn_count.load(Ordering::Relaxed), 0); +} + +#[tokio::test] +async fn positive_remove_writer_cleans_bound_registry_routes() { + let pool = make_pool().await; + let writer_id = 88; + let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 12, 0, 88)), 443); + insert_writer(&pool, writer_id, 2, addr, false, Instant::now()).await; + + let conn_id = bind_conn_to_writer(&pool, writer_id, 7301).await; + assert!(pool.registry.get_writer(conn_id).await.is_some()); + + pool.remove_writer_and_close_clients(writer_id).await; + + assert!(pool.registry.get_writer(conn_id).await.is_none()); + assert!(!current_writer_ids(&pool).await.contains(&writer_id)); + assert_eq!(pool.conn_count.load(Ordering::Relaxed), 0); +} + +#[tokio::test] +async fn negative_unknown_writer_removal_is_noop() { + let pool = make_pool().await; + let before_quarantine = pool.stats.get_me_endpoint_quarantine_total(); + + pool.remove_writer_and_close_clients(9_999_001).await; + + assert!(pool.writers.read().await.is_empty()); + assert_eq!(pool.conn_count.load(Ordering::Relaxed), 0); + assert_eq!(pool.stats.get_me_endpoint_quarantine_total(), before_quarantine); +} + +#[tokio::test] +async fn edge_draining_only_detach_rejects_active_writer() { + let pool = make_pool().await; + let writer_id = 91; + let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 12, 0, 91)), 443); + insert_writer(&pool, writer_id, 2, addr, false, Instant::now()).await; + + let removed = pool.remove_draining_writer_hard_detach(writer_id).await; + assert!(!removed, "active writer must not be detached by draining-only path"); + assert!(current_writer_ids(&pool).await.contains(&writer_id)); + assert_eq!(pool.conn_count.load(Ordering::Relaxed), 1); + + pool.remove_writer_and_close_clients(writer_id).await; +} + +#[tokio::test] +async fn adversarial_blackhat_single_remove_establishes_single_quarantine_entry() { + let pool = make_pool().await; + let writer_id = 93; + let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 12, 0, 93)), 443); + insert_writer( + &pool, + writer_id, + 2, + addr, + true, + Instant::now() - Duration::from_secs(1), + ) + .await; + + pool.remove_writer_and_close_clients(writer_id).await; + assert!(pool.is_endpoint_quarantined(addr).await); + assert_eq!(pool.endpoint_quarantine.lock().await.len(), 1); +} + +#[tokio::test] +async fn integration_old_uptime_writer_does_not_trigger_flap_quarantine() { + let pool = make_pool().await; + let writer_id = 94; + let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 12, 0, 94)), 443); + insert_writer( + &pool, + writer_id, + 2, + addr, + false, + Instant::now() - Duration::from_secs(30), + ) + .await; + + let before = pool.stats.get_me_endpoint_quarantine_total(); + pool.remove_writer_and_close_clients(writer_id).await; + let after = pool.stats.get_me_endpoint_quarantine_total(); + + assert_eq!(after, before); + assert!(!pool.is_endpoint_quarantined(addr).await); +} + +#[tokio::test] +async fn light_fuzz_insert_remove_schedule_preserves_pool_invariants() { + let pool = make_pool().await; + let mut seed = 0xA11C_E551_D00D_BAADu64; + let mut model = HashSet::::new(); + + for _ in 0..240 { + seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1); + let writer_id = 1 + (seed % 64); + let op_insert = ((seed >> 17) & 1) == 0; + + if op_insert { + if !model.contains(&writer_id) { + let ip_octet = (writer_id % 250) as u8; + let addr = SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(127, 13, 0, ip_octet.max(1))), + 4000 + writer_id as u16, + ); + let draining = ((seed >> 33) & 1) == 1; + let created_at = if draining { + Instant::now() - Duration::from_secs(1) + } else { + Instant::now() - Duration::from_secs(30) + }; + insert_writer(&pool, writer_id, 2, addr, draining, created_at).await; + model.insert(writer_id); + } + } else { + pool.remove_writer_and_close_clients(writer_id).await; + model.remove(&writer_id); + } + + let actual_ids = current_writer_ids(&pool).await; + assert_eq!(actual_ids, model, "writer-id set must match model under fuzz schedule"); + assert_eq!(pool.conn_count.load(Ordering::Relaxed) as usize, model.len()); + } + + for writer_id in model { + pool.remove_writer_and_close_clients(writer_id).await; + } + assert!(pool.writers.read().await.is_empty()); + assert_eq!(pool.conn_count.load(Ordering::Relaxed), 0); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_parallel_duplicate_removals_are_idempotent() { + let pool = make_pool().await; + + for writer_id in 1..=48u64 { + let addr = SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(127, 14, (writer_id / 250) as u8, (writer_id % 250) as u8)), + 5000 + writer_id as u16, + ); + insert_writer( + &pool, + writer_id, + 2, + addr, + true, + Instant::now() - Duration::from_secs(1), + ) + .await; + } + + let mut tasks = Vec::new(); + for worker in 0..8u64 { + let pool = Arc::clone(&pool); + tasks.push(tokio::spawn(async move { + for writer_id in 1..=48u64 { + if ((writer_id + worker) & 1) == 0 { + pool.remove_writer_and_close_clients(writer_id).await; + } else { + pool.remove_writer_and_close_clients(100_000 + writer_id).await; + } + } + })); + } + + for task in tasks { + task.await.expect("stress remover task must not panic"); + } + + for writer_id in 1..=48u64 { + pool.remove_writer_and_close_clients(writer_id).await; + } + + assert!(pool.writers.read().await.is_empty()); + assert_eq!(pool.conn_count.load(Ordering::Relaxed), 0); +} diff --git a/src/transport/middle_proxy/tests/send_adversarial_tests.rs b/src/transport/middle_proxy/tests/send_adversarial_tests.rs new file mode 100644 index 0000000..5850420 --- /dev/null +++ b/src/transport/middle_proxy/tests/send_adversarial_tests.rs @@ -0,0 +1,269 @@ +use std::collections::HashMap; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, AtomicU8, AtomicU32, AtomicU64, Ordering}; +use std::time::{Duration, Instant}; + +use tokio::sync::mpsc; +use tokio_util::sync::CancellationToken; + +use super::codec::WriterCommand; +use super::pool::{MePool, MeWriter, WriterContour}; +use crate::config::{GeneralConfig, MeRouteNoWriterMode, MeSocksKdfPolicy, MeWriterPickMode}; +use crate::crypto::SecureRandom; +use crate::network::probe::NetworkDecision; +use crate::stats::Stats; + +async fn make_pool() -> (Arc, Arc) { + let general = GeneralConfig { + me_route_no_writer_mode: MeRouteNoWriterMode::AsyncRecoveryFailfast, + me_route_no_writer_wait_ms: 50, + me_writer_pick_mode: MeWriterPickMode::SortedRr, + me_deterministic_writer_sort: true, + ..GeneralConfig::default() + }; + + let rng = Arc::new(SecureRandom::new()); + let pool = MePool::new( + None, + vec![1u8; 32], + None, + false, + None, + Vec::new(), + 1, + None, + 12, + 1200, + HashMap::new(), + HashMap::new(), + None, + NetworkDecision::default(), + None, + rng.clone(), + Arc::new(Stats::default()), + general.me_keepalive_enabled, + general.me_keepalive_interval_secs, + general.me_keepalive_jitter_secs, + general.me_keepalive_payload_random, + general.rpc_proxy_req_every, + general.me_warmup_stagger_enabled, + general.me_warmup_step_delay_ms, + general.me_warmup_step_jitter_ms, + general.me_reconnect_max_concurrent_per_dc, + general.me_reconnect_backoff_base_ms, + general.me_reconnect_backoff_cap_ms, + general.me_reconnect_fast_retry_count, + general.me_single_endpoint_shadow_writers, + general.me_single_endpoint_outage_mode_enabled, + general.me_single_endpoint_outage_disable_quarantine, + general.me_single_endpoint_outage_backoff_min_ms, + general.me_single_endpoint_outage_backoff_max_ms, + general.me_single_endpoint_shadow_rotate_every_secs, + general.me_floor_mode, + general.me_adaptive_floor_idle_secs, + general.me_adaptive_floor_min_writers_single_endpoint, + general.me_adaptive_floor_min_writers_multi_endpoint, + general.me_adaptive_floor_recover_grace_secs, + general.me_adaptive_floor_writers_per_core_total, + general.me_adaptive_floor_cpu_cores_override, + general.me_adaptive_floor_max_extra_writers_single_per_core, + general.me_adaptive_floor_max_extra_writers_multi_per_core, + general.me_adaptive_floor_max_active_writers_per_core, + general.me_adaptive_floor_max_warm_writers_per_core, + general.me_adaptive_floor_max_active_writers_global, + general.me_adaptive_floor_max_warm_writers_global, + general.hardswap, + general.me_pool_drain_ttl_secs, + general.me_instadrain, + general.me_pool_drain_threshold, + general.me_pool_drain_soft_evict_enabled, + general.me_pool_drain_soft_evict_grace_secs, + general.me_pool_drain_soft_evict_per_writer, + general.me_pool_drain_soft_evict_budget_per_core, + general.me_pool_drain_soft_evict_cooldown_ms, + general.effective_me_pool_force_close_secs(), + general.me_pool_min_fresh_ratio, + general.me_hardswap_warmup_delay_min_ms, + general.me_hardswap_warmup_delay_max_ms, + general.me_hardswap_warmup_extra_passes, + general.me_hardswap_warmup_pass_backoff_base_ms, + general.me_bind_stale_mode, + general.me_bind_stale_ttl_secs, + general.me_secret_atomic_snapshot, + general.me_deterministic_writer_sort, + general.me_writer_pick_mode, + general.me_writer_pick_sample_size, + MeSocksKdfPolicy::default(), + general.me_writer_cmd_channel_capacity, + general.me_route_channel_capacity, + general.me_route_backpressure_base_timeout_ms, + general.me_route_backpressure_high_timeout_ms, + general.me_route_backpressure_high_watermark_pct, + general.me_reader_route_data_wait_ms, + general.me_health_interval_ms_unhealthy, + general.me_health_interval_ms_healthy, + general.me_warn_rate_limit_ms, + general.me_route_no_writer_mode, + general.me_route_no_writer_wait_ms, + general.me_route_inline_recovery_attempts, + general.me_route_inline_recovery_wait_ms, + ); + + (pool, rng) +} + +async fn insert_writer( + pool: &Arc, + writer_id: u64, + writer_dc: i32, + addr: SocketAddr, + register_in_registry: bool, +) -> mpsc::Receiver { + let (tx, rx) = mpsc::channel::(8); + let writer = MeWriter { + id: writer_id, + addr, + source_ip: addr.ip(), + writer_dc, + generation: pool.current_generation(), + contour: Arc::new(AtomicU8::new(WriterContour::Active.as_u8())), + created_at: Instant::now(), + tx: tx.clone(), + cancel: CancellationToken::new(), + degraded: Arc::new(AtomicBool::new(false)), + rtt_ema_ms_x10: Arc::new(AtomicU32::new(0)), + draining: Arc::new(AtomicBool::new(false)), + draining_started_at_epoch_secs: Arc::new(AtomicU64::new(0)), + drain_deadline_epoch_secs: Arc::new(AtomicU64::new(0)), + allow_drain_fallback: Arc::new(AtomicBool::new(false)), + }; + + pool.writers.write().await.push(writer); + { + let mut map = pool.proxy_map_v4.write().await; + map.entry(writer_dc) + .or_insert_with(Vec::new) + .push((addr.ip(), addr.port())); + } + pool.rebuild_endpoint_dc_map().await; + if register_in_registry { + pool.registry.register_writer(writer_id, tx).await; + } + rx +} + +async fn recv_data_count(rx: &mut mpsc::Receiver, budget: Duration) -> usize { + let start = Instant::now(); + let mut data_count = 0usize; + while Instant::now().duration_since(start) < budget { + let remaining = budget.saturating_sub(Instant::now().duration_since(start)); + match tokio::time::timeout(remaining.min(Duration::from_millis(10)), rx.recv()).await { + Ok(Some(WriterCommand::Data(_))) => data_count += 1, + Ok(Some(WriterCommand::DataAndFlush(_))) => data_count += 1, + Ok(Some(WriterCommand::Close)) => {} + Ok(None) => break, + Err(_) => break, + } + } + data_count +} + +#[tokio::test] +async fn send_proxy_req_does_not_replay_when_first_bind_commit_fails() { + let (pool, _rng) = make_pool().await; + pool.rr.store(0, Ordering::Relaxed); + + let (conn_id, _rx) = pool.registry.register().await; + let mut stale_rx = insert_writer( + &pool, + 10, + 2, + SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 10)), 443), + false, + ) + .await; + let mut live_rx = insert_writer( + &pool, + 11, + 2, + SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 11)), 443), + true, + ) + .await; + + let result = pool + .send_proxy_req( + conn_id, + 2, + SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 30000), + SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443), + b"hello", + 0, + None, + ) + .await; + + assert!(result.is_ok()); + assert_eq!(recv_data_count(&mut stale_rx, Duration::from_millis(50)).await, 0); + assert_eq!(recv_data_count(&mut live_rx, Duration::from_millis(50)).await, 1); + + let bound = pool.registry.get_writer(conn_id).await; + assert!(bound.is_some()); + assert_eq!(bound.expect("writer should be bound").writer_id, 11); +} + +#[tokio::test] +async fn send_proxy_req_prunes_iterative_stale_bind_failures_without_data_replay() { + let (pool, _rng) = make_pool().await; + pool.rr.store(0, Ordering::Relaxed); + + let (conn_id, _rx) = pool.registry.register().await; + + let mut stale_rx_1 = insert_writer( + &pool, + 21, + 2, + SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 1, 21)), 443), + false, + ) + .await; + let mut stale_rx_2 = insert_writer( + &pool, + 22, + 2, + SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 1, 22)), 443), + false, + ) + .await; + let mut live_rx = insert_writer( + &pool, + 23, + 2, + SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 1, 23)), 443), + true, + ) + .await; + + let result = pool + .send_proxy_req( + conn_id, + 2, + SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 30001), + SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443), + b"storm", + 0, + None, + ) + .await; + + assert!(result.is_ok()); + assert_eq!(recv_data_count(&mut stale_rx_1, Duration::from_millis(50)).await, 0); + assert_eq!(recv_data_count(&mut stale_rx_2, Duration::from_millis(50)).await, 0); + assert_eq!(recv_data_count(&mut live_rx, Duration::from_millis(50)).await, 1); + + let writers = pool.writers.read().await; + let writer_ids = writers.iter().map(|w| w.id).collect::>(); + drop(writers); + assert_eq!(writer_ids, vec![23]); +} diff --git a/src/transport/socket.rs b/src/transport/socket.rs index 3ff96a2..3a35133 100644 --- a/src/transport/socket.rs +++ b/src/transport/socket.rs @@ -23,7 +23,7 @@ pub fn configure_tcp_socket( let socket = socket2::SockRef::from(stream); // Disable Nagle's algorithm for lower latency - socket.set_nodelay(true)?; + socket.set_tcp_nodelay(true)?; // Set keepalive if enabled if keepalive { @@ -54,7 +54,7 @@ pub fn configure_client_socket( let socket = socket2::SockRef::from(stream); // Disable Nagle's algorithm - socket.set_nodelay(true)?; + socket.set_tcp_nodelay(true)?; // Set keepalive let keepalive = TcpKeepalive::new() @@ -129,7 +129,7 @@ pub fn create_outgoing_socket_bound(addr: SocketAddr, bind_addr: Option) socket.set_nonblocking(true)?; // Disable Nagle - socket.set_nodelay(true)?; + socket.set_tcp_nodelay(true)?; socket.set_recv_buffer_size(DEFAULT_SOCKET_BUFFER_BYTES)?; socket.set_send_buffer_size(DEFAULT_SOCKET_BUFFER_BYTES)?; diff --git a/src/transport/upstream.rs b/src/transport/upstream.rs index b0d82b1..f2849e3 100644 --- a/src/transport/upstream.rs +++ b/src/transport/upstream.rs @@ -4,7 +4,7 @@ #![allow(deprecated)] -use rand::Rng; +use rand::RngExt; use std::collections::{BTreeSet, HashMap}; use std::net::{IpAddr, SocketAddr}; use std::pin::Pin; @@ -600,7 +600,7 @@ impl UpstreamManager { return self.connect_retry_backoff; } let jitter_cap_ms = (base_ms / 2).max(1); - let jitter_ms = rand::rng().gen_range(0..=jitter_cap_ms); + let jitter_ms = rand::rng().random_range(0..=jitter_cap_ms); Duration::from_millis(base_ms.saturating_add(jitter_ms)) } @@ -667,7 +667,7 @@ impl UpstreamManager { "No healthy upstreams available! Using random." ); } - return Some(filtered_upstreams[rand::rng().gen_range(0..filtered_upstreams.len())]); + return Some(filtered_upstreams[rand::rng().random_range(0..filtered_upstreams.len())]); } if healthy.len() == 1 { @@ -690,10 +690,10 @@ impl UpstreamManager { let total: f64 = weights.iter().map(|(_, w)| w).sum(); if total <= 0.0 { - return Some(healthy[rand::rng().gen_range(0..healthy.len())]); + return Some(healthy[rand::rng().random_range(0..healthy.len())]); } - let mut choice: f64 = rand::rng().gen_range(0.0..total); + let mut choice: f64 = rand::rng().random_range(0.0..total); for &(idx, weight) in &weights { if choice < weight {